引言

如果你最近被 deepseek 刷屏了,你应该会去阅读它的技术报告,尤其是 v3 和 r1,这两篇技术报告默认读者对于当前的大模型训练技术很了解

所以对于初学者来讲,阅读这些技术报告会有痛苦和挑战,第一个挑战可能就是 MLA (Multi-Head Latent Attention),这是个来源于 deepseek v2 的技术

本文试图从头开始,为大家梳理从 MHA,MQA,GQA 到 MLA 一路发展的脉络以及背后的原因,并尽量将所需要的知识直接附上,免去递归查找之苦。为了初学者友好,本文的思路是线性的,如果你已经了解某一块的知识,可以直接跳到其它感兴趣的部分

通过本文,你会掌握 MLA 所需的前置知识包括:

  1. 多头注意力机制 MHA
  2. 什么是位置编码和 ROPE
  3. KV Cache,prefilling&decoding
  4. MHA 为什么需要继续发展出 MQA 和 GQA

并掌握 MLA 的技术核心:

  1. MLA 从 0-1 思路的猜测
  2. MLA 的技术要点

多头注意力机制 MHA

如果你看 GPT 系列的论文,你学习到的 self-attention 是 Multi-Head Attention(MHA),即多头注意力机制, MHA 包含 h 个 Query、Key 和 Value 矩阵,所有注意力头(head)的 Key 和 Value 矩阵权重不共享

更细一点:

KV Cache:推理阶段的工程优化

KV Cache 是 GPT2 开始就存在的工程优化。它主要用于在生成阶段(decode)缓存之前计算的键值对,避免重复计算,从而节省计算资源和时间

推理有两个阶段:prefill 和 decode

  • Prefill 阶段处理整个输入序列,生成第一个输出 token,并初始化 KV 缓存
  • Decode 阶段则逐个生成后续的 token,此时如果有 KV 缓存,每次只需处理新生成的 token,而无需重新计算之前所有 token 的键值

附录 1 是一个 KV Cache 的例子作为参考

MQA,GQA:多头注意力机制的降本增效

既然我们用空间换时间的方案,加快了推理速度,那占用显存空间又成了一个可以优化的点,有没有可能降低 KV Cache 的大小呢?

业界在 2019 年和 2023 年分别发明了 MQA(Multi Query Attention)和 GQA(Group Query Attention), 来降低 KV 缓存的大小

不可否认的是,这两者都会对大模型的能力产生影响,但两者都认为这部分的能力衰减可以通过进一步的训练或者增加 FFN/GLU 的规模来弥补

MQA 通过在 Attention 机制里面共享 keys 和 values 来减少 KV cache 的内容,query 的数量还是多个,而 keys 和 values 只有一个,所有的 query 共享一组 kv,这样 KV Cache 就变小了

GQA 不是所有的 query 共享一组 KV,而是一个 group 的 guery 共享一组 KV,这样既降低了 KV cache,又能满足精度,属于 MHA 和 MQA 之间的折中方案

MLA:山穷水尽疑无路,柳暗花明又一村

MHA,MQA,GQA 后下一个创新点在哪?

MQA 和 GQA 是在缓存多少数量 KV 的思路上进行优化:直觉是如果我缓存的 KV 个数少一些,显存就占用少一些,大模型能力的降低可以通过进一步的训练或者增加 FFN/GLU 的规模来弥补

如果想进一步降低 KV 缓存的占用,从数量上思考已经不行了,那就势必得从 KV 本身思考,有没有可能每个缓存的 KV 都比之前小?

我们知道,一个 的矩阵可以近似成两个 矩阵的乘积,那如果我把一个 K 或者 V 矩阵拆成两个小矩阵的乘积,缓存的时候显存占用不就变小了吗?

但这有一个问题,如果单纯的把一个大的 K/V 矩阵拆成 2 个小矩阵进行 cache,那在推理的时候,还是需要计算出完整的 K 矩阵,这样就失去了缓存的意义,毕竟缓存的意义就是减少计算!

有没有一种方法,即能减少缓存大小,又不增加推理时候的计算?

我们看看 deepseek v2 中是怎么解这个问题的

MLA 面临的问题与解法

在 v2 的论文中, 的表达从 变为 , 原来缓存的是 ,而现在缓存的是 的一部分 ,论文中把它定义成 ,这样就达到了降低 K 大小的目的

注意到 是 UP 的意思,指带将后面的矩阵维度升上去; 是 Down 的意思,指将维度降下去,而 指的是对 K,V 矩阵的采用相同的降维矩阵,这样只用缓存相同的

