导入
为了快速应用大模型,需要采购商业大模型。采购前,对接销售时,我们描述了场景和需求:
Q:我们的 prompts token 数量在 1500-2000 之间,completion token 数量在 500 左右。这种情况下,prefilling 多久?每个 token 输出是多久?
销售回复:标准 3500 token Input,首包吐出来小于 1 秒;吞吐量(throughput)300token/s
你是否看出来答非所问了?
问的和延迟(latency)相关,回答的是吞吐量(throughput)。那这两个词有什么区别?导致连这个领域的销售都会搞混?
在搞懂 latency 和 throughput 之前,我们先来看看 KV cache 和 prefilling
KV Cache
KV Cache 采用以空间换时间的思想,复用上次推理的 KV 缓存,可以极大降低内存压力、提高推理性能,而且不会影响任何计算精度
decoder 架构里面最主要的就是 transformer 中的 self-attention 结构的堆叠,KV-cache 的实质是用之前计算过的 key-value 以及当前的 query 来生成下一个 token
prefill 指的是生成第一个 token 的时候,kv 是没有任何缓存的,需要预填充 prompt 对应的 KV 矩阵做缓存,所以第一个 token 生成的最慢,而从第二个 token 开始,都会快速获取缓存,并将前一个 token 的 kv 也缓存
可以看到,这是一个空间换时间的方案,缓存会不断变大,所以在私有化部署计算显存的时候,除了模型大小,还要要看你的应用中 prompt 和 completion 的大小(当然还有 batch-size)
Prefilling & Decoding
Latency VS. Throughput
- Latency:延迟,指的是从输入到输出的时间,即从输入到输出最后一个 token 的时间
- Throughput:吞吐量,指的是单位时间内处理的任务数,即每秒处理的 token 数
下面给出 latency 和 throughput 的计算方法:
# constants
max_tokens = 10
# observations
durations = []
throughputs = []
latencies = []
batch_sizes = [2**p for p in range(8)]
for batch_size in batch_sizes:
print(f"bs= {batch_size}")
# generate tokens for batch and record duration
t0 = time.time()
batch_prompts = [
prompts[i % len(prompts)] for i in range(batch_size)
]
inputs = tokenizer(
batch_prompts, padding=True, return_tensors="pt"
)
generated_tokens = generate_batch(inputs, max_tokens=max_tokens)
duration_s = time.time() - t0
ntokens = batch_size * max_tokens
throughput = ntokens / duration_s
avg_latency = duration_s / max_tokens
print("duration", duration_s)
print("throughput", throughput)
print("avg latency", avg_latency)
print()
durations.append(duration_s)
throughputs.append(throughput)
latencies.append(avg_latency)
Navie batching
Navie batching 是指将多个输入合并成一个 batch,然后一次性输入模型,这样可以减少模型的前向传播次数,提高效率
有的人也称其为 synchronous batching 或者 static batching,区别于后面的 continuous batching
Navie batching 的缺点是,如果一个 batch 中有一个输入很大,那么整个 batch 的计算时间就会被拉长,这样会导致整个 batch 的计算时间变长
Continuous batching
在传统的批处理方法中,一批请求必须全部完成处理后才能一起返回结果。这就意味着较短请求需要等待较长请求处理完成,导致了 GPU 资源的浪费和推理延迟的增加。而 Continuous Batching 技术允许模型在处理完当前迭代后,如果有请求已经处理完成,则可以立即返回该请求的结果,而不需要等待整个批次的请求都处理完成,这样可以显著提高硬件资源的利用率并减少空闲时间
此外,Continuous Batching 还能够解决不同请求计算量不同导致的资源浪费问题,通过迭代级别的调度动态调整批处理大小,适应不同请求的复杂程度,有效降低高复杂度请求的等待时间
值得注意的是,实现 Continuous Batching 需要考虑一些关键问题,如对 Early-finished Requests、Late-joining Requests 的处理,以及如何处理不同长度的请求 Batching。OCRA 提出的两个设计思路:Iteration-level Batching 和 Selective Batching,就是为了解决这些问题
在实际应用中,不同的框架可能对 Continuous Batching 有不同的实现方式。例如,vLLM 框架采用了一种简化的实现,将 prefill 和 decoding 分开处理,而 FastGen 框架则采用了 SplitFuse 方法,将长 prompt 分解成小块并在多个 step 中调度。这些不同的实现方式都旨在提高推理性能,降低延迟,同时优化资源的利用
给出生成 continous batching 的代码:
# seed the random number generator so our results are deterministic
random.seed(42)
# constants
queue_size = 32
batch_size = 8
# requests waiting to be processed
# this time requests are tuples (prompt, max_tokens)
request_queue = [
(prompts[0], 100 if i % batch_size == 0 else 10)
for i in range(queue_size)
]
t0 = time.time()
with tqdm(total=len(request_queue), desc=f"bs={batch_size}") as pbar:
# first, let's seed the initial cached_batch
# with the first `batch_size` inputs
# and run the initial prefill step
batch = init_batch(request_queue[:batch_size])
cached_batch = generate_next_token(batch)
request_queue = request_queue[batch_size:]
# continue until both the request queue is
# fully drained and every input
# within the cached_batch has completed generation
while (
len(request_queue) > 0 or
cached_batch["input_ids"].size(0) > 0
):
batch_capacity = (
batch_size - cached_batch["input_ids"].size(0)
)
if batch_capacity > 0 and len(request_queue) > 0:
# prefill
new_batch = init_batch(request_queue[:batch_capacity])
new_batch = generate_next_token(new_batch)
request_queue = request_queue[batch_capacity:]
# merge
cached_batch = merge_batches(cached_batch, new_batch)
# decode
cached_batch = generate_next_token(cached_batch)
# remove any inputs that have finished generation
cached_batch, removed_indices = filter_batch(cached_batch)
pbar.update(len(removed_indices))
duration_s = time.time() - t0
print("duration", duration_s)