大模型-FlashAttention 算法分析

一、FlashAttention的概述

FlashAttention是一种IO感知精确注意力算法。通过感知显存读取/写入,FlashAttention的运行速度比PyTorch标准Attention快了2-4倍,所需内存也仅是其5%-20%。随着Transformer变得越来越大、越来越深,但它在长序列上仍然处理的很慢、且耗费内存。(自注意力时间和显存复杂度与序列长度成二次方),现有近似注意力方法,在试图通过去牺牲模型质量,以降低计算复杂度来解决该问题。但存在一定的局限性,即不能提升运行时的训练速度。

基于这样子背景,研究人员觉得应该让注意力算法具有IO感知,即考虑显存级间的读写,比如大但慢的HBM(High Bandwidth Memory)技术与小但快的SRAM。基于这样的背景,研究人员提出了FlashAttention,具体有两种加速技术:

Flash Attention主要运用了两种技术:分块和算子融合计算。

  1. 分块即将输入划分成块,并单独处理每个块的注意力;

  2. 算子融合,将多个算子融合成一个算子,将所有注意力操作融合到CUDA内核中。

在深度学习中,Attention Mechanism(注意力机制)用于在序列数据中学习不同位置的重要性权重,以便更好地捕捉序列中的关键信息。Attention Scores(注意力分数)是用来计算这些权重的指标。以下是关于计算Attention Scores的推导过程的文档:

二、FlashAttention v1具体原理的探究

先上经典的示意图

 2.1 注意力机制快速了解

在注意力机制中,Q(Query)、K(Key)、V(Value)和softmax函数是核心概念,用于计算注意力权重并生成输出表示。

Query(查询)、Key(键)、Value(值)

  • Query Q(查询):查询是用来寻找与之相关性的键的向量。在注意力机制中,查询向量Q通常是通过对输入数据进行线性变换得到的,用于衡量查询与键的相似度,决定了在生成输出表示时每个键的重要程度。

  • Key K(键):键是用来表示输入数据的向量,与查询一起计算注意力权重。键向量K也是通过对输入数据进行线性变换得到的。

  • Value V(值):值是用来生成最终输出表示的向量。值向量V同样是通过对输入数据进行线性变换得到的,根据注意力权重对值进行加权求和得到最终输出。

Softmax函数

Softmax函数是一个常用的激活函数,通常用于多分类问题中,将输入转化为概率分布。在注意力机制中,softmax函数被用来计算注意力权重,将注意力分数转化为概率分布,使得不同位置的输入在生成输出时得到不同的权重。

标准的Attention实现算法如下。

 

 

2.2 Flash Attention V1

先给出整体的的代码框架:

 为了更容易理解K、V、Q的分块计算推导,我们可以将其分解为几个关键步骤,并用直观的方式解释每个步骤的目的和计算方法。以下是基于FlashAttention V1的分块计算的核心推导:

1. 核心步驟总结-分块

首先,我们将输入序列的查询(Query)矩阵Q、键(Key)矩阵的转置K^T和值(Value)矩阵V分别进行分块。

  • Q分为Tr个块,每个块的大小为(Br, d),其中Br是块的长度,d是向量的维度。

  • K^T分为Tc个块,每个块的大小为(d, Bc)。

  • V分为Tc个块,每个块的大小为(Bc, d)。

2. 分块计算注意力分数(Attention Score)

对于每一对分块(Q_i, K_j^T),我们计算它们的点积来得到局部注意力分数S_ij:

S_ij = Q_i * K_j^T

这里的S_ij是一个(Br, Bc)的矩阵,表示第i个Q块与第j个K块的转置之间的相关性分数。

3. 安全Softmax(Safe Softmax)

为了避免在计算softmax函数时发生数值溢出,我们使用安全softmax方法。对于每个S_ij,我们首先计算行的最大值m(S_ij)和行的指数和l(S_ij),然后应用到每一行上进行归一化:

4. 计算输出(Output)

使用归一化前的注意力分P_ij和对应的V_j块,我们可以计算输出O_ij:

这里的O_ij是一个(Br, d)的矩阵,表示通过注意力机制加权后的输出。

5. 分块输出的聚合

由于我们是对Q、K、V进行了分块计算,每个O_ij只是部分输出。为了得到最终的输出O,我们需要将所有分块的输出合并起来。这通常涉及到对所有O_ij块的累加操作。

