# 标准注意力计算
def standard_attention(Q, K, V, mask=None):
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 应用softmax获取权重
weights = F.softmax(scores, dim=-1)
# 计算加权和
output = torch.matmul(weights, V)
return output
内存访问瓶颈:反复读写全局内存
内存瓶颈:对于长序列,注意力矩阵尺寸为O(N²)
FlashAttention通过重组计算并利用GPU内存层次结构解决上述问题:
分块计算:将QKV矩阵分割成小块(blocks)
对于每个Q块:
对于每个K,V块:
计算部分注意力
在SRAM中累积结果
利用GPU内存层次:
重组softmax:使用数学技巧在线性空间内进行softmax操作
softmax(x) = exp(x - m) / sum(exp(x - m))
其中m是在线更新的最大值
def flash_attention(Q, K, V, mask=None):
# 将QKV分成块
Q_blocks = split_into_blocks(Q)
K_blocks = split_into_blocks(K)
V_blocks = split_into_blocks(V)
outputs = []
# 对每个Q块分别计算
for Qi in Q_blocks:
# 在SRAM中保存部分结果
O_i = torch.zeros_like(Qi)
m_i = torch.ones(Qi.size(0), Qi.size(1)) * -float('inf')
l_i = torch.zeros(Qi.size(0), Qi.size(1))
# 对每个KV块累积计算
for Kj, Vj in zip(K_blocks, V_blocks):
# 计算部分注意力分数
Sij = torch.matmul(Qi, Kj.transpose(-2, -1)) / math.sqrt(d_k)
# 处理掩码
if mask is not None:
# 只应用对应块的掩码
block_mask = get_block_mask(mask, i, j)
Sij = Sij.masked_fill(block_mask == 0, -float('inf'))
# 更新最大值m
m_new = torch.maximum(m_i, Sij.max(dim=-1)[0])
# 更新累积和l
l_i_new = l_i * torch.exp(m_i - m_new) + torch.sum(torch.exp(Sij - m_new.unsqueeze(-1)), dim=-1)
# 更新输出
O_i = (O_i * torch.exp(m_i - m_new).unsqueeze(-1) +
torch.matmul(torch.exp(Sij - m_new.unsqueeze(-1)), Vj)) / l_i_new.unsqueeze(-1)
# 更新m和l
m_i, l_i = m_new, l_i_new
outputs.append(O_i)
# 合并所有输出块
return torch.cat(outputs, dim=1)
def flash_attention_v2(Q, K, V, mask=None, causal=False, dropout_p=0.0):
# 硬件感知的分块策略 - 根据GPU特性动态调整
block_size_m = optimize_block_size_m() # 根据GPU架构优化
block_size_n = optimize_block_size_n() # 为减少内存栅栏优化
# 高效数据布局 - 为了提高内存访问效率
Q = optimize_layout(Q) # 改进的内存布局
K = optimize_layout(K) # 可能是以head维度交错排列
V = optimize_layout(V) # 优化GPU缓存命中率
batch_size, num_heads, seq_len_q, head_dim = Q.shape
_, _, seq_len_k, _ = K.shape
# 检查是否可以使用非填充(varlen)版本
if mask is not None and has_padding(mask):
return flash_attention_v2_varlen(Q, K, V, mask, causal, dropout_p)
# 预先计算缩放因子
softmax_scale = 1.0 / math.sqrt(head_dim)
# 根据形状分配输出空间
O = torch.zeros_like(Q)
# 对序列维度进行分块
for q_start in range(0, seq_len_q, block_size_m):
q_end = min(q_start + block_size_m, seq_len_q)
# 在SRAM中保存的状态变量
O_block = torch.zeros((batch_size, num_heads, q_end - q_start, head_dim),
device=Q.device, dtype=Q.dtype)
L_block = torch.zeros((batch_size, num_heads, q_end - q_start),
device=Q.device, dtype=torch.float32)
m_block = torch.ones((batch_size, num_heads, q_end - q_start),
device=Q.device, dtype=torch.float32) * -float('inf')
# KV缓存管理优化
k_start_offset = 0 if not causal else q_start
# 对KV序列维度进行分块
for k_start in range(k_start_offset, seq_len_k, block_size_n):
k_end = min(k_start + block_size_n, seq_len_k)
# 应用因果掩码(优化的实现)
if causal and q_start > k_start:
# 整个块可以跳过 - 提前退出
continue
# 提取当前块
Q_block = Q[:, :, q_start:q_end]
K_block = K[:, :, k_start:k_end]
V_block = V[:, :, k_start:k_end]
# 分块注意力计算 - 使用优化的CUDA内核
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) * softmax_scale
# 应用细粒度的因果掩码(仅当需要时)
if causal and q_start <= k_end and k_start <= q_end:
# 优化的因果掩码实现(无需显式掩码矩阵)
causal_mask = create_optimized_causal_mask(q_start, q_end, k_start, k_end)
S_block = apply_optimized_mask(S_block, causal_mask)
# 应用注意力掩码(如果有)
if mask is not None:
block_mask = mask[:, :, q_start:q_end, k_start:k_end]
S_block = S_block.masked_fill(block_mask == 0, -float('inf'))
# 高效更新状态变量 - 降低重计算
m_block_new = torch.maximum(m_block, S_block.max(dim=-1)[0])
# 带稳定性优化的状态更新
exp_factor = torch.exp(m_block - m_block_new.unsqueeze(-1))
L_block_new = L_block.unsqueeze(-1) * exp_factor + torch.sum(
torch.exp(S_block - m_block_new.unsqueeze(-1).unsqueeze(-1)), dim=-1)
# 使用fused操作优化计算
O_block = fused_softmax_and_matmul(
O_block, S_block, V_block,
m_block, m_block_new,
L_block, L_block_new
)
# 更新聚合状态
m_block, L_block = m_block_new, L_block_new
# 完成当前Q块的计算,写回结果
O[:, :, q_start:q_end] = O_block / L_block.unsqueeze(-1)
# 应用dropout(优化的实现)
if dropout_p > 0.0:
O = optimized_dropout(O, dropout_p)
return O
# 变长序列版本处理不等长输入
def flash_attention_v2_varlen(Q, K, V, mask, causal, dropout_p):
# 解包批次中的序列长度信息
batch_size, num_heads, max_seq_len_q, head_dim = Q.shape
# 优化:分析掩码提取实际序列长度
q_lens, k_lens = extract_sequence_lengths(mask)
# 将填充的输入展平为压缩格式(去除填充)
Q_unpad, Q_indices, cu_seqlens_q = unpad_input(Q, q_lens)
K_unpad, K_indices, cu_seqlens_k = unpad_input(K, k_lens)
V_unpad, _, _ = unpad_input(V, k_lens)
# 使用优化的变长实现
O_unpad = flash_attn_varlen_kernel(
Q_unpad, K_unpad, V_unpad,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_q=max_seq_len_q,
max_seqlen_k=max(k_lens),
dropout_p=dropout_p,
softmax_scale=1.0/math.sqrt(head_dim),
causal=causal,
return_attn_probs=False
)
# 将压缩输出恢复到原始填充格式
O = pad_output(O_unpad, Q_indices, batch_size, max_seq_len_q)
return O
# 优化的CUDA kernel的Python表示
def fused_softmax_and_matmul(O_block, S_block, V_block, m_old, m_new, L_old, L_new):
# 实际上这会在单个优化的CUDA内核中执行
# 以下是逻辑表示
# 计算缩放因子
exp_factor = torch.exp(m_old.unsqueeze(-1) - m_new.unsqueeze(-1).unsqueeze(-1))
# 重新计算归一化输出
O_scaled = O_block * exp_factor
# 计算新的输出贡献
P = torch.exp(S_block - m_new.unsqueeze(-1).unsqueeze(-1))
O_new = torch.matmul(P, V_block)
# 合并结果
O_combined = O_scaled + O_new
return O_combined
FlashAttention 2相比于原始FlashAttention的主要优化点包括:
数据布局优化:
optimize_layout()
函数表示改进的内存布局,减少内存访问冲突硬件感知分块:
optimize_block_size_m()
和 optimize_block_size_n()
动态选择最优块大小优化的因果掩码:
create_optimized_causal_mask()
创建更高效的掩码融合操作:
fused_softmax_and_matmul()
函数表示将多个操作融合为单个CUDA内核变长序列支持:
flash_attention_v2_varlen()
函数高效处理带填充的序列unpad_input()
和 pad_output()
减少不必要的计算减少同步障碍:
精简的状态管理:
m_block
和 L_block
等中间状态这些优化使得FlashAttention 2在处理相同大小的输入时,相比原始FlashAttention速度提升约20-30%,并且在长序列上的优势更为显著。
特性 | 普通Attention | FlashAttention | FlashAttention 2 |
---|---|---|---|
内存复杂度 | O(N²) | O(N) | O(N) |
相对速度 | 1x | 2-4x | 2-6x |
长序列支持 | 受限 | 较好 | 优秀 |
掩码处理 | 标准 | 优化 | 高度优化 |
硬件利用率 | 低 | 中 | 高 |
CUDA优化 | 基本 | 良好 | 深度优化 |
适用序列长度 | <4K | <64K | <128K |
注:FlashAttention2的相关内容请大家批评指正