Flash attention入门

一、目录

  1. flash attention
  2. GPU运算流程
  3. flash attention 原理
  4. flash attention 与 standard attention 时间/内存 对比。
  5. flash attention 算法实现
  6. 比较flash attention 计算、memory-efficient attention 等不同内核下用时

二、实现

  1. flash attention
    目的: 提高运行速度,减少内存消耗。

  2. GPU运算流程
    见gpu 入门篇

  3. flash attention 原理
    3.1 原理:
    flashAtention其加速的原理是非常简单的,也是最基础和常见的系统性能优化的手段,即通过利用更高速的上层存储计算单元,减少对低速更下层存储器的访问次数,来提升模型的训练性能。Flash attention入门_第1张图片
    图片代表的为带宽大小与内存大小的关系,即从上面的数字可以看出SRAM的访问速率是HBM的10倍左右,然而其能承载的数据量却远远小于HBM。
    CPU 内存大小》GPU 高带宽内存>>GPU SRAM(静态内存)
    GPU SRAM速度>>GPU 高带宽 显存>>CPU 内存速度
    3.2. 创新点:将flashAttention 计算过程由HBM 转为SRAM 中,减少访问次数。
    3.3. 标准attention 计算方法 与flashAttention 计算方法
    标准attention计算:
    首先,从HBM中读取完整的Q和K矩阵(每个大小为N x d),计算点积得到相似度得分S(大小为N x N),需要进行O(Nd + N^2)次HBM访问。
    其次,计算注意力权重P(大小为N x N)时,需要对S进行softmax操作,这需要进行O(N^2)次HBM访问。
    最后,将注意力权重P和值向量V(每个大小为N x d)加权求和得到输出向量O(大小为N x d)时,需要进行O(Nd)次HBM访问。
    标准 Attention 算法的总HBM访问次数为O(Nd + N^2)
    flashAttention计算:
    将原始的注意力矩阵分解成更小的子矩阵,然后分别对这些子矩阵进行计算,只要这个子矩阵的大小可以在SRAM内存放,就可以在计算过程中只访问SRAM。
    计算过程中要尽量的利用SRAM进行计算,避免访问HBM操作。
    3.4. 什么时候使用HBM,什么时候使用SRAM?
    编程时,人为指定SRAM空间。

  4. flash attention 与 standard attention 时间/内存 对比。
    参考:https://zhuanlan.zhihu.com/p/638468472
    以 batch=32, seq_len=512, n_head=16,head_dim=64 为例,记录flash attention 与standard attention 时间/内存对比。Flash attention入门_第2张图片flash attention实现:

import torch
from xformers import ops as xops
import time
bs = 32
seq_len = 512
n_head = 16
head_dim = 64
query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")

flash_query_states = query_states.transpose(1, 2)
flash_key_states = key_states.transpose(1, 2)
flash_value_states = value_states.transpose(1, 2)
start_time = time.time

你可能感兴趣的:(AI,自然语言处理)