Bahdanau 注意力

Bahdanau 注意力(Additive Attention)

Bahdanau 注意力,也被称为 加性注意力(Additive Attention),是一种广泛用于序列到序列(Seq2Seq)模型中的注意力机制,由 Dzmitry Bahdanau 等人在 2014 年提出。它是处理 序列对序列(seq2seq)任务(如机器翻译)的关键技术之一。在Bahdanau提出的模型中,注意力机制通过引入一个额外的上下文信息来动态地对输入序列中的各个部分进行加权,从而使模型可以更好地处理长距离依赖。

1. Bahdanau 注意力的基本思想

Bahdanau 注意力的核心思想是根据每个解码步骤的当前状态(通常是解码器的隐藏状态)来为输入序列的不同位置分配不同的权重,从而将这些不同的输入信息加权后传递给解码器。在机器翻译任务中,它帮助模型更好地集中注意力于源语言中与当前翻译相关的部分。

与传统的 硬性对齐(hard alignment) 方式不同,Bahdanau 注意力是 软性对齐(soft alignment),它根据每个解码器状态来动态地计算注意力权重,使得不同时间步的输入可以被灵活地加权。

2. Bahdanau 注意力的计算过程

Bahdanau 注意力主要包括以下几个步骤:

  1. 计算注意力得分

    • 对于每个解码器的隐藏状态 h t h_t ht,通过计算该隐藏状态与输入序列中每个编码器隐藏状态 h s h_s hs 之间的相似性来获得注意力得分。
    • 加性注意力 中,得分是通过一个神经网络(通常是一个前馈网络)计算的,而不是直接计算点积。
  2. 计算注意力权重

    • 对于每个时间步的解码器,使用 softmax 将注意力得分归一化,得到注意力权重。
  3. 加权求和得到上下文向量

    • 使用计算得到的注意力权重对输入序列的编码器隐藏状态进行加权求和,生成一个上下文向量,作为解码器当前时间步的输入。

3. 数学公式

假设我们有如下变量:

  • h t h_t ht:解码器在时间步 t t t 的隐藏状态
  • h s h_s hs:编码器在时间步 s s s 的隐藏状态
  • e t e_t et:解码器在时间步 t t t 对编码器隐藏状态 h s h_s hs 的注意力得分
  • a t s a_{ts} ats:编码器第 s s s 个隐藏状态对于解码器第 t t t 个时间步的注意力权重
  • C t C_t Ct:解码器在时间步 t t t 的上下文向量
步骤 1:计算注意力得分(加性注意力)

Bahdanau 注意力使用一个前馈神经网络计算每个解码器隐藏状态 h t h_t ht 和编码器隐藏状态 h s h_s hs 之间的匹配得分:

e t s = v T tanh ⁡ ( W 1 h t + W 2 h s ) e_{ts} = v^T \tanh(W_1 h_t + W_2 h_s) ets=vTtanh(W1ht+W2hs)

其中:

  • W 1 W_1 W1 W 2 W_2 W2 是待学习的权重矩阵,
  • v v v 是一个待学习的向量(通常是一个列向量),
  • tanh ⁡ \tanh tanh 是激活函数,确保得分在非线性空间中。
步骤 2:归一化得分为注意力权重

接下来,我们通过 softmax 将得分 e t s e_{ts} ets 归一化为概率:

a t s = exp ⁡ ( e t s ) ∑ s ′ exp ⁡ ( e t s ′ ) a_{ts} = \frac{\exp(e_{ts})}{\sum_{s'} \exp(e_{ts'})} ats=sexp(ets)exp(ets)

这样,得到的注意力权重 a t s a_{ts} ats 可以解释为 编码器第 s s s 个位置对解码器第 t t t 个时间步的贡献

步骤 3:加权求和得到上下文向量

最后,解码器在时间步 t t t 的上下文向量 C t C_t Ct 是编码器所有隐藏状态的加权和,权重为 a t s a_{ts} ats

C t = ∑ s a t s h s C_t = \sum_{s} a_{ts} h_s Ct=satshs

这个上下文向量 C t C_t Ct 会被传递给解码器的下一层,用于生成当前时间步的输出。

4. 总结流程

  1. 对解码器在时间步 t t t 的隐藏状态 h t h_t ht 和编码器在时间步 s s s 的隐藏状态 h s h_s hs 计算匹配得分 e t s e_{ts} ets
  2. 通过 softmax 归一化得分,得到注意力权重 a t s a_{ts} ats
  3. 使用注意力权重 a t s a_{ts} ats 对编码器的隐藏状态 h s h_s hs 进行加权求和,得到上下文向量 C t C_t Ct
  4. 将上下文向量 C t C_t Ct 作为解码器输入的一部分,帮助生成下一步的输出。

5. Bahdanau 注意力的实现

以下是一个简单的 PyTorch 实现,展示了如何计算 Bahdanau 注意力:

import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    def __init__(self, d_model):
        super(BahdanauAttention, self).__init__()
        self.W1 = nn.Linear(d_model, d_model)
        self.W2 = nn.Linear(d_model, d_model)
        self.v = nn.Parameter(torch.randn(d_model))

    def forward(self, query, values):
        """
        :param query: 解码器的隐藏状态 (batch_size, d_model)
        :param values: 编码器的输出 (batch_size, seq_len, d_model)
        :return: 上下文向量 (batch_size, d_model), 注意力权重 (batch_size, seq_len)
        """
        # 扩展 query 到与 values 相同的维度
        query_with_time_axis = query.unsqueeze(1)  # (batch_size, 1, d_model)
        
        # 计算注意力得分
        score = torch.tanh(self.W1(query_with_time_axis) + self.W2(values))  # (batch_size, seq_len, d_model)
        score = score @ self.v  # (batch_size, seq_len)
        
        # 通过 softmax 计算注意力权重
        attention_weights = F.softmax(score, dim=1)  # (batch_size, seq_len)
        
        # 计算上下文向量
        context_vector = attention_weights.unsqueeze(1) @ values  # (batch_size, 1, d_model)
        context_vector = context_vector.squeeze(1)  # (batch_size, d_model)
        
        return context_vector, attention_weights

# 示例
batch_size = 2
seq_len = 5
d_model = 10

# 创建输入
query = torch.randn(batch_size, d_model)  # 解码器的隐藏状态
values = torch.randn(batch_size, seq_len, d_model)  # 编码器的输出

# 实例化注意力模块
attention = BahdanauAttention(d_model)

# 计算上下文向量
context_vector, attention_weights = attention(query, values)

print("Context Vector:", context_vector)
print("Attention Weights:", attention_weights)

6. 总结

  • Bahdanau 注意力(加性注意力)使用一个前馈神经网络来计算每个解码器隐藏状态和输入序列每个编码器隐藏状态之间的匹配得分。
  • 然后通过 softmax 归一化得分,生成注意力权重,并用这些权重加权求和输入序列的隐藏状态,得到上下文向量。
  • Bahdanau 注意力的关键优势在于它是 软性对齐,允许模型在解码时根据当前状态动态地关注输入的不同部分,适用于长距离依赖任务,如机器翻译等。

通过引入注意力机制,Bahdanau注意力使得Seq2Seq模型更加灵活且有表现力,能够在复杂的任务中实现更好的性能。

你可能感兴趣的:(自然语言处理,Bahdanau,Attention,注意力,Seq2Seq,pytorch,python,自然语言处理)