普通Attention,FlashAttention和FlashAttention2的比较

FlashAttention、普通Attention和FlashAttention2的比较

一、普通Attention的实现与问题

普通Attention实现

# 标准注意力计算
def standard_attention(Q, K, V, mask=None):
    # 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 应用掩码
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # 应用softmax获取权重
    weights = F.softmax(scores, dim=-1)
    
    # 计算加权和
    output = torch.matmul(weights, V)
    return output

主要问题

  1. 内存访问瓶颈:反复读写全局内存

    • Q、K、V矩阵
    • 注意力分数矩阵(N×N大小)
    • 注意力权重矩阵(softmax后)
    • 输出矩阵
  2. 内存瓶颈:对于长序列,注意力矩阵尺寸为O(N²)

    • 序列长度1K需要4MB
    • 序列长度10K需要400MB

二、FlashAttention的创新

FlashAttention通过重组计算并利用GPU内存层次结构解决上述问题:

关键技术

  1. 分块计算:将QKV矩阵分割成小块(blocks)

    对于每个Q块:
      对于每个K,V块:
        计算部分注意力
        在SRAM中累积结果
    
  2. 利用GPU内存层次

    • 高速SRAM(片上缓存):存储中间计算结果
    • 较慢的HBM(全局内存):仅用于读取输入和写入最终输出
  3. 重组softmax:使用数学技巧在线性空间内进行softmax操作

    softmax(x) = exp(x - m) / sum(exp(x - m))
    

    其中m是在线更新的最大值

代码示意

def flash_attention(Q, K, V, mask=None):
    # 将QKV分成块
    Q_blocks = split_into_blocks(Q)
    K_blocks = split_into_blocks(K)
    V_blocks = split_into_blocks(V)
    
    outputs = []
    # 对每个Q块分别计算
    for Qi in Q_blocks:
        # 在SRAM中保存部分结果
        O_i = torch.zeros_like(Qi)
        m_i = torch.ones(Qi.size(0), Qi.size(1)) * -float('inf')
        l_i = torch.zeros(Qi.size(0), Qi.size(1))
        
        # 对每个KV块累积计算
        for Kj, Vj in zip(K_blocks, V_blocks):
            # 计算部分注意力分数
            Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) / math.sqrt(d_k)
            
            # 处理掩码
            if mask is not None:
                # 只应用对应块的掩码
                block_mask = get_block_mask(mask, i, j)
                Sij = Sij.masked_fill(block_mask == 0, -float('inf'))
            
            # 更新最大值m
            m_new = torch.maximum(m_i, Sij.max(dim=-1)[0])
            
            # 更新累积和l
            l_i_new = l_i * torch.exp(m_i - m_new) + torch.sum(torch.exp(Sij - m_new.unsqueeze(-1)), dim=-1)
            
            # 更新输出
            O_i = (O_i * torch.exp(m_i - m_new).unsqueeze(-1) + 
                   torch.matmul(torch.exp(Sij - m_new.unsqueeze(-1)), Vj)) / l_i_new.unsqueeze(-1)
            
            # 更新m和l
            m_i, l_i = m_new, l_i_new
        
        outputs.append(O_i)
    
    # 合并所有输出块
    return torch.cat(outputs, dim=1)

性能优势

  1. 内存效率:不需要存储完整的N×N注意力矩阵
  2. IO效率:减少全局内存访问次数
  3. 速度提升:相比普通Attention快2-4倍
  4. 内存使用:O(N)而非O(N²)

三、FlashAttention 2的改进

def flash_attention_v2(Q, K, V, mask=None, causal=False, dropout_p=0.0):
    # 硬件感知的分块策略 - 根据GPU特性动态调整
    block_size_m = optimize_block_size_m()  # 根据GPU架构优化
    block_size_n = optimize_block_size_n()  # 为减少内存栅栏优化
    
    # 高效数据布局 - 为了提高内存访问效率
    Q = optimize_layout(Q)  # 改进的内存布局
    K = optimize_layout(K)  # 可能是以head维度交错排列
    V = optimize_layout(V)  # 优化GPU缓存命中率
    
    batch_size, num_heads, seq_len_q, head_dim = Q.shape
    _, _, seq_len_k, _ = K.shape
    
    # 检查是否可以使用非填充(varlen)版本
    if mask is not None and has_padding(mask):
        return flash_attention_v2_varlen(Q, K, V, mask, causal, dropout_p)
    
    # 预先计算缩放因子
    softmax_scale = 1.0 / math.sqrt(head_dim)
    
    # 根据形状分配输出空间
    O = torch.zeros_like(Q)
    
    # 对序列维度进行分块
    for q_start in range(0, seq_len_q, block_size_m):
        q_end = min(q_start + block_size_m, seq_len_q)
        
        # 在SRAM中保存的状态变量
        O_block = torch.zeros((batch_size, num_heads, q_end - q_start, head_dim), 
                              device=Q.device, dtype=Q.dtype)
        L_block = torch.zeros((batch_size, num_heads, q_end - q_start), 
                             device=Q.device, dtype=torch.float32)
        m_block = torch.ones((batch_size, num_heads, q_end - q_start), 
                            device=Q.device, dtype=torch.float32) * -float('inf')
        
        # KV缓存管理优化
        k_start_offset = 0 if not causal else q_start
        
        # 对KV序列维度进行分块
        for k_start in range(k_start_offset, seq_len_k, block_size_n):
            k_end = min(k_start + block_size_n, seq_len_k)
            
            # 应用因果掩码(优化的实现)
            if causal and q_start > k_start:
                # 整个块可以跳过 - 提前退出
                continue
            
            # 提取当前块
            Q_block = Q[:, :, q_start:q_end]
            K_block = K[:, :, k_start:k_end]
            V_block = V[:, :, k_start:k_end]
            
            # 分块注意力计算 - 使用优化的CUDA内核
            S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) * softmax_scale
            
            # 应用细粒度的因果掩码(仅当需要时)
            if causal and q_start <= k_end and k_start <= q_end:
                # 优化的因果掩码实现(无需显式掩码矩阵)
                causal_mask = create_optimized_causal_mask(q_start, q_end, k_start, k_end)
                S_block = apply_optimized_mask(S_block, causal_mask)
            
            # 应用注意力掩码(如果有)
            if mask is not None:
                block_mask = mask[:, :, q_start:q_end, k_start:k_end]
                S_block = S_block.masked_fill(block_mask == 0, -float('inf'))
            
            # 高效更新状态变量 - 降低重计算
            m_block_new = torch.maximum(m_block, S_block.max(dim=-1)[0])
            
            # 带稳定性优化的状态更新
            exp_factor = torch.exp(m_block - m_block_new.unsqueeze(-1))
            L_block_new = L_block.unsqueeze(-1) * exp_factor + torch.sum(
                torch.exp(S_block - m_block_new.unsqueeze(-1).unsqueeze(-1)), dim=-1)
            
            # 使用fused操作优化计算
            O_block = fused_softmax_and_matmul(
                O_block, S_block, V_block, 
                m_block, m_block_new, 
                L_block, L_block_new
            )
            
            # 更新聚合状态
            m_block, L_block = m_block_new, L_block_new
        
        # 完成当前Q块的计算,写回结果
        O[:, :, q_start:q_end] = O_block / L_block.unsqueeze(-1)
    
    # 应用dropout(优化的实现)
    if dropout_p > 0.0:
        O = optimized_dropout(O, dropout_p)
    
    return O

