研究者与发布者为:CSDN wq舞s,知乎 wqws
github wqws
在深度学习领域,循环神经网络(RNN)及其变体GRU和LSTM长期以来一直是处理序列数据的首选架构。然而,它们在长序列任务中始终存在信息遗忘和梯度消失等问题。今天,我很高兴地宣布一种全新的RNN架构——AttnRNN,它在多个长序列基准测试中全面超越传统模型,性能提升显著!
AttnRNN的设计灵感来源于生物神经系统的信息处理机制。人脑在处理信息时,并非简单地将所有输入平等对待,而是让已有信息与新信息进行竞争,重要的信息被强化,无用的信息被清除。基于这一思想,AttnRNN的核心设计原则包括:
class EfficientAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 仅投影Q值,KV不使用投影
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, context):
# 仅投影Q值
q = self.q_proj(query)
# K和V直接使用原始context
k = context
v = context
# 重塑为多头格式
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 注意力计算
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# 合并多头输出
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.embed_dim)
return self.out_proj(attn_output)
该注意力模块的创新点在于:
class AttnRNNCell(nn.Module):
def __init__(self, input_size, hidden_size, num_heads=4):
super().__init__()
self.hidden_size = hidden_size
# 投影层(当输入维度与隐藏维度不同时)
self.proj_layer = (
nn.Linear(input_size, hidden_size)
if input_size != hidden_size
else nn.Identity()
)
# 注意力机制
self.attn = EfficientAttention(
embed_dim=hidden_size,
num_heads=num_heads
)
# 精简门控机制
self.update_gate = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.Sigmoid()
)
# 层归一化
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, h_prev, x):
x_proj = self.proj_layer(x)
# 构建丰富上下文:拼接、相加、相乘
h_exp = h_prev.unsqueeze(1)
x_exp = x_proj.unsqueeze(1)
context = torch.cat([h_exp, x_exp, h_exp + x_exp, h_exp * x_exp], dim=1)
# 注意力计算
attn_out = self.attn(query=h_exp, context=context)
attn_out = attn_out.squeeze(1)
# 门控融合
gate_input = torch.cat([h_prev, attn_out], dim=-1)
update_gate = self.update_gate(gate_input)
# 门控更新(含残差连接)
h_candidate = self.layer_norm(attn_out + h_prev)
h_new = update_gate * h_candidate + (1 - update_gate) * h_prev
return h_new
模型 | 参数量公式 (d≠h) | 示例 (d=64, h=128) | 相对比例 |
---|---|---|---|
Vanilla RNN | dh + h² + 2h | 24,832 | 100% |
LSTM | 4(dh + h² + h) | 99,328 | 400% |
GRU | 3(dh + h² + h) | 74,496 | 300% |
AttnRNN | dh + 4h² + 6h | 74,496 | 300% |
关键发现:
模型 | 单步时间复杂度 | 序列时间复杂度 | 主导项系数 |
---|---|---|---|
RNN | O(B(dh + h²)) | O(TB(dh + h²)) | 1 |
LSTM | O(4B(dh + h²)) | O(4TB(dh + h²)) | 4 |
GRU | O(3B(dh + h²)) | O(3TB(dh + h²)) | 3 |
AttnRNN | O(B(3h² + dh)) | O(TB(3h² + dh)) | 3 |
注意:尽管AttnRNN的理论时间复杂度与GRU相当,但由于张量重塑操作,当前实现耗时约为GRU的2.5-3倍。我们正在开发优化版本,预计可将差距缩小到1.5倍以内。
我们在两个经典的长序列测试任务上进行了全面评估,所有模型使用相同的超参数设置(隐藏维度128,epochs=30,lr=0.001)。
模型需要记住序列中两个特定位置的值,并在序列结束时输出它们的和。评价指标为平均绝对误差(MAE),越低越好。
序列长度 | RNN | GRU | LSTM | AttnRNN | 相对提升 |
---|---|---|---|---|---|
50 | 0.3411 | 0.0111 | 0.0666 | 0.0069 | +38% |
100 | 0.3392 | 0.0108 | 0.0083 | 0.0010 | +88% |
200 | 0.3244 | 0.0048 | 0.0231 | 0.0014 | +71% |
400 | 0.3458 | 0.0208 | 0.0427 | 0.0011 | +95% |
模型需要记住序列前三个字符,并在序列结束时复制它们。评价指标为准确率(Acc),越高越好。
序列长度 | RNN | GRU | LSTM | AttnRNN | 相对提升 |
---|---|---|---|---|---|
30 | 0.1380 | 0.6620 | 0.5480 | 0.8140 | +23% |
60 | 0.2065 | 0.4550 | 0.3955 | 0.5780 | +27% |
90 | 0.1917 | 0.3877 | 0.3413 | 0.4720 | +22% |
120 | 0.1643 | 0.3463 | 0.3008 | 0.4383 | +27% |
AttnRNN在以下场景具有显著优势:
尤其适合处理超过200步的长序列任务,在这些场景中,传统RNN变体往往表现不佳。
我们正在以下方向继续改进AttnRNN:
AttnRNN已开源!欢迎社区贡献代码和反馈:
GitHub仓库:https://github.com/liluoyi666/Important_Memories-AttnRNN.git
AttnRNN通过创新的注意力竞争机制和精简门控设计,在长序列任务中实现了突破性性能提升。它不仅解决了传统RNN变体的长距离依赖问题,还保持了参数效率。实验证明,AttnRNN在多个基准测试中全面超越RNN、GRU和LSTM,特别是在长序列场景下优势显著。
随着序列数据处理需求的不断增长,AttnRNN有望成为处理长序列任务的新标杆架构。我们期待看到它在各个领域的应用成果!