(1)关键点说明:
多头注意力机制:将输入分成多个头,每个头独立计算注意力
(2)维度变换:
输入:(batch_size, sequence_length, embedding_dim)
多头处理:(batch_size, num_heads, sequence_length, dim_per_head) 后面两项用于 Q、K的计算
维度变换的核心是用 view和permute来进行
(3)注意力计算:
Q×K^T 得到注意力分数
除以维度的平方根进行缩放 torch.sqrt()
使用softmax归一化 torch.softmax()
与V相乘得到加权结果 torch.matmul()
(4)Mask机制:可选的mask功能,用于屏蔽某些位置的注意力
这是Transformer架构中的一个核心组件,用于捕获序列中的长距离依赖关系。
# 指定文件编码为UTF-8
# encoding: utf-8
# 导入PyTorch主库
import torch
# 导入神经网络模块
import torch.nn as nn
# 定义自注意力层,继承自nn.Module
class SelfAttention(nn.Module):
def __init__(self, num_heads, embed_dim):
"""
初始化参数:
num_heads: 注意力头的数量
embed_dim: 嵌入维度大小
"""
# 调用父类初始化
super().__init__()
# 确保嵌入维度能被头数整除
assert embed_dim % num_heads == 0
# 计算每个头的维度
self.dim = embed_dim / num_heads
# 定义Query、Key、Value的线性变换层
self.q = nn.Linear(embed_dim, embed_dim) # Query变换
self.k = nn.Linear(embed_dim, embed_dim) # Key变换
self.v = nn.Linear(embed_dim, embed_dim) # Value变换
# 定义输出的线性变换层
self.o = nn.Linear(embed_dim, embed_dim)
# 保存注意力头数量
self.num_heads = num_heads
def forward(self, x):
"""
前向传播函数
x: 输入张量,形状为(batch_size, sequence_length, embedding_dim)
"""
# 获取输入张量的维度信息
batch_size, sequence_length, embedding_dim = x.size()
# 计算Q、K、V并重塑维度
# 从(B,L,D)转换为(B,N,L,d),其中:
# B=batch_size, N=num_heads, L=sequence_length, d=dim_per_head
Q = self.q(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
K = self.k(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
V = self.v(x).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
# 计算注意力分数:(Q × K^T) / sqrt(dim)
A = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(self.dim)
# 如果有mask,将mask为0的位置设置为很小的负数(-1e30)
if mask is not None:
A = A.mask_fill(mask == 0, -1e30)
# 对注意力分数进行softmax
A = torch.softmax(A, dim=-1)
# 计算最终输出:注意力分数 × V
# 然后将维度转回(B,L,D)
O = torch.matmul(A, V).permute(0, 2, 1, 3).view(batch_size, sequence_length, embedding_dim)
# 通过输出线性层
O = self.o(O)
return O