推理阶段计算量的增加

到目前为止,似乎是 so far so good,但是注意到在推理的时候,为了得到 ,还得将 或者 乘上去,失去了缓存的意义,怎么办?

有没有什么办法在推理的时候,降低这一步的计算量呢?如果在推理中,一定要还原出 或者 ,那就无解了,但好在 也只是中间变量,我们可以通过一定的变形巧妙的避免推理计算量

我们看 Attention 因子的计算:

其中 代表第 次输入/输出, 代表第 个 head

利用矩阵的结合率,我们可以在推理的时候,提前算好 ,这样在 decode 的时候,计算量基本没有增加,这被称为矩阵吸收(absorb)

同理, 也可以被吸收到 中,注意这边我们需要小心的通过转置等手段保证数学上的恒等

这样 deepseek v2 就能既要又要了

MLA 和 RoPE 位置编码不兼容

但是 MLA 又有一个新的问题,那就是和 Rope 位置编码的兼容性问题,为此他们还找过 Rope 的发明人苏剑林讨论过,这在苏 2024 年 5 月的博客里有提到,引用如下:

Deepseek 最终的解决方案是在 Q,K 上新加 d 个维度,单独用来存储位置向量,在推理的时候,缓存

完整的 MLA 算法

我们可以看下完整的包含了 RoPE 的位置编码 MLA 算法,标框的是缓存的内容

附录 1:KV Cache

在推理阶段,KV Cache 的存在与否对大模型的计算流程和效率有显著影响

以下以 prompt 为“真忒修斯之船是一个”,生成 completion “分享平台”为例,分别说明有无 KV Cache 的差异:

无 KV Cache 的推理过程

原理:每次生成新 token 时,需要将整个历史序列(prompt+已生成 tokens)重新输入模型,并重新计算所有 token 的 Key 和 Value 向量

示例流程:

  1. 输入完整 prompt“真忒修斯之船是一个”,计算所有 token 的 Key/Value,生成第一个 token“分”
    • 计算量:需处理 9 个 token(假设“真忒修斯之船是一个”分词为 9 个 token)
  2. 输入“真忒修斯之船是一个分”,重新计算全部 10 个 token 的 Key/Value,生成“享”
    • 冗余计算:前 9 个 token 的 Key/Value 被重复计算
  3. 输入“真忒修斯之船是一个分享”,重新计算 11 个 token 的 Key/Value,生成“平”
  4. 输入“真忒修斯之船是一个分享平”,重新计算 12 个 token 的 Key/Value,生成“台”

问题:

  • 计算冗余:每生成一个 token 需重新计算所有历史 token 的 Key/Value,复杂度为 显存和计算时间随序列长度急剧增长
  • 显存占用高:显存需存储完整历史序列的中间结果,例如生成“台”时需缓存 10 个 token 的 Key/Value

有 KV Cache 的推理过程

原理:在 prefill 阶段计算 prompt 的 Key/Value 并缓存,后续 decode 阶段仅需计算新 token 的 Key/Value,复用缓存的旧结果

示例流程:

  • Prefill 阶段:输入完整 prompt“真忒修斯之船是一个”,计算其 9 个 token 的 Key/Value 并缓存,生成第一个 token“分”
  • Decode 阶段:
    1. 输入新 token“分”,仅计算其 Key/Value,与缓存的 9 个 Key/Value 合并,生成“享”
    2. 输入新 token“享”,计算其 Key/Value,与缓存的 10 个 Key/Value 合并,生成“平”
    3. 输入新 token“平”,计算其 Key/Value,与缓存的 11 个 Key/Value 合并,生成“台”

优势:

  • 计算量降低:复杂度从 降至 每个 decode 步骤仅需计算新 token 的 Key/Value
  • 显存优化:仅需存储缓存的 Key/Value,显存占用公式为 但通过复用缓存避免了冗余存储
  • 速度提升:实验显示,KV Cache 可使吞吐量提升数十倍

核心对比

维度无 KV Cache有 KV Cache
计算复杂度 随序列长度平方增长 仅需计算新 token
显存占用存储完整序列中间结果,显存需求高缓存 Key/Value,显存需求可控
生成速度慢(重复计算历史 token)快(仅计算新 token,复用缓存)
适用场景短序列生成(<100 tokens)长序列生成(如 API 输入、视频生成)

附录 2:RoPE 相对位置编码

旋转位置编码 RoPE

参考