全面解析DeepSeek算法细节(3) —— Multi-head Latent Attention (MLA)

全面解析DeepSeek算法细节(3) —— Multi-head Latent Attention (MLA)_第1张图片

概述

Multihead Latent Attention (MLA) 是一种旨在提升计算效率的注意力机制,它通过将 Key-Query-Value (KQV) 矩阵投影到一个低维潜在空间,显著降低计算和内存成本。MLA 是 DeepSeek 系列模型(特别是 DeepSeek-R1)的关键组成部分,它通过低秩压缩技术最小化 Key-Value (KV) 缓存的存储开销,从而实现更快的推理速度,并支持更长的上下文长度或更大的批处理大小。

DeepSeek-R1 在 MLA 的基础上,进一步融入了强化学习 (RL) 增强的推理优化,同时保持较低的内存开销。通过利用解耦的旋转位置嵌入和潜在空间压缩,MLA 确保在保持计算效率的同时,最小化精度损失。

关键特性

  • 低秩 Key-Value 压缩:MLA 采用低秩潜在空间投影来压缩 KV 对,显著降低内存开销。这使得 DeepSeek-R1 能够仅存储压缩表示,而不是完整的 KV 状态,从而实现高效的长上下文处理。
  • 解耦旋转位置嵌入 (RoPE):标准 RoPE 引入了位置相关的变换,这会阻碍 KV 压缩。DeepSeek-R1 将 RoPE 与 key-value 存储解耦,确保位置编码在不干扰潜在空间效率的情况下保持有效。
  • 高效的多头注意力与压缩存储:与缓存所有 token 的完整 key-value 矩阵不同,MLA 仅存储其紧凑的潜在空间等效项。这大大降低了推理内存需求,同时保持了注意力保真度。
  • 自适应投影矩阵:MLA 为 queries、keys 和 values 利用单独的学习投影矩阵。这些矩阵在训练期间动态调整,确保最佳的存储效率和最小的精度损失(与全维注意力相比)。
  • 推理高效的缓存机制:通过选择性地仅缓存压缩的 key-value 表示,MLA 比传统的 Multi-Head Attention (MHA) 实现了 93.3% 的 KV 缓存减少。这使得 DeepSeek-R1 能够支持更长的上下文长度,同时最小化推理延迟。
  • 增强的在长上下文任务上的性能:DeepSeek-R1 使用 RL 驱动的优化(例如 GRPO)来改进 MLA,以优先处理关键 token。这提高了长上下文任务中的推理准确性,同时保持了计算效率。

从 DeepSeek-V2 到 DeepSeek-R1 的演变

DeepSeek-V2 中的 MLA

DeepSeek-V2 中的 MLA 旨在通过显著减少 KV 缓存大小,同时保持强大的模型性能来提高推理效率。它引入了相对于传统 Multi-Head Attention (MHA) 的多项关键创新,包括低秩 key-value 联合压缩和解耦旋转位置嵌入。

DeepSeek-V2 中的 MLA 实现为 DeepSeek-R1 中进一步的改进奠定了基础,在 DeepSeek-R1 中,它通过 FP8 量化、增强的压缩技术和改进的数值稳定性得到了进一步的完善。

低秩 Key-Value 联合压缩

Transformer 推理的主要瓶颈之一是存储过去 keys 和 values 所需的大型 KV 缓存。DeepSeek-V2 通过使用线性投影将 KV 表示压缩到低维潜在空间来解决此问题。

给定一个输入 token 表示 ( h t ∈ R d ) (h_t \in \mathbb{R}^d) (htRd),标准多头注意力计算 queries、keys 和 values 如下:
q t = W Q h t , k t = W K h t , v t = W V h t   q_t = W_Q h_t, \quad k_t = W_K h_t, \quad v_t = W_V h_t\ qt=WQht,kt=WKht,vt=WVht 

其中 W Q , W K , W V ∈ R d h n h × d W_Q, W_K, W_V \in \mathbb{R}^{d_h n_h \times d} WQ,WK,WVRdhnh×d

MLA 不是存储全维 k t k_t kt v t v_t vt,而是将它们压缩为潜在表示 c K V c_{KV} cKV
c K V t = W D K V h t c_{KV_t} = W_{D_{KV}} h_t cKVt=WDKVht

其中 W D K V ∈ R d c × d W_{D_{KV}} \in \mathbb{R}^{d_c \times d} WDKVRdc×d 是一个下投影矩阵,并且 d c ≪ d h n h d_c \ll d_h n_h dcdhnh

