FlashAttention是一种IO感知精确注意力算法。通过感知显存读取/写入,FlashAttention的运行速度比PyTorch标准Attention快了2-4倍,所需内存也仅是其5%-20%。随着Transformer变得越来越大、越来越深,但它在长序列上仍然处理的很慢、且耗费内存。(自注意力时间和显存复杂度与序列长度成二次方),现有近似注意力方法,在试图通过去牺牲模型质量,以降低计算复杂度来解决该问题。但存在一定的局限性,即不能提升运行时的训练速度。
基于这样子背景,研究人员觉得应该让注意力算法具有IO感知,即考虑显存级间的读写,比如大但慢的HBM(High Bandwidth Memory)技术与小但快的SRAM。基于这样的背景,研究人员提出了FlashAttention,具体有两种加速技术:
Flash Attention主要运用了两种技术:分块和算子融合计算。
分块即将输入划分成块,并单独处理每个块的注意力;
算子融合,将多个算子融合成一个算子,将所有注意力操作融合到CUDA内核中。
在深度学习中,Attention Mechanism(注意力机制)用于在序列数据中学习不同位置的重要性权重,以便更好地捕捉序列中的关键信息。Attention Scores(注意力分数)是用来计算这些权重的指标。以下是关于计算Attention Scores的推导过程的文档:
先上经典的示意图
在注意力机制中,Q(Query)、K(Key)、V(Value)和softmax函数是核心概念,用于计算注意力权重并生成输出表示。
Query(查询)、Key(键)、Value(值)
Query Q(查询):查询是用来寻找与之相关性的键的向量。在注意力机制中,查询向量Q通常是通过对输入数据进行线性变换得到的,用于衡量查询与键的相似度,决定了在生成输出表示时每个键的重要程度。
Key K(键):键是用来表示输入数据的向量,与查询一起计算注意力权重。键向量K也是通过对输入数据进行线性变换得到的。
Value V(值):值是用来生成最终输出表示的向量。值向量V同样是通过对输入数据进行线性变换得到的,根据注意力权重对值进行加权求和得到最终输出。
Softmax函数
Softmax函数是一个常用的激活函数,通常用于多分类问题中,将输入转化为概率分布。在注意力机制中,softmax函数被用来计算注意力权重,将注意力分数转化为概率分布,使得不同位置的输入在生成输出时得到不同的权重。
标准的Attention实现算法如下。
先给出整体的的代码框架:
为了更容易理解K、V、Q的分块计算推导,我们可以将其分解为几个关键步骤,并用直观的方式解释每个步骤的目的和计算方法。以下是基于FlashAttention V1的分块计算的核心推导:
首先,我们将输入序列的查询(Query)矩阵Q、键(Key)矩阵的转置K^T和值(Value)矩阵V分别进行分块。
Q分为Tr个块,每个块的大小为(Br, d),其中Br是块的长度,d是向量的维度。
K^T分为Tc个块,每个块的大小为(d, Bc)。
V分为Tc个块,每个块的大小为(Bc, d)。
对于每一对分块(Q_i, K_j^T),我们计算它们的点积来得到局部注意力分数S_ij:
S_ij = Q_i * K_j^T
这里的S_ij是一个(Br, Bc)的矩阵,表示第i个Q块与第j个K块的转置之间的相关性分数。
为了避免在计算softmax函数时发生数值溢出,我们使用安全softmax方法。对于每个S_ij,我们首先计算行的最大值m(S_ij)和行的指数和l(S_ij),然后应用到每一行上进行归一化:
使用归一化前的注意力分P_ij和对应的V_j块,我们可以计算输出O_ij:
这里的O_ij是一个(Br, d)的矩阵,表示通过注意力机制加权后的输出。
由于我们是对Q、K、V进行了分块计算,每个O_ij只是部分输出。为了得到最终的输出O,我们需要将所有分块的输出合并起来。这通常涉及到对所有O_ij块的累加操作。
在FlashAttention中,关键的优化是减少对显存的访问次数。在分块计算中,我们只加载必要的Q_i、K_j^T和V_j块到GPU的片上内存(SRAM),计算完成后直接使用,避免了频繁地从片下内存(HBM)读取和写入大型矩阵。
通过这种方式,FlashAttention能够在保持精确计算的同时,显著提高大模型的计算效率,尤其是在处理长序列时。分块计算不仅减少了内存占用,还降低了内存带宽的需求,从而提高了整体的计算性能。
分块向上取整是为了分块的存够大,有空间冗余,不会丢失数据, 保证同时加载四个Q, K, V, O tensor的时候,内存不爆掉。
這個Mij的維度正常理解每一行的最大值构成的向量。所以维度是行数。
分块计算是按照br的纵向进行分块计算的。
第十步是为了防止某个数过大溢出,然后进行的操作技巧。计算归一化前的一个分块的结果。
一个外圈的第j个循环,得到一个完整的O(n*d), 我们记作O_j, 第j+1次时候,会在 O_j 的基础上继续更新, 得到O_{j+1}, 最终所有 j 循环结束, 得到最新的完整的 O, 就是最终的结果,这个 j 循环过程中, 计算出了Tc 个完整的O(n*d), 每次都在原有基础上更新前面的O.
trick:
所有以 ij 作为下标的,都表示当前分块的计算结果
所有以 i作为下标的,都表示截止到前一个分块(包含前一个分块)的计算结果
所有以 new 为上标的,都表示引入当前分块做更新后的结果
所有没有下标的,都表示全局结果
分块示意图:
前两小节是伪代码的11, 12 的计算定义计算的背景。
前两小节是伪代码的11, 12 的计算定义计算的背景。
f(x)的计算,可以构造
第一步部分:f(x__1) = exp(x_1 - m(x_1) + m(x_1) - m(x)) = f(x_1)* exp(m(x_1) - m(x))
第二部分:f(x__2) = exp(x_2 - m(x_2) + m(x_2) - m(x)) = f(x_2)* exp(m(x_2) - m(x))
得到f(x) =[f(x__1), f(x__2)]
最終使用分塊的部分值
f(x_1): 分块[x_1]的attention的分子部分
f(x_2):分块[x_2]的attention的分子部分
l(x_1):分块[x_1]的attention的分母部分
l(x_2):分块[x_2]的attention的分母部分
m(x_1):分块[x_1]的最大值
m(x_2):分块[x_2]的最大值
m(x):[x_1, x_2]全局向量的最大值
表示了全局f(x)的attention 分數計算了。这个也是为了后面的遍历分块计算,使用各个子块的数据,计算得到全局的attention分数。
这一节就是对应伪代码的11,12 行的原理。
这里我们换个思路: O_i 不是每遍历一块就更新一次吗?那有没有一种办法,不断用当前最新的rowmax和rowsum去更新 O_i ,直到遍历完最后一块,这时的 O_i不就和标准场景下的结果完全一致了吗?也就是我们想构造形如下面这样的更新等式:
暂且接受这个公式说的,这个样子就得到了伪代码的第12行。
diag(l_j+1): 传入的l_j+1 是每一行的p_ij构成的向量,使用diag 函数,将它构建成一个方阵,对角线的元素是l_j+1的值,其实就是构建和 j+1的方阵,每一行只有对角线的位置一个元素。
对角矩阵的逆可以就是将对角元素的倒数(1/m)构成的对角方阵。那刚好就是attention score的分母
O_i * () 表示的是之前计算的O_i的attention score, pij * vj 这一部分是最新的一个循环计算的attention score。
上面是分块的各个变量的示意图和计算的步骤过程,最终计算得到
各个分块的各个行最大值构成的向量:m_ij
各个分块的每一个值和该行最大值max(x)的exp(x_ij-max(x))矩阵:p_ij
各个分块各行的p_ij和构成的向量l_ij。
根据上面推导的公式,就可以进行最小粒度分块的attention score的计算,再按照循环计算所有分块的值,当遍历结束,就可以将各个分块的值转成全局score的计算。
将各个分块的的m_ij的保存,得到Tr 个m, 每个长度是Br 。是为了下一次的计算
将各个分块的的l_ij的保存,得到Tr 个l, 每个长度是Br 。是为了下一次的计算
将各个O_i 保存,得到Tr 个O, 每个维度是Br*d. 就是每i 遍历一次,使用之前的保存的O_i和本次计算的P_ij * V_ij 的值更新一遍O_i, 当遍历i 结束的时候,第一遍的O_i的值全部更新完毕,得到第一遍的初步的N*d的全局O_i. 当 j 遍历结束的时,就刷新了j 遍O值,得到了最終的O值。
j=1, i 循环,得到第一遍的O1, O2, ....ON.
j=2, i循环,更新第二遍的O1, O2, ....ON.
j = n, i循环结束,计算完成,得到更新第j 遍的O1, O2, ....ON., 就是最终的结果。
可以理解是j次的纵向的更新O. 其中 l,m pij 在内循环中随着i增大,矩阵元素纵向的更新。
把FlashAttention V1的伪代码贴到这里,便于与上面的分析进行比对(需要求max, 所以I 之间不能并行)
kernel融合和尽可能利用起SRAM,以减少数据读取时间,都是flash attention的重要优化点。分块之后flash attention将矩阵乘法、mask、softmax、dropout操作合并成一个kernel,做到了只读一次和只写回一次,节省了数据读取时间
一句话总结就是:将k,v, q的矩阵进行分块,然后逐个块拷贝到SRAM 中,利用SRAM 高速的读写速度,快速的进行k,q矩阵乘法,再使用 v和注意力分数矩阵计算,然后将计算得到的结果,直接写到HBM中。(如果能做到并行的拷贝块到SRAM 中,速度更多,但是需要进行更细致的分块的索引处理和考虑SRAM的并行计算的内存容量)。
Flash Attention 2
比Flash Attention 1
加速2x
, 计算效率达到GEMM性能的50~73%
改进1:减少非乘法计算
改进2: 优化QKV for
循环顺序
改进3: 采用shared memory
减少通信(q 分块之间需要相互数据通信)
并行:
i 之间的并行(Q),j 之间有依赖,不能并行。
一个Vj内部的矩阵计算并行, Vj与Vj-1之间不能并行。
错误的公式:
Flash Attention 1
最大的问题在于O[i]
需要频繁的读写SRAM
缓存,如左图 ,那么考虑O[i]
的在一个Q[i]
周期算完, 就可以减少O[i]
的频繁读写缓存
Flash Attention 2
将Q
当作外循环,KV
当作内循环, 将O[i]
的在一个Q[i]
周期算完。
就是v1的内外循环的互换,结果就是外循环一次,计算得到一最终的Oi.
j=1, i 循环,得到第一个的完整O1,
j=2, i循环,得到第二个的完整O2,
j = n, i循环结束,计算完成得到第n个的完整On, 就是最终的结果.
[O1,
O2,
O3,
On]
可以理解是j遍的横向更新O. 其中 l, m pij 在外循环, j次更新。
下图可以直接对比出来,Flash Attnetion 1
每次都要做Scaled
(对O * l^-1 ),都是额外的非乘法计算。
在Flash Attention中,计算O时存在非乘法(non-matmul)的计算,非常耗时,计算O
的过程un-scaled
,对O
最后时刻做scaled
.
跳过绝对需要完全mask
的block
的计算,也就是這些位置的值不進行計算。
这个L ,对于我们只有前向推理,也不需要保存