MHA实现

(1)关键点说明:

多头注意力机制:将输入分成多个头,每个头独立计算注意力

(2)维度变换:

输入:(batch_size, sequence_length, embedding_dim)

多头处理:(batch_size, num_heads, sequence_length, dim_per_head) 后面两项用于 Q、K的计算

维度变换的核心是用 view和permute来进行

(3)注意力计算:

Q×K^T 得到注意力分数

除以维度的平方根进行缩放 torch.sqrt()

使用softmax归一化 torch.softmax()

与V相乘得到加权结果 torch.matmul()

(4)Mask机制:可选的mask功能,用于屏蔽某些位置的注意力

这是Transformer架构中的一个核心组件,用于捕获序列中的长距离依赖关系。

# 指定文件编码为UTF-8
# encoding: utf-8

# 导入PyTorch主库
import torch
# 导入神经网络模块
import torch.nn as nn

# 定义自注意力层,继承自nn.Module
class SelfAttention(nn.Module):

    def __init__(self, num_heads, embed_dim):
        """
        初始化参数:
        num_heads: 注意力头的数量
        embed_dim: 嵌入维度大小
        """
        # 调用父类初始化
        super().__init__()
        # 确保嵌入维度能被头数整除
        assert embed_dim % num_heads == 0
        # 计算每个头的维度
        self.dim = embed_dim / num_heads
        
        # 定义Query、Key、Value的线性变换层
        self.q = nn.Linear(embed_dim, embed_dim)  # Query变换
        self.k = nn.Linear(embed_dim, embed_dim)  # Key变换
        self.v = nn.Linear(embed_dim, embed_dim)  # Value变换
        # 定义输出的线性变换层
        self.o = nn.Linear(embed_dim, embed_dim)
        # 保存注意力头数量
        self.num_heads = num_heads

    def forward(self, x):
        """
        前向传播函数
        x: 输入张量,形状为(batch_size, sequence_length, embedding_dim)
        """
        # 获取输入张量的维度信息
        batch_size, sequence_length, embedding_dim = x.size()
        
        # 计算Q、K、V并重塑维度
        # 从(B,L,D)转换为(B,N,L,d),其中:
        # B=batch_size, N=num_heads, L=sequence_length, d=dim_per_head
        Q = self.q(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
        K = self.k(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
        V = self.v(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)

        # 计算注意力分数:(Q × K^T) / sqrt(dim)
        A = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(self.dim)
        
        # 如果有mask,将mask为0的位置设置为很小的负数(-1e30)
        if mask is not None:
            A = A.mask_fill(mask == 0, -1e30)
        
        # 对注意力分数进行softmax
        A = torch.softmax(A, dim=-1)

        # 计算最终输出:注意力分数 × V
        # 然后将维度转回(B,L,D)
        O = torch.matmul(A, V).permute(0, 2, 1, 3).view(batch_size, sequence_length, embedding_dim)
        # 通过输出线性层
        O = self.o(O)
        return O

你可能感兴趣的:(算法基础,模型基础)