在推理期间,压缩的 key-value 表示被扩展回可用的 keys 和 values:
k t C = W U K c K V t , v t C = W U V c K V t k_t^C = W_{U_K} c_{KV_t}, \quad v_t^C = W_{U_V} c_{KV_t} ktC=WUKcKVt,vtC=WUVcKVt

其中 W U K , W U V ∈ R d h n h × d c W_{U_K}, W_{U_V} \in \mathbb{R}^{d_h n_h \times d_c} WUK,WUVRdhnh×dc 是上投影矩阵。

这种压缩将 KV 缓存大小从 O ( n h d h l ) O(n_h d_h l) O(nhdhl) 减少到 O ( d c l ) O(d_c l) O(dcl),其中 l l l 是层数。

解耦旋转位置嵌入

RoPE 通常用于 transformer 架构中,以将位置信息编码到 queries 和 keys 中。然而,标准 RoPE 应用与 MLA 的 key-value 压缩不兼容,因为它引入了阻止有效缓存的位置相关变换。

DeepSeek-V2 通过将 RoPE 与 key 压缩解耦来解决此问题:

  1. 引入辅助共享 key k t R k_t^R ktR 和额外的多头 queries q t R q_t^R qtR

  2. 仅将 RoPE 应用于 q t R q_t^R qtR k t R k_t^R ktR
    q t R = RoPE ( W Q R c Q t ) , k t R = RoPE ( W K R h t ) q_t^R = \text{RoPE}(W_{Q_R} c_{Q_t}), \quad k_t^R = \text{RoPE}(W_{K_R} h_t) qtR=RoPE(WQRcQt),ktR=RoPE(WKRht)

    其中 W Q R , W K R W_{Q_R}, W_{K_R} WQR,WKR 是特定于解耦 RoPE 的投影矩阵。

  3. 连接压缩的和 RoPE 应用的 keys/queries:
    q t = [ q t C ; q t R ] , k t = [ k t C ; k t R ] q_t = [q_t^C; q_t^R], \quad k_t = [k_t^C; k_t^R] qt=[qtC;qtR],kt=[ktC;ktR]
    确保 RoPE 仅影响注意力机制的一个子集,同时保持 key-value 压缩完整。

KV 缓存需求比较

MLA 的一个关键优势是,它在需要显著更少的 KV 缓存的同时,实现了比标准 MHA 更强的性能。下表比较了不同注意力机制的缓存大小:

注意力机制 每个 Token 的 KV 缓存(元素)
MHA 2 n h d h l 2 n_h d_h l 2nhdhl
GQA (分组查询) 2 n g d h l 2 n_g d_h l 2ngdhl
MQA (多查询) 2 d h l 2 d_h l 2dhl
MLA (DeepSeek-V2) ( d c + d h R ) l (d_c + d_h^R) l (dc+dhR)l

对于 DeepSeek-V2,值设置为:

  • d c = 4 d h d_c = 4d_h dc=4dh
  • d h R = d h / 2 d_h^R = d_h / 2 dhR=dh/2

这意味着 MLA 实现了与具有 2.25 个组的 GQA 相似的效率,同时保持了 MHA 的性能水平。

DeepSeek-V3 中的增强

DeepSeek-V3 引入了对 Multihead Latent Attention (MLA) 的多项关键增强,这些增强显著提高了其效率、可伸缩性和精度,同时保持了较高的模型准确性。主要改进包括:

  • 通过优化的压缩技术进一步减少 KV 缓存
  • 用于激活内存节省的查询压缩
  • 通过 FP8 混合精度增强的数值稳定性
  • 用于 MLA 中负载平衡的自适应路由

通过这些改进,DeepSeek-V3 降低了内存开销,增强了数值精度,并实现了显著更快的推理速度,同时保持了较高的模型准确性。

DeepSeek-R1 中的增强

DeepSeek-R1 对 MLA 进行了多项改进,提高了推理效率和推理性能,同时保持了较低的内存开销。在 DeepSeek-V3 中 MLA 优化的基础上,DeepSeek-R1 进一步增强了 KQV 压缩、RL 引导的注意力分配和数值稳定性机制。

MLA vs. MHA:对比分析

特性 MHA MLA
KQV 存储 存储完整的 KQV 矩阵 仅存储压缩后的 Key-Value 表示
内存占用 高,存储开销大 低,存储开销显著减少
计算复杂度 高,尤其对于长上下文任务 低,通过压缩技术减少计算量
位置编码 直接应用于 KQV 矩阵 解耦 RoPE,确保不影响压缩效率
适用场景 适用于短上下文任务 尤其适用于长上下文任务
推理速度 较慢,尤其对于长序列 较快,通过压缩和优化提升推理速度