# 变长序列版本处理不等长输入
def flash_attention_v2_varlen(Q, K, V, mask, causal, dropout_p):
    # 解包批次中的序列长度信息
    batch_size, num_heads, max_seq_len_q, head_dim = Q.shape
    
    # 优化:分析掩码提取实际序列长度
    q_lens, k_lens = extract_sequence_lengths(mask)
    
    # 将填充的输入展平为压缩格式(去除填充)
    Q_unpad, Q_indices, cu_seqlens_q = unpad_input(Q, q_lens)
    K_unpad, K_indices, cu_seqlens_k = unpad_input(K, k_lens)
    V_unpad, _, _ = unpad_input(V, k_lens)
    
    # 使用优化的变长实现
    O_unpad = flash_attn_varlen_kernel(
        Q_unpad, K_unpad, V_unpad,
        cu_seqlens_q, cu_seqlens_k,
        max_seqlen_q=max_seq_len_q,
        max_seqlen_k=max(k_lens),
        dropout_p=dropout_p,
        softmax_scale=1.0/math.sqrt(head_dim),
        causal=causal,
        return_attn_probs=False
    )
    
    # 将压缩输出恢复到原始填充格式
    O = pad_output(O_unpad, Q_indices, batch_size, max_seq_len_q)
    
    return O

# 优化的CUDA kernel的Python表示
def fused_softmax_and_matmul(O_block, S_block, V_block, m_old, m_new, L_old, L_new):
    # 实际上这会在单个优化的CUDA内核中执行
    # 以下是逻辑表示
    
    # 计算缩放因子
    exp_factor = torch.exp(m_old.unsqueeze(-1) - m_new.unsqueeze(-1).unsqueeze(-1))
    
    # 重新计算归一化输出
    O_scaled = O_block * exp_factor
    
    # 计算新的输出贡献
    P = torch.exp(S_block - m_new.unsqueeze(-1).unsqueeze(-1))
    O_new = torch.matmul(P, V_block)
    
    # 合并结果
    O_combined = O_scaled + O_new
    
    return O_combined

FlashAttention 2相比于原始FlashAttention的主要优化点包括:

  1. 数据布局优化

    • optimize_layout() 函数表示改进的内存布局,减少内存访问冲突
    • 优化了张量存储方式以提高缓存命中率
  2. 硬件感知分块

    • optimize_block_size_m()optimize_block_size_n() 动态选择最优块大小
    • 基于具体GPU架构特性调整计算方式
  3. 优化的因果掩码

    • 使用 create_optimized_causal_mask() 创建更高效的掩码
    • 对块之间的因果关系进行预检查,避免不必要的计算
  4. 融合操作

    • fused_softmax_and_matmul() 函数表示将多个操作融合为单个CUDA内核
    • 减少内核启动和中间结果的内存访问
  5. 变长序列支持

    • 专门优化的 flash_attention_v2_varlen() 函数高效处理带填充的序列
    • 使用 unpad_input()pad_output() 减少不必要的计算
  6. 减少同步障碍

    • 通过精心设计的计算顺序减少线程同步需求
    • 降低了内存栅栏操作的次数
  7. 精简的状态管理

    • 更高效地管理和更新 m_blockL_block 等中间状态
    • 减少了重复计算和冗余内存访问

这些优化使得FlashAttention 2在处理相同大小的输入时,相比原始FlashAttention速度提升约20-30%,并且在长序列上的优势更为显著。

四、性能对比

特性 普通Attention FlashAttention FlashAttention 2
内存复杂度 O(N²) O(N) O(N)
相对速度 1x 2-4x 2-6x
长序列支持 受限 较好 优秀
掩码处理 标准 优化 高度优化
硬件利用率
CUDA优化 基本 良好 深度优化
适用序列长度 <4K <64K <128K

注:FlashAttention2的相关内容请大家批评指正

你可能感兴趣的:(深度学习,人工智能)