FLOPs:Floating Point Operations 指的是浮点运算次数,一般特指乘加运算次数,理解为计算量,可以用来衡量算法/模型时间的复杂度。更大的计算量单位通常包括:
张量运算的计算量通常与运算维度和操作类型有关,以pytorch
中线性层nn.Linear
的计算为例,设输入张量的维度为 B × S × D B \times S \times D B×S×D,线性层内部权重矩阵维度为 D × O D \times O D×O:
bias
,两个张量相乘的结果维度为 B × S × O B \times S \times O B×S×O,结果中的每个元素是由原始张量分别沿着 D D D维度进行了 D D D次乘法和 D − 1 D-1 D−1次加法而来的,因此总计算量为:( 2 D − 1 ) × B × S × O (2D-1)\times B \times S \times O (2D−1)×B×S×O
bias
,则每个元素由原始张量分别沿着 D D D维度进行 D D D次乘法和 D − 1 D-1 D−1次加法后,还需加上bias
,因此一共也执行了 D D D次加法,总计算量为:2 D × B × S × O 2D \times B \times S \times O 2D×B×S×O
为了简单起见,后续分析时均以考虑bias
来分析,这样FLOPs的计算可直接由相关维度的相乘而来。
张量计算顺序的不同会影响计算量。以下是一个例子:
假设有三个张量 A A A、 B B B 和 C C C,它们的形状分别为:
我们需要计算 A × B × C A \times B \times C A×B×C,其中 × \times × 表示矩阵乘法。
计算顺序 1:先计算 A × B A \times B A×B,再乘以 C C C
计算顺序 2:先计算 B × C B \times C B×C,再乘以 A A A
比较两种计算顺序:
计算顺序 1的总计算量为 2 m n p + 2 m p q 2mnp + 2mpq 2mnp+2mpq。
计算顺序 2的总计算量为 2 n p q + 2 m n q 2npq + 2mnq 2npq+2mnq。
将上述两式相减,有:
2 [ m n ( p − q ) + p q ( m − n ) ] 2[mn(p-q)+pq(m-n)] 2[mn(p−q)+pq(m−n)]
可见如果 p < q , m < n pp<q,m<n
则必定计算顺序1的计算量更小,如果 p > q , m > n p>q,m>n p>q,m>n则反之,其余情况 则需根据具体数值分析。
我们比较三种计算顺序:
假设原始序列 h \mathbf{h} h经Q低秩压缩后得到 c Q \mathbf{c}^Q cQ,经KV低秩压缩得到 c K V \mathbf{c}^{KV} cKV,它们的上投影矩阵分别为 W U Q W^{UQ} WUQ和 W U K W^{UK} WUK。
原始注意力计算如下:
( W U Q c Q ) T ( W U K c K V ) (W^{UQ}\mathbf{c}^Q)^T (W^{UK}\mathbf{c}^{KV}) (WUQcQ)T(WUKcKV)
上述张量的形状如下,箭头右边是简记的符号,并将n_heads × qk_nope_head_dim
进行了拆分:
(q_lora_rank, n_heads × qk_nope_head_dim) -> (q, h, d)
(bsz, q_seq_len, q_lora_rank) -> (b, s, q)
(kv_lora_rank, n_heads × qk_nope_head_dim) -> (k, h, d)
(bsz, k_seq_len, kv_lora_rank) -> (b, t, k)
(bsz, q_seq_len, n_heads, qk_nope_head_dim) -> (b, s, h, d)
(bsz, k_seq_len, n_heads, qk_nope_head_dim) -> (b, t, h, d)
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
这里区分q_seq_len和k_seq_len,训练或prefill时二者是一致的,decode时q_seq_len是1,k_seq_len是cache的长度。
根据张量计算量分析的规则,计算量如下:
FLOPs order 1 = 2 b s h d q + 2 b t h d k + 2 b h s t d \text{FLOPs}_{\text{order}_1}=2bshdq+2bthdk+2bhstd FLOPsorder1=2bshdq+2bthdk+2bhstd
[ ( W U Q c Q ) T W U K ] c K V [(W^{UQ}\mathbf{c}^Q)^T W^{UK}]\mathbf{c}^{KV} [(WUQcQ)TWUK]cKV
(bsz, q_seq_len, n_heads, qk_nope_head_dim) -> (b, s, h, d)
(bsz, q_seq_len, n_heads, kv_lora_rank) -> (b, s, h, k)
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
计算量如下:
FLOPs order 2 = 2 b s h d q + 2 b s h k d + 2 b h s t k \text{FLOPs}_{\text{order}_2}=2bshdq+2bshkd+2bhstk FLOPsorder2=2bshdq+2bshkd+2bhstk
c Q T ( W U Q T W U K ) c K V {\mathbf{c}^Q}^T(W^{UQ^T} W^{UK})\mathbf{c}^{KV} cQT(WUQTWUK)cKV
(n_heads, q_lora_rank, kv_lora_rank) -> (h, q, k)
(bsz, q_seq_len, n_heads, kv_lora_rank) -> (b, s, h, k)
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
计算量如下:
FLOPs order 3 = 2 h q k d + 2 b s h k q + 2 b h s t k \text{FLOPs}_{\text{order}_3}=2hqkd+2bshkq+2bhstk FLOPsorder3=2hqkd+2bshkq+2bhstk
首先比较 FLOPs order 1 \text{FLOPs}_{\text{order}_1} FLOPsorder1和 FLOPs order 2 \text{FLOPs}_{\text{order}_2} FLOPsorder2,有:
FLOPs order 1 − FLOPs order 2 = 2 b h d k ( t − s ) + 2 b h s t ( d − k ) \text{FLOPs}_{\text{order}_1}-\text{FLOPs}_{\text{order}_2}= 2bhdk(t-s)+2bhst(d-k) FLOPsorder1−FLOPsorder2=2bhdk(t−s)+2bhst(d−k)
其中:
t
:k_seq_len
s
:q_seq_len
d
:qk_nope_head_dim = 128
k
:kv_lora_rank = 512
h
:n_heads = 128
b
:bsz
由于第一项和第二项都有b
,为简单起见,设为1在训练或prefill阶段,t
=s
,上式结果为 − 98304 s 2 -98304s^2 −98304s2,此时顺序1的计算量更优。
在decode阶段,t
是缓存长度,而s
=1,上式结果为 16777216 ( t − 1 ) − 98304 t = 16678912 t − 16777216 16777216(t-1)-98304t=16678912t-16777216 16777216(t−1)−98304t=16678912t−16777216,可见,推理时随着缓存长度t
的变大,顺序1需要花费更大的计算量,因此才需要把 W U K W^{UK} WUK吸收进 W U Q c Q W^{UQ}\mathbf{c}^Q WUQcQ(也就是代码中的q_nope
)中,避免产生的中间量需要大量的计算。
然后比较 FLOPs order 2 \text{FLOPs}_{\text{order}_2} FLOPsorder2和 FLOPs order 3 \text{FLOPs}_{\text{order}_3} FLOPsorder3,有:
FLOPs order 2 − FLOPs order 3 = 2 h d q ( b s − k ) + 2 b s h k ( d − q ) \text{FLOPs}_{\text{order}_2}-\text{FLOPs}_{\text{order}_3}= 2hdq(bs-k)+2bshk(d-q) FLOPsorder2−FLOPsorder3=2hdq(bs−k)+2bshk(d−q)
其中:
q
:q_lora_rank = 1536
b
:bsz
第一项的b
无法作为因子提出,因此先不假定具体值上式结果中不包含t
,结果为 50331648 ( b s − 512 ) − 184549376 b s = − 134217728 b s − 25769803776 50331648(bs-512)-184549376bs=-134217728bs-25769803776 50331648(bs−512)−184549376bs=−134217728bs−25769803776,恒小于0,因此顺序2的计算量优于顺序3。其原因是 ( W U Q T W U K ) (W^{UQ^T} W^{UK}) (WUQTWUK)充当了新的 W U Q ′ W^{UQ'} WUQ′,其形状为(h, q, k)
,具有100663296个元素。而 W U Q W^{UQ} WUQ和 W U K W^{UK} WUK的形状分别为(q, h, d)
和(k, h, d)
,二者之和只有33554432个元素,约为 W U Q ′ W^{UQ'} WUQ′的33%,这就解释了虽然公式上直接将 W U K W^{UK} WUK吸收进了 W U Q W^{UQ} WUQ,但为什么代码实现上不这么做的原因。不论是从参数量占用还是计算量上,顺序3都没有优势。
同样比较三种计算顺序:
假设得到的score
形状大小为(bsz, n_heads, q_seq_len, k_seq_len)
, c K V \mathbf{c}^{KV} cKV向value
的上投影矩阵为 W U V W^{UV} WUV,输出维度变换 矩阵为 W O W^O WO。
原始计算顺序如下:
W O [ s c o r e ( W U V c K V ) ] W^O[score(W^{UV} \mathbf{c}^{KV})] WO[score(WUVcKV)]
上述张量的形状如下,将n_heads × v_head_dim
进行了拆分:
(bsz, n_heads, q_seq_len, k_seq_len) -> (b, h, s, t)
(bsz, k_seq_len, kv_lora_rank) -> (b, t, k)
(kv_lora_rank, n_heads × v_head_dim) -> (k, h, v)
(n_heads × v_head_dim, dim) -> (h, v, e)
(bsz, k_seq_len, n_heads, v_head_dim) -> (b, t, h, v)
(bsz, n_heads, q_seq_len, v_head_dim) -> (b, h, s, v)
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 1 = 2 b t h v k + 2 b h s v t + 2 b h s e v \text{FLOPs}_{\text{order}_1}=2bthvk+2bhsvt+2bhsev FLOPsorder1=2bthvk+2bhsvt+2bhsev
W O [ W U V ( s c o r e c K V ) ] W^O[W^{UV} (score\mathbf{c}^{KV})] WO[WUV(scorecKV)]
(bsz, n_heads, q_seq_len, kv_lora_rank) -> (b, h, s, k)
(bsz, n_heads, q_seq_len, v_head_dim) -> (b, h, s, v)
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 2 = 2 b h s k t + 2 b h s v k + 2 b h s e v \text{FLOPs}_{\text{order}_2}=2bhskt+2bhsvk+2bhsev FLOPsorder2=2bhskt+2bhsvk+2bhsev
( W O W U V ) ( s c o r e c K V ) (W^OW^{UV})(score\mathbf{c}^{KV}) (WOWUV)(scorecKV)
(n_heads, kv_lora_rank, dim) -> (h, k, e)
(bsz, n_heads, q_seq_len, kv_lora_rank) -> (b, h, s, k)
(bsz, n_heads, q_seq_len, dim) -> (b, h, s, e)
计算量如下:
FLOPs order 3 = 2 h k e v + 2 b h s k t + 2 b h s e k \text{FLOPs}_{\text{order}_3}=2hkev+2bhskt+2bhsek FLOPsorder3=2hkev+2bhskt+2bhsek
首先比较 FLOPs order 1 \text{FLOPs}_{\text{order}_1} FLOPsorder1和 FLOPs order 2 \text{FLOPs}_{\text{order}_2} FLOPsorder2,有:
FLOPs order 1 − FLOPs order 2 = 2 b h v k ( t − s ) + 2 b h s t ( v − k ) \text{FLOPs}_{\text{order}_1}-\text{FLOPs}_{\text{order}_2}=2bhvk(t-s)+2bhst(v-k) FLOPsorder1−FLOPsorder2=2bhvk(t−s)+2bhst(v−k)
其中:
t
:k_seq_len
s
:q_seq_len
v
:v_head_dim = 128
k
:kv_lora_rank = 512
h
:n_heads = 128
b
:bsz
由于第一项和第二项都有b
,为简单起见,设为1由于v
与d
值大小一样,因此计算结果与与第一次矩阵吸收一致。即在训练或prefill阶段,顺序1更优,在decode阶段,顺序2更优。
然后比较 FLOPs order 2 \text{FLOPs}_{\text{order}_2} FLOPsorder2和 FLOPs order 3 \text{FLOPs}_{\text{order}_3} FLOPsorder3,有:
FLOPs order 2 − FLOPs order 3 = 2 h v k ( b s − e ) + 2 b h s e ( v − k ) \text{FLOPs}_{\text{order}_2}-\text{FLOPs}_{\text{order}_3}=2hvk(bs-e)+2bhse(v-k) FLOPsorder2−FLOPsorder3=2hvk(bs−e)+2bhse(v−k)
其中:
e
:dim = 7168
b
:bsz
第一项的b
无法作为因子提出,因此先不假定具体值上式结果为 16777216 ( b s − 7168 ) − 704643072 b s = − 687865856 b s − 120259084288 16777216(bs-7168)-704643072bs=-687865856bs −120259084288 16777216(bs−7168)−704643072bs=−687865856bs−120259084288,可见仍然是顺序2的计算结果更优。