通过优化压缩技术进一步减少键值缓存(KV Cache)

  • DeepSeek-V3的多头潜在注意力(MLA)的主要改进之一是在保持模型性能的同时,对键值缓存进行更深度的压缩。这通过以下方式实现:
    • 动态键值压缩矩阵:DeepSeek-V3并非使用静态压缩矩阵,而是针对每个序列长度动态优化压缩过程。
    • 键值存储的分解投影:应用双矩阵分解来对键和值进行下投影,进一步减少键值存储量。

优化的压缩公式

  • 给定一个输入令牌表示 h t ∈ R d h_t \in \mathbb{R}^d htRd,DeepSeek-V2中的标准多头潜在注意力(MLA)计算压缩后的键值表示如下:
    c K V i = W D K V h t c_{KV_i} = W_{D_{KV}}h_t cKVi=WDKVht
    • 其中 W D K V ∈ R d c × d W_{D_{KV}} \in \mathbb{R}^{d_c\times d} WDKVRdc×d是一个静态下投影矩阵。
  • 在DeepSeek-V3中,压缩过程通过自适应双矩阵压缩得到增强:
    c K V i = W D K V , 1 W D K V , 2 h t c_{KV_i} = W_{D_{KV,1}}W_{D_{KV,2}}h_t cKVi=WDKV,1WDKV,2ht
    • 其中 W D K V , 1 ∈ R d m × d W_{D_{KV,1}} \in \mathbb{R}^{d_m\times d} WDKV,1Rdm×d W D K V , 2 ∈ R d c × d m W_{D_{KV,2}} \in \mathbb{R}^{d_c\times d_m} WDKV,2Rdc×dm d m d_m dm是一个中间维度。
    • 这种分解能够实现更有效的压缩,与DeepSeek-V2相比,存储需求最多可减少40%。

推理时的扩展

  • 在推理过程中,现在扩展后的键和值按如下方式计算:
    k t C = W U K W M K c K V i , v t C = W U V W M V c K V i k^C_t = W_{U_K}W_{M_K}c_{KV_i}, \quad v^C_t = W_{U_V}W_{M_V}c_{KV_i} ktC=WUKWMKcKVi,vtC=WUVWMVcKVi
    • 其中 W M K , W M V W_{M_K}, W_{M_V} WMK,WMV作为中间投影层,用于优化键值重建过程。
  • 这一改进确保只有压缩后的向量存储在内存中,显著减少了键值缓存开销。

查询压缩以节省激活内存

  • DeepSeek-V3将多头潜在注意力(MLA)的低秩压缩扩展到查询(queries)上,在不影响注意力精度的情况下,降低了激活内存需求。

