摘要——循环神经网络(RNN)通过其固有的时序依赖捕捉能力,从根本上改变了序列建模范式。本文系统剖析RNN处理序列数据的结构优势,对比其与LSTM等现代方案的局限性,并提供PyTorch实现范例,最后探讨稀疏循环与连续时间架构等未来研究方向。
时序数据(时间序列、文本、语音)具有关键的时间依赖性——元素 x t x_t xt依赖于前序元素 { x 1 , . . . , x t − 1 } \{x_1, ..., x_{t-1}\} {x1,...,xt−1}。传统前馈神经网络存在固有缺陷:
RNN单元的核心计算:
h t = σ ( W h h t − 1 + W x x t + b ) h_t = \sigma(W_h h_{t-1} + W_x x_t + b) ht=σ(Whht−1+Wxxt+b)
其中:
关键优势:
通过时间反向传播(BPTT)计算梯度:
∂ L ∂ W = ∑ t = 1 T ∂ L t ∂ W \frac{\partial \mathcal{L}}{\partial W} = \sum_{t=1}^T \frac{\partial \mathcal{L}_t}{\partial W} ∂W∂L=t=1∑T∂W∂Lt
缺陷:当 W h W_h Wh的特征值远大于/小于1时,梯度消失/爆炸会限制长程依赖捕捉。
import torch
import torch.nn as nn
class RNNCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size))
self.W_x = nn.Parameter(torch.randn(hidden_size, input_size))
self.b = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x, h_prev):
h_next = torch.tanh(h_prev @ self.W_h + x @ self.W_x.T + self.b)
return h_next
def forward_sequence(self, X): # X: [seq_len, batch_size, input_size]
h = torch.zeros(X.size(1), self.hidden_size) # 初始状态
outputs = []
for t in range(X.size(0)):
h = self(X[t], h) # 每个时间步复用单元
outputs.append(h)
return torch.stack(outputs) # [seq_len, batch_size, hidden_size]
当相关输入间隔超过≈10步时,RNN表现显著下降:
∥ ∂ h t ∂ h k ∥ ∝ ∥ W h t − k ∥ → 0 或 ∞ \left\|\frac{\partial h_t}{\partial h_k}\right\| \propto \|W_h^{t-k}\| \rightarrow 0 \quad \text{或} \quad \infty ∂hk∂ht ∝∥Wht−k∥→0或∞
**长短期记忆网络(LSTM)**引入门控机制:
f t = σ ( W f [ h t − 1 , x t ] + b f ) (遗忘门) i t = σ ( W i [ h t − 1 , x t ] + b i ) (输入门) c ~ t = tanh ( W c [ h t − 1 , x t ] + b c ) c t = f t ⊙ c t − 1 + i t ⊙ c ~ t o t = σ ( W o [ h t − 1 , x t ] + b o ) h t = o t ⊙ tanh ( c t ) \begin{align*} f_t &= \sigma(W_f [h_{t-1}, x_t] + b_f) \quad \text{(遗忘门)} \\ i_t &= \sigma(W_i [h_{t-1}, x_t] + b_i) \quad \text{(输入门)} \\ \tilde{c}_t &= \tanh(W_c [h_{t-1}, x_t] + b_c) \\ c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ o_t &= \sigma(W_o [h_{t-1}, x_t] + b_o) \\ h_t &= o_t \odot \tanh(c_t) \end{align*} ftitc~tctotht=σ(Wf[ht−1,xt]+bf)(遗忘门)=σ(Wi[ht−1,xt]+bi)(输入门)=tanh(Wc[ht−1,xt]+bc)=ft⊙ct−1+it⊙c~t=σ(Wo[ht−1,xt]+bo)=ot⊙tanh(ct)
RNN凭借时序参数共享与显式状态传递,仍是序列建模的基石。未来的突破将来自算法-硬件协同设计,特别是面向稀疏与连续时间动态的优化。
【哈佛博后带小白玩转机器学习】