【Flux 中的 attention】缩放点积注意力机制的具体实现

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias.to(attn_weight.device)
    attn_weight = torch.softmax(attn_weight, dim=-1)

    return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight

注意力机制中的 query、key 和 value 这三个核心概念:

  1. Query(查询)
  • 可以理解为"我要找什么"
  • 代表当前需要关注的信息
  • 在代码中,query 的维度是 […, L, D],其中 L 是序列长度,D 是特征维度
  • 比如在机器翻译中,query 可能是目标语言中当前要翻译的词
  1. Key(键)
  • 可以理解为"我有什么"
  • 代表可以被查询的信息
  • 在代码中,key 的维度是 […, S, D],其中 S 是序列长度,D 是特征维度
  • 比如在机器翻译中,key 可能是源语言中的词
  1. Value(值)
  • 可以理解为"实际要用的信息"
  • 代表最终要使用的信息
  • 在代码中,value 的维度与 key 相同
  • 比如在机器翻译中,value 可能是源语言词的语义表示

在代码中的处理流程:

  1. 首先计算 query 和 key 的点积:query @ key.transpose(-2, -1)
  2. 应用缩放因子:* scale_factor
  3. 添加注意力偏置:attn_weight += attn_bias
  4. 通过 softmax 得到注意力权重:torch.softmax(attn_weight, dim=-1)
  5. 最后用这些权重对 value 进行加权求和:attn_weight @ value

举个生活中的例子:

  • 想象你在图书馆找书
  • Query 就像你的搜索条件(比如"Python编程")
  • Key 就像每本书的标签/索引
  • Value 就像书的具体内容
  • 注意力机制就是根据你的搜索条件(Query)与书的标签(Key)的匹配程度,来决定要读哪些书的内容(Value)

这种机制的优势在于:

  1. 可以动态地关注不同的信息
  2. 能够捕捉长距离依赖关系
  3. 计算效率高,可以并行处理
  4. 具有可解释性,可以看到模型关注了哪些部分

本文代码片段来源于:https://github.com/wooyeolbaek/attention-map-diffusers(Line1826,/path/attention-map-diffusers/attention_map_diffusers/modules.py)

你可能感兴趣的:(编程学习,Diffusion,AIGC,深度学习,人工智能)