查询压缩公式

  • 不再计算完整的查询:
    q t = W Q h t , k t = W K h t , v t = W V h t q_t = W_Qh_t, \quad k_t = W_Kh_t, \quad v_t = W_Vh_t qt=WQht,kt=WKht,vt=WVht
  • DeepSeek-V3引入了一个额外的压缩步骤:
    c Q t = W D Q h t , q t C = W U Q c Q t c_{Q_t} = W_{D_Q}h_t, \quad q^C_t = W_{U_Q}c_{Q_t} cQt=WDQht,qtC=WUQcQt
    • 其中:
      • c Q t ∈ R d c ′ c_{Q_t} \in \mathbb{R}^{d'_c} cQtRdc 是压缩后的查询表示。
      • d c ′ ≪ d h n h d'_c \ll d_hn_h dcdhnh,这确保激活内存的使用量显著降低。

解耦旋转位置嵌入(RoPE)

  • 为了保持位置嵌入的有效性,DeepSeek-V3对旋转位置嵌入(RoPE)的应用进行了解耦:
    q t R = RoPE ( W Q R c Q t ) , k t R = RoPE ( W K R h t ) q^R_t = \text{RoPE}(W_{Q_R}c_{Q_t}), \quad k^R_t = \text{RoPE}(W_{K_R}h_t) qtR=RoPE(WQRcQt),ktR=RoPE(WKRht)
    • 其中:
      • q t R q^R_t qtR k t R k^R_t ktR 存储应用了RoPE的压缩表示版本。
      • 这可以防止RoPE干扰MLA的低秩压缩。

激活内存的减少

  • 通过查询压缩,DeepSeek-V3将注意力激活内存减少了35%,从而能够对大规模模型进行高效训练 。

利用FP8混合精度增强数值稳定性

  • DeepSeek-V3利用FP8混合精度训练,在降低内存和计算成本的同时,提高了数值稳定性。

针对MLA组件的FP8训练

  • 在DeepSeek-V2中,MLA组件主要以BF16(16位二进制浮点数)运行。而DeepSeek-V3采用了细粒度的FP8(8位浮点数)量化,并应用了分组缩放策略:
    • 激活缩放:对激活值采用每个令牌、每128个通道块的量化方式。
    • 权重缩放:对权重采用128×128的块级缩放。
    • 这确保了在训练中减少舍入误差,并能更好地覆盖动态范围。

FP8注意力计算

  • DeepSeek-V3中的注意力输出使用与FP8兼容的缩放方式进行计算:
    o t = ∑ j = 1 t Softmax ( q t T k j d h + d R ) v j o_t = \sum_{j = 1}^{t} \text{Softmax} \left( \frac{q^T_t k_{j}}{\sqrt{d_h + d_R}} \right) v_{j} ot=j=1tSoftmax(dh+dR qtTkj)vj
    • 其中:
      • 缩放因子针对激活值进行在线计算。
      • 每128步将累加结果升级为FP32(32位浮点数),以提高数值精度。

精度对比

组件 DeepSeek-V2(BF16) DeepSeek-V3(FP8)
查询/键压缩 d c = 4 d h d_c = 4d_h dc=4dh d c = 3 d h d_c = 3d_h dc=3dh
KV缓存存储 BF16 FP8
RoPE应用 全精度 解耦,FP8
注意力计算 BF16 FP8 + FP32累加
  • 通过利用FP8量化,DeepSeek-V3的训练效率提高了2.3倍,在不降低性能的前提下减少了内存消耗。

多头潜在注意力(MLA)中用于负载均衡的自适应路由

  • DeepSeek-V3通过为查询 - 键计算引入动态负载均衡,提高了注意力计算效率。

负载自适应路由机制

  • 在DeepSeek-V2中,MLA使用静态的注意力头分配方式,这导致在处理长序列时偶尔会出现计算效率低下的问题。
  • DeepSeek-V3通过自适应路由对此进行了改进:
    s i , t = Sigmoid ( w t T e i + b i ) s_{i,t} = \text{Sigmoid}(w^T_te_i + b_i) si,t=Sigmoid(wtTei+bi)
    • 其中:
      • e i e_i ei 是被路由到的专家的质心向量。
      • b i b_i bi 是一个动态更新的偏置项,用于调整每个注意力头的工作负载平衡。
  • 偏置项的更新方式如下:
    b i ( t + 1 ) = b i ( t ) − γ ⋅ ( overloaded i − underloaded i ) b^{(t + 1)}_i = b^{(t)}_i - \gamma \cdot (\text{overloaded}_i - \text{underloaded}_i) bi(t+1)=bi(t)γ(overloadediunderloadedi)
    • 其中 γ \gamma γ 是一个调整参数。
  • 这确保了:
    • 令牌在各个注意力头之间的分布平衡。
    • 在推理过程中不进行令牌丢弃,避免效率损失。

计算收益

  • 通过集成自适应路由,DeepSeek-V3实现了:
    • 各个注意力头之间的计算负载均匀。
    • 每个令牌的推理延迟降低10%。

DeepSeek-R1的改进

  • DeepSeek-R1对MLA进行了多项优化,在保持低内存开销的同时,提高了推理效率和推理性能。基于DeepSeek-V3中对MLA的优化,DeepSeek-R1进一步增强了键值(KQV)压缩、强化学习(RL)引导的注意力分配以及数值稳定性机制。

强化学习(RL)引导的潜在注意力优化

  • DeepSeek-R1将强化学习技术集成到多头潜在注意力(MLA)中,通过分组相对策略优化(GRPO)来优化注意力机制。与以往的确定性注意力策略不同,DeepSeek-R1基于强化奖励动态调整注意力权重,优先处理对强化推理路径有更大贡献的令牌。
  • GRPO无需单独的评论家模型,从而减少了内存开销并提高了收敛效率。
  • GRPO不依赖于监督微调,而是直接从组级奖励中估计优势值:
    A i = r i − mean ( { r 1 , r 2 , … , r G } ) std ( { r 1 , r 2 , … , r G } ) A_i = \frac{r_i - \text{mean}(\{r_1, r_2, \ldots, r_G\})}{\text{std}(\{r_1, r_2, \ldots, r_G\})} Ai=std({r1,r2,,rG})rimean({r1,r2,,rG})
  • 通过最大化以下公式来更新策略模型 π θ \pi_{\theta} πθ
    J GRPO ( θ ) = E [ ∑ i = 1 G min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i ) − β D K L ( π θ ∥ π ref ) ] J_{\text{GRPO}}(\theta) = \mathbb{E} \left[ \sum_{i = 1}^{G} \min \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i \right) - \beta D_{KL}(\pi_{\theta} \| \pi_{\text{ref}}) \right] JGRPO(θ)=E[i=1Gmin(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ϵ,1+ϵ)Ai)βDKL(πθπref)]
  • 这种方法使DeepSeek-R1能够自适应地优化MLA中的注意力机制,改善长文本推理中的令牌优先级排序。
  • 更多详细信息可在“RL算法:分组相对策略优化(GRPO)”部分中找到。

