Skip to content

RNN 循环神经网络

RNN 核心概念

为什么需要 RNN

全连接网络处理序列问题:
- 输入输出长度固定
- 无法处理变长序列
- 不共享序列不同位置的信息

RNN 设计目标:
- 处理变长序列
- 记住之前的信息
- 共享权重(参数效率)

RNN 结构

x₁ ──→ [RNN Cell] ──→ h₁ ──→ [RNN Cell] ──→ h₂ ──→ ...
      ↑                   ↑
      └───────────────────┘
           (隐藏状态传递)
python
# 基础 RNN
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
    
    def forward(self, x, h=None):
        # x: (batch, seq_len, input_size)
        # h: (num_layers, batch, hidden_size)
        output, h = self.rnn(x, h)
        return output, h

RNN 公式

hₜ = tanh(W·xₜ + U·hₜ₋₁ + b)

其中:
- xₜ: t 时刻的输入
- hₜ₋₁: t-1 时刻的隐藏状态
- W: 输入到隐藏的权重
- U: 隐藏到隐藏的权重(循环)

RNN 的问题

梯度消失/爆炸

反向传播时间步:
∂L/∂W = Σ ∂L/∂hₜ × ∂hₜ/∂W

梯度包含 W 的连乘:
∂hₜ/∂hₜ₋₁ = tanh' × U

如果 |U·tanh'| > 1 → 梯度爆炸
如果 |U·tanh'| < 1 → 梯度消失

长序列依赖问题

  • 早期信息在反向传播时被稀释
  • t=1000 时刻的梯度 ≈ 0
  • 网络难以记住长期依赖

解决方案

1. LSTM (Long Short-Term Memory)
   - 门控机制,选择性记住/遗忘
   
2. GRU (Gated Recurrent Unit)
   - 简化的 LSTM
   - 更新门 + 重置门
   
3. Gradient Clipping
   torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

LSTM (Long Short-Term Memory)

LSTM 结构

Cell State (细胞状态):
      ┌──────────────────────────────────────┐
      │                                      │
      ↓                                      │
  ┌───────┐    ┌─────────┐    ┌─────────┐   │
  │ 遗忘门 │    │ 输入门  │    │ 输出门  │   │
  │Forget │    │ Input   │    │ Output  │   │
  └───────┘    └─────────┘    └─────────┘   │
      │            │            │          │
      ↓            ↓            ↓          │

三个门

python
# 遗忘门:决定丢弃什么信息
f = σ(W_f · [h_{t-1}, x_t] + b_f)

# 输入门:决定存储什么信息
i = σ(W_i · [h_{t-1}, x_t] + b_i)
g = tanh(W_g · [h_{t-1}, x_t] + b_g)

# 输出门:决定输出什么
o = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o * tanh(C_t)

PyTorch LSTM

python
# 单层 LSTM
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)

# 前向传播
# x: (batch, seq_len, input_size)
# h0, c0: (num_layers, batch, hidden_size)
output, (h_n, c_n) = lstm(x, (h0, c0))

# output: (batch, seq_len, hidden_size) - 每个时间步的输出
# h_n: (num_layers, batch, hidden_size) - 最后一个隐藏状态

LSTM 解决梯度消失

Cell State 更新:
C_t = f * C_{t-1} + i * g

反向传播:
∂C_t/∂C_{t-1} = f + ... (不是连乘!)

f ∈ [0,1] 接近 1 时,梯度几乎无损传递
门控机制让网络自己学习保留/遗忘

GRU (Gated Recurrent Unit)

GRU 结构

python
# 更新门:决定保留多少过去的信息
z = σ(W_z · [h_{t-1}, x_t])

# 重置门:决定忘记多少过去的信息
r = σ(W_r · [h_{t-1}, x_t])

# 候选隐藏状态
h_tilde = tanh(W · [r * h_{t-1}, x_t])

# 最终隐藏状态
h_t = (1 - z) * h_{t-1} + z * h_tilde

LSTM vs GRU

特性LSTMGRU
门数量3个(遗忘、输入、输出)2个(更新、重置)
参数较多较少
效果通常更好相当,训练更快
选择复杂任务简单任务/资源有限