6. 优化内存访问

在FlashAttention中,关键的优化是减少对显存的访问次数。在分块计算中,我们只加载必要的Q_i、K_j^T和V_j块到GPU的片上内存(SRAM),计算完成后直接使用,避免了频繁地从片下内存(HBM)读取和写入大型矩阵。

通过这种方式,FlashAttention能够在保持精确计算的同时,显著提高大模型的计算效率,尤其是在处理长序列时。分块计算不仅减少了内存占用,还降低了内存带宽的需求,从而提高了整体的计算性能。

7. 细节点
  1. 分块向上取整是为了分块的存够大,有空间冗余,不会丢失数据, 保证同时加载四个Q, K, V, O tensor的时候,内存不爆掉。

  2. 這個Mij的維度正常理解每一行的最大值构成的向量。所以维度是行数。

  3. 分块计算是按照br的纵向进行分块计算的。

  4. 第十步是为了防止某个数过大溢出,然后进行的操作技巧。计算归一化前的一个分块的结果。

  5. 一个外圈的第j个循环,得到一个完整的O(n*d), 我们记作O_j, 第j+1次时候,会在 O_j 的基础上继续更新, 得到O_{j+1}, 最终所有 j 循环结束, 得到最新的完整的 O, 就是最终的结果,这个 j 循环过程中, 计算出了Tc 个完整的O(n*d), 每次都在原有基础上更新前面的O.

    1. trick:

    • 所有以 ij 作为下标的,都表示当前分块的计算结果

    • 所有以 i作为下标的,都表示截止到前一个分块(包含前一个分块)的计算结果

    • 所有以 new 为上标的,都表示引入当前分块做更新后的结果

    • 所有没有下标的,都表示全局结果

     分块示意图:

     

    2.4.1 分块推理

    前两小节是伪代码的11, 12 的计算定义计算的背景。

    2.4.1.1 分块的分子的计算定义
    2.4.1 分块推理

    前两小节是伪代码的11, 12 的计算定义计算的背景。

    2.4.1.1 分块的分子的计算定义

     f(x)的计算,可以构造

    第一步部分:f(x__1) = exp(x_1 - m(x_1) + m(x_1) - m(x)) = f(x_1)* exp(m(x_1) - m(x))

    第二部分:f(x__2) = exp(x_2 - m(x_2) + m(x_2) - m(x)) = f(x_2)* exp(m(x_2) - m(x))

    得到f(x) =[f(x__1), f(x__2)]

    2.4.1.2 分块的分母的计算定义

     

    最終使用分塊的部分值

    f(x_1): 分块[x_1]的attention的分子部分

    f(x_2):分块[x_2]的attention的分子部分

    l(x_1):分块[x_1]的attention的分母部分

    l(x_2):分块[x_2]的attention的分母部分

    m(x_1):分块[x_1]的最大值

    m(x_2):分块[x_2]的最大值

    m(x):[x_1, x_2]全局向量的最大值

    表示了全局f(x)的attention 分數計算了。这个也是为了后面的遍历分块计算,使用各个子块的数据,计算得到全局的attention分数。

    2.4.1.3 使用分块计算全局的计算推理

    这一节就是对应伪代码的11,12 行的原理。

    这里我们换个思路: O_i 不是每遍历一块就更新一次吗?那有没有一种办法,不断用当前最新的rowmax和rowsum去更新 O_i ,直到遍历完最后一块,这时的 O_i不就和标准场景下的结果完全一致了吗?也就是我们想构造形如下面这样的更新等式:

    暂且接受这个公式说的,这个样子就得到了伪代码的第12行。

    diag(l_j+1): 传入的l_j+1 是每一行的p_ij构成的向量,使用diag 函数,将它构建成一个方阵,对角线的元素是l_j+1的值,其实就是构建和 j+1的方阵,每一行只有对角线的位置一个元素。

    对角矩阵的逆可以就是将对角元素的倒数(1/m)构成的对角方阵。那刚好就是attention score的分母

    O_i * () 表示的是之前计算的O_i的attention score, pij * vj 这一部分是最新的一个循环计算的attention score。

    2.5 分块的计算示意图总结

     

    上面是分块的各个变量的示意图和计算的步骤过程,最终计算得到

    1. 各个分块的各个行最大值构成的向量:m_ij

    2. 各个分块的每一个值和该行最大值max(x)的exp(x_ij-max(x))矩阵:p_ij

    3. 各个分块各行的p_ij和构成的向量l_ij。

    4. 根据上面推导的公式,就可以进行最小粒度分块的attention score的计算,再按照循环计算所有分块的值,当遍历结束,就可以将各个分块的值转成全局score的计算。

    5. 将各个分块的的m_ij的保存,得到Tr 个m, 每个长度是Br 。是为了下一次的计算

    6. 将各个分块的的l_ij的保存,得到Tr 个l, 每个长度是Br 。是为了下一次的计算

    7. 将各个O_i 保存,得到Tr 个O, 每个维度是Br*d. 就是每i 遍历一次,使用之前的保存的O_i和本次计算的P_ij * V_ij 的值更新一遍O_i, 当遍历i 结束的时候,第一遍的O_i的值全部更新完毕,得到第一遍的初步的N*d的全局O_i. 当 j 遍历结束的时,就刷新了j 遍O值,得到了最終的O值。

    j=1, i 循环,得到第一遍的O1, O2, ....ON.

    j=2, i循环,更新第二遍的O1, O2, ....ON.

    j = n, i循环结束,计算完成,得到更新第j 遍的O1, O2, ....ON., 就是最终的结果。

    可以理解是j次的纵向的更新O. 其中 l,m pij 在内循环中随着i增大,矩阵元素纵向的更新。

    把FlashAttention V1的伪代码贴到这里,便于与上面的分析进行比对(需要求max, 所以I 之间不能并行)

    2.6 Flashv1总结

    kernel融合和尽可能利用起SRAM,以减少数据读取时间,都是flash attention的重要优化点。分块之后flash attention将矩阵乘法、mask、softmax、dropout操作合并成一个kernel,做到了只读一次和只写回一次,节省了数据读取时间

    一句话总结就是:将k,v, q的矩阵进行分块,然后逐个块拷贝到SRAM 中,利用SRAM 高速的读写速度,快速的进行k,q矩阵乘法,再使用 v和注意力分数矩阵计算,然后将计算得到的结果,直接写到HBM。(如果能做到并行的拷贝块到SRAM 中,速度更多,但是需要进行更细致的分块的索引处理和考虑SRAM的并行计算的内存容量)。

    三、Flash Attention V2

    1. Flash Attention 2Flash Attention 1加速2x, 计算效率达到GEMM性能的50~73%

    2. 改进1:减少非乘法计算

    3. 改进2: 优化QKV for循环顺序

    4. 改进3: 采用shared memory减少通信(q 分块之间需要相互数据通信)

    并行:

    1. i 之间的并行(Q),j 之间有依赖,不能并行。

    2. 一个Vj内部的矩阵计算并行, Vj与Vj-1之间不能并行。

    错误的公式:

     

     

    3.1 for循环的优化

    • Flash Attention 1 最大的问题在于O[i]需要频繁的读写SRAM 缓存,如左图 ,那么考虑O[i]的在一个Q[i]周期算完, 就可以减少O[i]的频繁读写缓存

    • Flash Attention 2Q当作外循环,KV当作内循环, 将O[i]的在一个Q[i]周期算完。

    • 就是v1的内外循环的互换,结果就是外循环一次,计算得到一最终的Oi.

    j=1, i 循环,得到第一个的完整O1,

    j=2, i循环,得到第二个的完整O2,

    j = n, i循环结束,计算完成得到第n个的完整On, 就是最终的结果.

    [O1,

    O2,

    O3,

    On]

    可以理解是j遍的横向更新O. 其中 l, m pij 在外循环, j次更新。

     

    3.2 减少非矩阵乘法运算

    下图可以直接对比出来,Flash Attnetion 1 每次都要做Scaled (对O * l^-1 ),都是额外的非乘法计算。

    在Flash Attention中,计算O时存在非乘法(non-matmul)的计算,非常耗时,计算O的过程un-scaled,对O最后时刻做scaled.

    3.3 忽略mask block的Attention计算

    • 跳过绝对需要完全maskblock的计算,也就是這些位置的值不進行計算。

    3.4 L的不保存

    这个L ,对于我们只有前向推理,也不需要保存

     

     

    你可能感兴趣的:(大模型推理,算法,算法,大模型推理,LLM,flashattention)