通过强化学习实现自适应查询和键压缩

DeepSeek-R1的MLA的主要改进之一是强化学习引导的自适应查询和键压缩。DeepSeek-V3已经引入了用于键值(KV)存储的低秩压缩技术,但DeepSeek-R1将压缩扩展到查询,在不影响注意力准确性的情况下减少激活内存。

  • 优化的压缩公式
    • 在DeepSeek-V3中,键值缓存压缩是通过静态低秩投影实现的:
      c K V i = W D K V h t c_{KV_i} = W_{D_{KV}}h_t cKVi=WDKVht
    • DeepSeek-R1在推理过程中使用基于强化学习的奖励最大化来动态调整压缩矩阵:
      c K V i = W D K V , 1 W D K V , 2 h t c_{KV_i} = W_{D_{KV,1}}W_{D_{KV,2}}h_t cKVi=WDKV,1WDKV,2ht
    • 其中:
      • W D K V , 1 ∈ R d m × d W_{D_{KV,1}} \in \mathbb{R}^{d_m\times d} WDKV,1Rdm×d W D K V , 2 ∈ R d c × d m W_{D_{KV,2}} \in \mathbb{R}^{d_c\times d_m} WDKV,2Rdc×dm
      • d m d_m dm 是一个中间维度,允许实现更细粒度的潜在空间表示。
  • 推理时的扩展
    • DeepSeek-R1没有使用单个上投影矩阵,而是采用了多阶段扩展流水线:
      k t C = W U K W M K c K V i , v t C = W U V W M V c K V i k^C_t = W_{U_K}W_{M_K}c_{KV_i}, \quad v^C_t = W_{U_V}W_{M_V}c_{KV_i} ktC=WUKWMKcKVi,vtC=WUVWMVcKVi
    • 其中 W M K , W M V W_{M_K}, W_{M_V} WMK,WMV 优化重建后的查询 - 键值,确保只有压缩向量存储在内存中。
  • 压缩比改进:DeepSeek-R1在保持查询 - 键检索准确性的同时,相比DeepSeek-V3,进一步将键值缓存需求降低了25%。

带上下文特定缩放的解耦旋转位置嵌入

  • DeepSeek-V3引入了解耦的旋转位置嵌入(RoPE),将位置编码与压缩后的键值表示分离,而DeepSeek-R1通过上下文特定的缩放机制进一步优化了RoPE。
  • DeepSeek-R1采用了一种改进的RoPE公式,其中RoPE具有上下文感知能力,可根据序列长度动态调整缩放因子:
    λ t = 1 1 + α L t \lambda_t = \frac{1}{\sqrt{1 + \alpha L_t}} λt=1+αLt 1
    • 其中:
      • λ t \lambda_t λt 是位置嵌入的自适应缩放因子。
      • α \alpha α 是通过强化学习优化得到的超参数。
      • L t L_t Lt 表示时间步 t t t 时的序列长度。
  • 实现优势
    • RoPE缩放确保了在不同序列长度下注意力对齐的一致性。
    • 在压缩MLA的键值状态时,防止位置信息退化。

用于MLA稳定性的FP8混合精度

  • DeepSeek-R1在MLA计算中采用FP8量化,相比DeepSeek-V3基于BF16的方法,进一步提高了数值稳定性。
  • 在DeepSeek-R1的精度感知计算流水线中,查询(Q)、键(K)、值(V)矩阵通过分组缩放进行动态量化:
    Q ~ = Q s Q , K ~ = K s K , V ~ = V s V \tilde{Q} = \frac{Q}{s_Q}, \quad \tilde{K} = \frac{K}{s_K}, \quad \tilde{V} = \frac{V}{s_V} Q~=sQQ,K~=sKK,V~=sVV
    • 其中 s Q , s K , s V s_Q, s_K, s_V sQ,sK,sV 是通过学习得到的分组缩放因子。
  • 注意力输出采用混合精度累加进行计算:
    o t = ∑ j = 1 t Softmax ( q ~ t T k ~ j d h + d R ) v ~ j o_t = \sum_{j = 1}^{t} \text{Softmax} \left( \frac{\tilde{q}^T_t \tilde{k}_j}{\sqrt{d_h + d_R}} \right) \tilde{v}_j ot=j=1tSoftmax(dh+dR q~tTk~j)v~j
  • 累加过程每128步升级为FP32,在保持FP8效率的同时确保更好的数值精度。