序列到序列 (Seq2Seq)

编码器-解码器结构

编码器 (Encoder):
h₁, h₂, ..., hₙ = Encoder(x₁, x₂, ..., xₙ)
c = hₙ  # 上下文向量

解码器 (Decoder):
h'₁ = Decoder(c)
h'₂ = Decoder(c, h'₁)
...
y₁, y₂, ... = Decoder(h'₁, h'₂, ...)

机器翻译示例

python
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
    
    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_size)
        outputs, (h, c) = self.lstm(embedded)
        return h, c  # 最终隐藏状态作为上下文

class Decoder(nn.Module):
    def forward(self, x, h, c):
        embedded = self.embedding(x)
        outputs, (h, c) = self.lstm(embedded, (h, c))
        # outputs: (batch, seq_len, hidden_size)
        logits = self.fc(outputs)
        return logits, h, c

Attention 机制

问题:固定上下文向量是瓶颈

解决:每个解码步关注编码器的不同部分

Attention(Q, K, V) = softmax(QK^T / √d_k) V

- Q: 解码器当前隐藏状态
- K, V: 编码器所有隐藏状态
- 输出: 加权上下文向量
python
class AttentionDecoder(nn.Module):
    def forward(self, dec_h, enc_outputs):
        # dec_h: (batch, hidden_size)
        # enc_outputs: (batch, src_len, hidden_size)
        
        # 计算注意力分数
        scores = torch.matmul(enc_outputs, dec_h.unsqueeze(2))  # (batch, src_len, 1)
        attn_weights = F.softmax(scores, dim=1)
        
        # 加权求和
        context = torch.sum(attn_weights * enc_outputs, dim=1)  # (batch, hidden_size)
        return context, attn_weights

双向 RNN

python
# 双向 LSTM
bilstm = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True, bidirectional=True)

# 合并双向输出
# 方法1: concatenate
combined = torch.cat([forward_h, backward_h], dim=2)

# 方法2: sum
combined = forward_h + backward_h

# 方法3: average
combined = (forward_h + backward_h) / 2

序列填充与掩码

python
# 填充序列长度一致
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# 填充
padded = pad_sequence([seq1, seq2, seq3], batch_first=True, padding_value=0)

# 记录实际长度
lengths = torch.tensor([len(seq1), len(seq2), len(seq3)])

# 打包(用于 RNN)
packed = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)
outputs, _ = self.lstm(packed)

# 解包
unpacked, _ = pad_packed_sequence(outputs, batch_first=True)

实战:文本分类

python
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.lstm = nn.LSTM(embed_size, hidden_size, 
                           num_layers=2, 
                           batch_first=True,
                           dropout=0.3,
                           bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, num_classes)
    
    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_size)
        
        # 双向 LSTM
        outputs, (h, c) = self.lstm(embedded)
        
        # 取最后一层的双向隐藏状态拼接
        hidden = torch.cat([h[-2], h[-1]], dim=1)  # (batch, hidden_size*2)
        
        return self.fc(hidden)  # (batch, num_classes)

RNN 变体与比较

模型特点适用场景
RNN基础,梯度消失简单短序列
LSTM门控,长期记忆长序列、NLP
GRU简化 LSTM长序列、资源有限
BiLSTM双向信息序列标注、分类
Stacked LSTM多层抽象复杂任务

面试要点

1. 为什么普通 RNN 梯度消失
   - 反向传播连乘导致
   - 远处梯度指数衰减

2. LSTM 如何解决
   - 门控机制
   - Cell State 路径梯度近乎无损

3. LSTM vs GRU 区别
   - GRU 更简单,参数少
   - LSTM 多一个输出门

4. RNN vs CNN 处理序列
   - RNN:变长、共享权重、顺序依赖
   - CNN:固定窗口、局部感受野

5. Attention 为什么有效
   - 解决固定上下文瓶颈
   - 允许解码器动态关注源序列不同部分

6. 双向 RNN 适用场景
   - 需要看到完整上下文
   - 分类、序列标注(NER)
   - 不适用于实时预测(需要未来信息)

基于 MIT 许可发布