RNN 循环神经网络
RNN 核心概念
为什么需要 RNN
全连接网络处理序列问题:
- 输入输出长度固定
- 无法处理变长序列
- 不共享序列不同位置的信息
RNN 设计目标:
- 处理变长序列
- 记住之前的信息
- 共享权重(参数效率)RNN 结构
x₁ ──→ [RNN Cell] ──→ h₁ ──→ [RNN Cell] ──→ h₂ ──→ ...
↑ ↑
└───────────────────┘
(隐藏状态传递)python
# 基础 RNN
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
def forward(self, x, h=None):
# x: (batch, seq_len, input_size)
# h: (num_layers, batch, hidden_size)
output, h = self.rnn(x, h)
return output, hRNN 公式
hₜ = tanh(W·xₜ + U·hₜ₋₁ + b)
其中:
- xₜ: t 时刻的输入
- hₜ₋₁: t-1 时刻的隐藏状态
- W: 输入到隐藏的权重
- U: 隐藏到隐藏的权重(循环)RNN 的问题
梯度消失/爆炸
反向传播时间步:
∂L/∂W = Σ ∂L/∂hₜ × ∂hₜ/∂W
梯度包含 W 的连乘:
∂hₜ/∂hₜ₋₁ = tanh' × U
如果 |U·tanh'| > 1 → 梯度爆炸
如果 |U·tanh'| < 1 → 梯度消失长序列依赖问题:
- 早期信息在反向传播时被稀释
- t=1000 时刻的梯度 ≈ 0
- 网络难以记住长期依赖
解决方案
1. LSTM (Long Short-Term Memory)
- 门控机制,选择性记住/遗忘
2. GRU (Gated Recurrent Unit)
- 简化的 LSTM
- 更新门 + 重置门
3. Gradient Clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)LSTM (Long Short-Term Memory)
LSTM 结构
Cell State (细胞状态):
┌──────────────────────────────────────┐
│ │
↓ │
┌───────┐ ┌─────────┐ ┌─────────┐ │
│ 遗忘门 │ │ 输入门 │ │ 输出门 │ │
│Forget │ │ Input │ │ Output │ │
└───────┘ └─────────┘ └─────────┘ │
│ │ │ │
↓ ↓ ↓ │三个门
python
# 遗忘门:决定丢弃什么信息
f = σ(W_f · [h_{t-1}, x_t] + b_f)
# 输入门:决定存储什么信息
i = σ(W_i · [h_{t-1}, x_t] + b_i)
g = tanh(W_g · [h_{t-1}, x_t] + b_g)
# 输出门:决定输出什么
o = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o * tanh(C_t)PyTorch LSTM
python
# 单层 LSTM
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)
# 前向传播
# x: (batch, seq_len, input_size)
# h0, c0: (num_layers, batch, hidden_size)
output, (h_n, c_n) = lstm(x, (h0, c0))
# output: (batch, seq_len, hidden_size) - 每个时间步的输出
# h_n: (num_layers, batch, hidden_size) - 最后一个隐藏状态LSTM 解决梯度消失
Cell State 更新:
C_t = f * C_{t-1} + i * g
反向传播:
∂C_t/∂C_{t-1} = f + ... (不是连乘!)
f ∈ [0,1] 接近 1 时,梯度几乎无损传递
门控机制让网络自己学习保留/遗忘GRU (Gated Recurrent Unit)
GRU 结构
python
# 更新门:决定保留多少过去的信息
z = σ(W_z · [h_{t-1}, x_t])
# 重置门:决定忘记多少过去的信息
r = σ(W_r · [h_{t-1}, x_t])
# 候选隐藏状态
h_tilde = tanh(W · [r * h_{t-1}, x_t])
# 最终隐藏状态
h_t = (1 - z) * h_{t-1} + z * h_tildeLSTM vs GRU
| 特性 | LSTM | GRU |
|---|---|---|
| 门数量 | 3个(遗忘、输入、输出) | 2个(更新、重置) |
| 参数 | 较多 | 较少 |
| 效果 | 通常更好 | 相当,训练更快 |
| 选择 | 复杂任务 | 简单任务/资源有限 |
序列到序列 (Seq2Seq)
编码器-解码器结构
编码器 (Encoder):
h₁, h₂, ..., hₙ = Encoder(x₁, x₂, ..., xₙ)
c = hₙ # 上下文向量
解码器 (Decoder):
h'₁ = Decoder(c)
h'₂ = Decoder(c, h'₁)
...
y₁, y₂, ... = Decoder(h'₁, h'₂, ...)机器翻译示例
python
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_size)
outputs, (h, c) = self.lstm(embedded)
return h, c # 最终隐藏状态作为上下文
class Decoder(nn.Module):
def forward(self, x, h, c):
embedded = self.embedding(x)
outputs, (h, c) = self.lstm(embedded, (h, c))
# outputs: (batch, seq_len, hidden_size)
logits = self.fc(outputs)
return logits, h, cAttention 机制
问题:固定上下文向量是瓶颈
解决:每个解码步关注编码器的不同部分
Attention(Q, K, V) = softmax(QK^T / √d_k) V
- Q: 解码器当前隐藏状态
- K, V: 编码器所有隐藏状态
- 输出: 加权上下文向量python
class AttentionDecoder(nn.Module):
def forward(self, dec_h, enc_outputs):
# dec_h: (batch, hidden_size)
# enc_outputs: (batch, src_len, hidden_size)
# 计算注意力分数
scores = torch.matmul(enc_outputs, dec_h.unsqueeze(2)) # (batch, src_len, 1)
attn_weights = F.softmax(scores, dim=1)
# 加权求和
context = torch.sum(attn_weights * enc_outputs, dim=1) # (batch, hidden_size)
return context, attn_weights双向 RNN
python
# 双向 LSTM
bilstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)
# 合并双向输出
# 方法1: concatenate
combined = torch.cat([forward_h, backward_h], dim=2)
# 方法2: sum
combined = forward_h + backward_h
# 方法3: average
combined = (forward_h + backward_h) / 2序列填充与掩码
python
# 填充序列长度一致
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
# 填充
padded = pad_sequence([seq1, seq2, seq3], batch_first=True, padding_value=0)
# 记录实际长度
lengths = torch.tensor([len(seq1), len(seq2), len(seq3)])
# 打包(用于 RNN)
packed = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)
outputs, _ = self.lstm(packed)
# 解包
unpacked, _ = pad_packed_sequence(outputs, batch_first=True)实战:文本分类
python
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_size, hidden_size, num_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
self.lstm = nn.LSTM(embed_size, hidden_size,
num_layers=2,
batch_first=True,
dropout=0.3,
bidirectional=True)
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
# x: (batch, seq_len)
embedded = self.embedding(x) # (batch, seq_len, embed_size)
# 双向 LSTM
outputs, (h, c) = self.lstm(embedded)
# 取最后一层的双向隐藏状态拼接
hidden = torch.cat([h[-2], h[-1]], dim=1) # (batch, hidden_size*2)
return self.fc(hidden) # (batch, num_classes)RNN 变体与比较
| 模型 | 特点 | 适用场景 |
|---|---|---|
| RNN | 基础,梯度消失 | 简单短序列 |
| LSTM | 门控,长期记忆 | 长序列、NLP |
| GRU | 简化 LSTM | 长序列、资源有限 |
| BiLSTM | 双向信息 | 序列标注、分类 |
| Stacked LSTM | 多层抽象 | 复杂任务 |
面试要点
1. 为什么普通 RNN 梯度消失
- 反向传播连乘导致
- 远处梯度指数衰减
2. LSTM 如何解决
- 门控机制
- Cell State 路径梯度近乎无损
3. LSTM vs GRU 区别
- GRU 更简单,参数少
- LSTM 多一个输出门
4. RNN vs CNN 处理序列
- RNN:变长、共享权重、顺序依赖
- CNN:固定窗口、局部感受野
5. Attention 为什么有效
- 解决固定上下文瓶颈
- 允许解码器动态关注源序列不同部分
6. 双向 RNN 适用场景
- 需要看到完整上下文
- 分类、序列标注(NER)
- 不适用于实时预测(需要未来信息)