MLA精度策略对比

组件 DeepSeek-V3(BF16) DeepSeek-R1(FP8)
查询/键压缩 d c = 4 d h d_c = 4d_h dc=4dh d c = 3 d h d_c = 3d_h dc=3dh
KV缓存存储 BF16 FP8
RoPE应用 全精度 解耦,FP8
注意力计算 BF16 FP8 + FP32累加

效率提升

  • 与BF16相比,FP8可将内存占用减少约40%。
  • 对于长文本任务,推理吞吐量提升2.3倍。

用于负载均衡注意力的自适应/动态路由

  • DeepSeek-R1引入了负载均衡自适应路由机制,确保在各个注意力头之间,查询 - 键的计算负载均匀。
  • DeepSeek-R1使用基于Sigmoid的路由函数来优化每个注意力头的工作负载平衡:
    s i , t = Sigmoid ( u t T e i + b i ) s_{i,t} = \text{Sigmoid}(u^T_te_i + b_i) si,t=Sigmoid(utTei+bi)
    • 其中:
      • e i e_i ei 表示被路由的注意力专家的质心向量。
      • b i b_i bi 是一个自适应偏置项,用于确保工作负载均匀。
  • 性能提升
    • 各个注意力头之间的均衡计算避免了瓶颈问题。
    • 将每个令牌的推理延迟降低了10%。

对比分析

  • DeepSeek-V2引入了多头潜在注意力(MLA),具备显著的键值(KV)缓存压缩能力、解耦的旋转位置嵌入(RoPE),以及为提高效率而采用的基本低秩投影。DeepSeek-V3在此基础上,进一步减小了键值缓存大小,优化了查询压缩,并引入了FP8混合精度以增强数值稳定性。DeepSeek-R1通过集成诸如分组相对策略优化(GRPO)等强化学习技术,对MLA进行了更深入的优化,从而能够动态优化注意力分配。DeepSeek-R1的最新改进还提升了推理延迟和内存效率,使其成为目前MLA的最优化版本。
  • 下表对DeepSeek-V2、DeepSeek-V3和DeepSeek-R1的MLA进行了对比分析。该对比突出了各版本在压缩技术、精度、路由机制和推理效率方面的关键改进。
特性 DeepSeek-V2 DeepSeek-V3 DeepSeek-R1
低秩键值(KV)压缩 ✔(通过分解投影优化) ✔(强化学习优化的自适应压缩)
查询压缩 ✔(静态低秩查询压缩) ✔(强化学习引导的动态查询压缩)
键值缓存缩减 ✔(减少93.3%) ✔(在V2基础上再减少40%) ✔(在V3基础上再减少25%)
旋转位置嵌入(RoPE)应用 ✔(解耦RoPE) ✔(解耦并带有上下文特定缩放) ✔(增强的上下文感知缩放)
精度格式 BF16(16位二进制浮点数) FP8(细粒度混合精度) FP8(分组缩放,FP32累加)
多头潜在注意力(MLA)的自适应路由 ✔(静态自适应路由) ✔(负载均衡动态路由)
推理延迟降低 ✔(键值压缩降低延迟) ✔(比V2快10%) ✔(比V3快10%)
强化学习(RL)改进 ✔(使用分组相对策略优化(GRPO)进行自适应MLA优化)
数值稳定性改进 ✔(基本稳定性增强) ✔(FP8混合精度) ✔(FP8结合强化学习引导的稳定性机制)
长文本性能 ✔(支持更长文本) ✔(进一步优化) ✔(通过强化学习引导的令牌优先级排序增强)

实现方式

DeepSeek-R1中多头潜在注意力(MLA)的实现包含了多项优化,旨在在保持准确性的同时最大化效率。本节详细介绍了MLA的核心机制,包括键值压缩、查询变换、位置编码和计算优化。

背景:标准多头注意力(MHA)

  • 对于标准的多头注意力(MHA)机制,键(K)、查询(Q)和值(V)矩阵的计算如下:
    K , Q , V = W k X , W q X , W v X K, Q, V = W_kX, W_qX, W_vX K,Q,V=WkX,WqX,WvX
    • 其中 W k W_k Wk W q W_q Wq W v W_v Wv 分别是用于键、查询和值投影的权重矩阵。
  • 注意力权重的计算如下:
    A = Softmax ( Q K T d k ) A = \text{Softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) A=Softmax(dk QKT)
  • 输出为:
    O = A V O = AV O=AV
  • 这要求在推理过程中存储完整的键值缓存,从而导致显著的内存开销。

