推理时的 KV cache 已经是大模型长上下文部署的硬约束。MLA(Multi-head Latent Attention)用一招"低秩联合压缩 + 矩阵吸收",把 cache 缩到 GQA 2.25 组的水平,效果反而比 MHA 更好。
要解决的问题是:自回归推理时每生成一个 token 都要把所有历史 token 的 K 和 V从显存读一遍,KV cache 同时吃掉显存容量和带宽,长上下文 + 大 batch 直接撞墙。
MQA / GQA 用「让多个 query head 共享同一组 K/V」来缩 cache,但共享是粗暴的有损压缩,效果会掉。MLA 不走共享,而是把 K 和 V 一起低秩联合压缩成一个 d_c 维 latent 向量缓存起来;再借助矩阵乘法的结合律把上投影矩阵吸收进 W^Q 和 W^O,推理时根本不用还原出 K/V。RoPE 因为破坏可交换性需要单独处理:拆出一条「带 RoPE 的小 head」并行计算,再拼回去。
Transformer 推理时为了不重算前面 token 的 K 和 V,会把它们逐层缓存下来。这块缓存体量随上下文长度线性增长、随 batch 线性增长,最终撞到的不只是显存装得下装不下,还撞到「每个 step 都要把它从显存读一遍」的带宽。
自回归生成的运作方式很简单:模型每生成一个 token,都要让这个新 token 的 query 与所有历史 token的 key、value 做注意力。如果每生成一步都重新算一遍前面所有 token 的 K/V,复杂度是 O(n²),根本没法用。所以工程上一定要把历史的 K 和 V 缓存下来,只为新 token 算 K/V 然后追加进去——这就是 KV cache。
问题在于这块 cache 的体量:每层 2 × nh × dh 个元素,乘上层数 l,再乘上序列长度,再乘上 batch。以 DeepSeek 67B(标准 MHA)为例,单 token 单层 cache 是几 KB;推到 128K context、batch=8,一个请求就要吞几十 GB——大头反而不是模型权重。
更要命的是:每个 decoding step 都要把整个 KV cache 从显存里读一遍。decoding 是 memory-bound 的,不是 compute-bound 的。这意味着你哪怕用 H100 也喂不饱算力,瓶颈在显存带宽(HBM bandwidth)上。所以「缩小 KV cache」不只是省显存,更是直接降低单 step 延迟、提升 throughput 的杠杆。
Decoding 时新 token 的 query 是单条向量,K/V 矩阵是历史所有 token——这是个又长又瘦的 GEMV,HBM 带宽喂不动算力,绝大部分时间花在搬 cache 上。
h_t → 三组 projection 算出 q_t、k_t、v_t。计算量小、一次搞定。MLA 之前主流的 cache 压缩思路只有一条:既然每个 head 都各自有 K 和 V 太贵,那就让多个 query head 共享少量 K/V head。MQA 是极端共享(全部共享一组),GQA 是折中(按组共享)。但本质上都是在「质量 vs cache」之间做线性 trade-off。
原始的 MHA(Multi-Head Attention)给每个 head 单独算一组 K 和 V。这是 quality 的天花板,也是 cache 的"丑数字"。
2019 年 Noam Shazeer 提出 MQA(Multi-Query Attention):让所有 head 共享同一组 K/V,cache 一下子缩小 nh 倍。代价是质量肉眼可见地下降,且训练不稳定。
2023 年 Google 提出 GQA(Grouped-Query Attention):把 head 分成 ng 组,每组共享一组 K/V。本质是 MHA 和 MQA 中间的插值,绝大部分 LLaMA-2/3 系列、Qwen、Mistral 都在用。它解决了 MQA 质量太差的问题,但 KV cache 的下限是 GQA-1(即 MQA),上限是 GQA-H(即 MHA),它没有打破"共享 K/V"这个范式。
DeepSeek 的论文里做过一个直观对比:在 7B dense 上对齐参数量后,MMLU 分别是 MQA 37.9 / GQA 41.2 / MHA 45.2 ——共享得越狠,掉得越多。MLA 想问的是:能不能既不共享、又把 cache 缩到 GQA 的量级,同时质量持平甚至超过 MHA?
下面四张小图都画 4 个 query head(紫色)。差别只在于 K/V(青色 / 橙色)有几组、来源是什么。MLA 不再像前三种那样"按 head 摊分",而是先压成一个 latent 再 up-project。
MLA 真正的秘诀有两部分。第一部分让 cache 变小:把 K 和 V 联合压成一个低秩 latent;第二部分让计算也不变贵:利用矩阵乘法的结合律,把"还原 K/V 的 up-projection"提前吸收进 W^Q 和 W^O,推理时根本不需要解压回 K/V。
第一招:低秩联合压缩。给定输入 ht ∈ ℝd,MLA 不再像 MHA 那样直接算出全维度的 K 和 V,而是先经过一个 down-projection 矩阵 WDKV 把它压成一个小很多的 latent 向量 ctKV ∈ ℝdc,其中 dc ≪ nh·dh。需要 K 时再用 WUK up-project 回去,需要 V 时再用 WUV up-project 回去。
注意"联合"两个字:K 和 V 共用同一个 latent,意味着 cache 里只存这一个 latent 就行,cache 量直接从 2·nh·dh 砍到 dc——而且这是每个 token 都拥有自己的 latent,不是 MQA/GQA 那种"多个 head 共用一组 K/V"的有损共享。
第二招:矩阵吸收(associativity trick)。看起来 cache 是省了,但推理时每来一个新 query,还要把 WUK 和 WUV 作用在缓存的 latent 上去把 K/V "解压"出来——这步算量加回来怎么办?利用矩阵乘法结合律,WUK 可以提前合并进 query 投影矩阵 WQ,WUV 可以合并进输出投影矩阵 WO。这样 attention 算分数时直接用 query 和 latent 算内积,解压 K/V 这一步在数学上根本不存在(具体推导见下文 trick 块)。
结果是:cache 里只放一个 dc 维向量,attention 计算也不需要把它展回 K/V 全维度。这一步就是 MLA 区别于"普通低秩 KV"的关键——不只是省内存,连计算也是直接在 latent 空间里做的。
切换两个 tab 看清楚:训练时为了得到正确的 attention 输出,K 和 V 是真的被算出来过的;推理时由于矩阵吸收,那一步根本不发生,cache 里只有一个 latent。
关键是注意力分数 q · k 的形式。把 key 的定义代入并利用矩阵乘法结合律重组:
也就是说,WUK 可以被吸收进 WQ(实质是 WUQ),WUV 可以被吸收进 WO。这两步合并是离线的、一次性的;推理时 attention 直接在 latent 空间做,K 和 V 永远不需要被还原成全维度。这是 MLA 区别于普通 "low-rank KV" 想法的关键工程点。
矩阵吸收依赖于矩阵乘法的结合律和可交换的因子顺序。但 RoPE 是位置敏感的旋转矩阵,它会卡在 W^Q 和 W^UK 中间,不能交换位置——吸收技巧立刻失效。MLA 的解法是「解耦 RoPE」:让一小部分 head 维度专门承担位置编码,其余维度仍然走 latent 通路。
RoPE 的工作方式是给 query 和 key 各自乘一个由位置 t 决定的旋转矩阵 Rt。在 MHA 里这无伤大雅,因为 q 和 k 都是显式存在的。但在 MLA 里,如果想给 key 加 RoPE,就变成 kt = Rt · WUK · ct。这时 WQ 与 WUK 之间多了一个跟当前 query 位置有关的 Rt,矩阵乘法不再可交换——WUK 没法再被合并进 WQ,每生成一步都要重新 up-project 所有历史 token 的 K,吞吐立刻崩盘。
DeepSeek 的解法很巧:与其让 RoPE 和 latent 通路硬碰硬,不如把它们分开两条路。给每个 head 额外开一段维度 dhR(论文里 dhR = dh/2)专门用来承载 RoPE 信息,这部分的 query 仍然按 head 分开算(保留位置敏感性),但 key 是所有 head 共享一个(这部分回到 MQA 的做法,反正只是辅助位置信号)。
原本走 latent 通路的部分(content)继续保持矩阵吸收的红利;带 RoPE 的小尾巴单独缓存一份。最终 attention 时把 content 和 rope 两段拼接起来一起做内积。cache 多出来一份 dhR · l,但量级很小:DeepSeek-V2 总 cache = (dc + dhR) · l ≈ 4.5·dh·l,相当于 GQA 只有 2.25 组的水平,但效果比 MHA 还强。
attention 分数变成 qt⊤ · Rt · WUK · cjKV。Rt 与位置有关,没法跟 WUK、WQ 互换位置。"把 WUK 吸收进 WQ"的等价变形不再成立。
后果:每个 step 都要重新算所有历史 token 的真实 key,cache 缩了但计算量爆炸,工程上得不偿失。
Content 段(d_h 维):走 latent 通路,按原方案吸收 W^UK / W^UV,不带 RoPE。
RoPE 段(d_h^R 维):专门承载位置信息。query 按 head 分开算,key 是所有 head 共享一个 k_t^R。
attention 时拼接两段一起算内积。Content 享受吸收红利,RoPE 段单独 cache 一份。
每个 head 的 query 和 key 都是「Content 部分 ⨁ RoPE 部分」拼接而成。Content 部分走 MLA 的 latent 通路,RoPE 部分走类 MQA 的共享通路。
q^C ← W^UQ · c_t^Q(query 由低秩 latent up-project)
k^C ← W^UK · c_t^KV(key 由 KV latent up-project)
推理时整条路被矩阵吸收,cache 只放 c_t^KV。
q^R ← RoPE(W^QR · c_t^Q)(每 head 一份)
k^R ← RoPE(W^KR · h_t)(全 head 共享一份)
承担位置编码,cache 单独存 k_t^R。
论文里给了两组对照实验:4 种注意力机制的每 token 缓存大小公式,以及在小 / 大 MoE 上 MLA 与 MHA 的实测对比。MLA 在两种规模下都既省 cache 又涨点。
| 机制 | cache (元素数) | 能力 |
|---|---|---|
| MHA | 2·n_h·d_h·l | Strong |
| GQA | 2·n_g·d_h·l | Moderate |
| MQA | 2·d_h·l | Weak |
| MLA | (d_c+d_h^R)·l ≈ 9/2·d_h·l | Stronger |
DeepSeek-V2 取 d_c = 4·d_h,d_h^R = d_h/2。cache 量级介于 GQA-2 与 GQA-3 之间,但效果反而比 MHA 强。
| 指标 | MHA | MLA |
|---|---|---|
| KV cache / token | 860.2K | 34.6K |
| BBH (3-shot) | 46.6 | 50.7 |
| MMLU (5-shot) | 57.5 | 59.0 |
| C-Eval (5-shot) | 57.9 | 59.2 |
| CMMLU (5-shot) | 60.7 | 62.5 |
两个 ~250B 总参数的 MoE 模型,唯一差别是注意力机制。MLA 用 MHA 4% 的 cache,质量四项全胜。
| 机制 | MMLU | C-Eval | CMMLU |
|---|---|---|---|
| MQA | 37.9 | 30.0 | 34.6 |
| GQA-8 | 41.2 | 37.7 | 38.4 |
| MHA | 45.2 | 42.9 | 43.5 |
单看 dense 模型,"共享 KV 越狠 → 质量掉得越多"是非常清楚的趋势。GQA 比 MHA 也明显落后近 4 个点。
MLA 是「改 attention 数学」这条路的代表。但围绕 KV cache 的优化是个生态:还有量化它的、扔掉一部分的、跨层共享的、彻底换掉 attention 的、以及在系统层做内存管理的。下面是六个值得知道的方向。