低秩键值联合压缩

  • MLA的一项基础优化是将键值对压缩到一个更低维度的潜在空间,显著降低内存开销。具体如下:
    • 压缩机制
      • 键和值的表示在被投影回各自维度之前,先被压缩到一个共享的潜在空间。这通过两步变换实现:
        c K V i = W D K V h t c_{KV_i} = W_{D_{KV}}h_t cKVi=WDKVht
        k t C = W U K c K V i , v t C = W U V c K V i k^C_t = W_{U_K}c_{KV_i}, \quad v^C_t = W_{U_V}c_{KV_i} ktC=WUKcKVi,vtC=WUVcKVi
      • 其中:
        • c K V i ∈ R d c c_{KV_i} \in \mathbb{R}^{d_c} cKViRdc 是压缩后的潜在表示。
        • W D K V ∈ R d c × d W_{D_{KV}} \in \mathbb{R}^{d_c\times d} WDKVRdc×d 是一个下投影矩阵。
        • W U K , W U V ∈ R d h n h × d c W_{U_K}, W_{U_V} \in \mathbb{R}^{d_hn_h\times d_c} WUK,WUVRdhnh×dc 分别是用于键和值的上投影矩阵。
    • 内存缩减
      • 对于每个令牌,无需存储完整尺寸的键和值,仅缓存 c K V i c_{KV_i} cKVi
      • 内存占用的减少使得DeepSeek-R1能够以更低的计算成本处理更长的序列。

多阶段压缩

DeepSeek-R1通过引入一个额外的变换层来优化压缩机制,从而形成了一种多阶段压缩方法。具体如下:

  • 额外的投影层
    • 为了进一步降低存储成本,引入了一个二次压缩层:
      c K V i ′ = W D K V 2 f ( W D K V h t ) c'_{KV_i} = W_{DKV2}f(W_{DKV}h_t) cKVi=WDKV2f(WDKVht)
    • 其中:
      • W D K V 2 ∈ R d c ′ × d c W_{DKV2} \in \mathbb{R}^{d'_c\times d_c} WDKV2Rdc×dc 是第二个下投影矩阵。
      • f ( ⋅ ) f(\cdot) f() 是一个非线性激活函数,用于提升表示学习效果。
      • d c ′ < d c d'_c < d_c dc<dc,这确保键值(KV)缓存的大小更小。
  • 性能优势
    • 这个额外的步骤在为注意力机制保留足够信息的同时,进一步减少了键值存储量。
    • 实验表明,与DeepSeek-V3相比,这使得内存占用减少了10%-15%。

查询压缩与优化

  • 与键和值类似,查询也会被压缩,以便在训练过程中进行高效计算并减少激活内存的使用。具体如下:
    • 查询变换
      • 查询会经历类似于键和值的两步变换:
        c Q t = W D Q h t c_{Q_t} = W_{DQ}h_t cQt=WDQht
        q t C = W U Q c Q t q^C_t = W_{UQ}c_{Q_t} qtC=WUQcQt
      • 其中:
        • W D Q ∈ R d c ′ × d W_{DQ} \in \mathbb{R}^{d'_c\times d} WDQRdc×d 是用于查询的下投影矩阵。
        • W U Q ∈ R d h n h × d c ′ W_{UQ} \in \mathbb{R}^{d_hn_h\times d'_c} WUQRdhnh×dc 将压缩后的查询表示映射回其原始维度。
    • 多层查询优化
      • DeepSeek-R1通过额外的自适应缩放层来优化查询投影。
      • 在微调过程中,使用强化学习(RL)动态调整变换矩阵 W D Q W_{DQ} WDQ W U Q W_{UQ} WUQ

解耦旋转位置嵌入(RoPE)

为确保能够稳健地处理长文本内容,DeepSeek-R1以解耦的方式应用RoPE,将位置编码与潜在注意力机制分离。具体如下:

  • 键和查询的独立位置编码
    k R t = RoPE ( W K R h t ) k_{R_t} = \text{RoPE}(W_{KR}h_t) kRt=RoPE(WKRht)
    q R t = RoPE ( W Q R c Q t ) q_{R_t} = \text{RoPE}(W_{QR}c_{Q_t}) qRt=RoPE(WQRcQt)
    • 其中:
      • W K R ∈ R d r h × d W_{KR} \in \mathbb{R}^{d_{rh}\times d} WKRRdrh×d 用于生成键的位置嵌入。
      • W Q R ∈ R d r h n h × d c ′ W_{QR} \in \mathbb{R}^{d_{rh}n_h\times d'_c} WQRRdrhnh×dc 用于生成查询的位置嵌入。
      • RoPE变换可确保在保持键值缓存紧凑的同时,保留相对位置信息。
  • DeepSeek-R1中RoPE的计算效率
    • RoPE的应用被延迟到查询 - 键交互的最后阶段,从而避免不必要的内存占用。
    • 与DeepSeek-V2和V3相比,DeepSeek-R1的查询 - 键检索速度快25%。

多头潜在注意力(MLA)中的注意力计算

MLA中的最终注意力输出是通过在改进的注意力机制中整合压缩后的键、查询和值来计算的。具体如下:

  • 改进的注意力分数
    • 注意力分数是使用压缩后的潜在键和显式位置编码来计算的:
      A t , j , i = q t , i T k j , i d h + d R A_{t,j,i} = \frac{q^T_{t,i}k_{j,i}}{\sqrt{d_h + d_R}} At,j,i=dh+dR qt,iTkj,i
    • 该公式可确保位置嵌入对注意力强度的贡献成比例。
  • 加权值聚合
    • 注意力输出的计算方式如下:
      o t , i = ∑ j = 1 t Softmax j ( A t , j , i ) v j , i C o_{t,i} = \sum_{j = 1}^{t} \text{Softmax}_j(A_{t,j,i})v^C_{j,i} ot,i=j=1tSoftmaxj(At,j,i)vj,iC
    • Softmax操作会对序列中的注意力分数进行归一化处理。
  • 最终输出投影
    • 最终输出通过以下方式获得:
      u t = W O [ o t , 1 ; o t , 2 ; … ; o t , n h ] u_t = W_O[o_{t,1}; o_{t,2}; \ldots ; o_{t,n_h}] ut=WO[ot,1;ot,2;;ot,nh]
    • 其中:
      • W O W_O WO 是输出投影矩阵,用于将连接后的注意力输出映射回完整的嵌入空间。

强化学习优化的多头潜在注意力(MLA)

DeepSeek-R1整合了强化学习(RL)技术,以进一步优化MLA的变换矩阵。具体内容如下:

  • 基于强化学习的微调
    • 利用分组相对策略优化(GRPO),根据内存的高效使用和检索准确性对MLA进行奖励。
    • 策略更新公式为:
      J GRPO ( θ ) = E [ ∑ i = 1 G min ⁡ ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) A i , clip ( π θ ( o i ∣ q ) π θ old ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i ) ] J_{\text{GRPO}}(\theta) = \mathbb{E} \left[ \sum_{i = 1}^{G} \min \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)} A_i, \text{clip} \left( \frac{\pi_{\theta}(o_i|q)}{\pi_{\theta_{\text{old}}}(o_i|q)}, 1 - \epsilon, 1 + \epsilon \right) A_i \right) \right] JGRPO(θ)=E[i=1Gmin(πθold(oiq)πθ(oiq)Ai,clip(πθold(oiq)πθ(oiq),1ϵ,1+ϵ)Ai)]
    • 其中:
      • π θ \pi_{\theta} πθ 表示更新后的策略。
      • A i A_i Ai 是引导优化的优势函数。
    • 更多详细信息可在“RL算法:分组相对策略优化(GRPO)”部分中找到。

计算和硬件优化

  • 推理时的效率
    • DeepSeek-R1中的MLA通过张量并行计算来实现,优化了跨GPU的吞吐量。
    • 通过低精度的键值(KV)存储(FP8格式),将内存开销降至最低。
  • 跨节点通信优化
    • 使用优化后的全连接通信内核,以充分利用InfiniBand(IB)和NVLink的带宽。
    • 将节点间的通信延迟降低30%,提升分布式推理性能。

效率对比分析

注意力机制 每个令牌的键值缓存 计算复杂度 性能影响
标准多头注意力(MHA) O ( N d h ) O(Nd_h) O(Ndh) O ( N 2 d h ) O(N^2d_h) O(N2dh) 高精度,高成本
多头查询注意力(MQA) O ( d h ) O(d_h) O(dh) O ( N d h ) O(Nd_h) O(Ndh) 内存占用低,性能下降
分组查询注意力(GQA) O ( g d h ) O(gd_h) O(gdh)(分组) O ( N d h ) O(Nd_h) O(Ndh) 性能平衡
MLA(DeepSeek-V2) O ( d L ) O(d_L) O(dL) O ( N d L ) O(Nd_L) O(NdL) 高效率,损失极小
MLA + 分层缓存(DeepSeek-R1) O ( d L ) O(d_L) O(dL)(可复用) O ( N d L ) O(Nd_L) O(NdL) 效率峰值,性能保持

你可能感兴趣的:(算法,人工智能,自然语言处理,DeepSeek)