Transformer 与 Attention
Attention 机制
核心思想
Seq2Seq 问题:固定上下文向量是瓶颈
解决方案:每个输出位置都能关注输入的所有位置
Attention(Q, K, V) = softmax(QK^T / √d_k) V
Query: 我要查询什么
Key: 我有什么特征可以匹配
Value: 实际的上下文信息Self-Attention
python
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.scale = math.sqrt(d_model)
def forward(self, x):
# x: (batch, seq_len, d_model)
Q = self.W_q(x) # (batch, seq_len, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
# scores: (batch, seq_len, seq_len)
# 归一化
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V)
# output: (batch, seq_len, d_model)
return output, attn_weightsMulti-Head Attention
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_o = nn.Linear(d_model, d_model)
self.heads = nn.ModuleList([
SelfAttention(d_model) for _ in range(num_heads)
])
def forward(self, x):
# 多头并行计算
head_outputs = [head(x) for head in self.heads]
# 拼接所有头
concat = torch.cat([h[0] for h in head_outputs], dim=-1)
# 线性变换
output = self.W_o(concat)
return output为什么用 Multi-Head:
- 不同头可以关注不同位置/模式
- 有的头关注语法,有的关注语义
- 增强模型表达能力
Scaled Dot-Product Attention
QK^T 的问题:
- d_k 很大时,点积值方差变大
- softmax 进入饱和区,梯度小
解决:除以 √d_k
scale = √d_kTransformer 结构
整体架构
Encoder Decoder
│ │
Input ──→ Embedding ──→ [Encode] ──→ Enc Output ──→ [Decode] ──→ Output
│ ↑ │
↓ │ ↓
Positional │ Linear + Softmax
Encoding │
N×
(堆叠)位置编码 (Positional Encoding)
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
# x: (batch, seq_len, d_model)
return x + self.pe[:, :x.size(1)]为什么需要位置编码:
- Self-Attention 不包含位置信息
- 词序信息需要额外注入
- 用正弦/余弦函数编码位置
Encoder
python
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# Self Attention + Residual + Norm
attn_output, _ = self.self_attn(x)
x = self.norm1(x + self.dropout(attn_output))
# Feed Forward + Residual + Norm
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
class Encoder(nn.Module):
def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return xDecoder
python
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = nn.Sequential(...)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, enc_output):
# Masked Self Attention(防止看到未来)
attn1, _ = self.self_attn(x)
x = self.norm1(x + attn1)
# Cross Attention(关注编码器输出)
attn2, _ = self.cross_attn(x, enc_output)
x = self.norm2(x + attn2)
# FFN
ffn_output = self.ffn(x)
x = self.norm3(x + ffn_output)
return x掩码 (Masking)
python
# 1. Padding Mask(处理变长序列)
def create_padding_mask(seq):
# mask = 1 表示需要 mask(填充)
mask = (seq == 0) # 假设 0 是 padding
return mask.unsqueeze(1).unsqueeze(2)
# 2. Look-ahead Mask(解码器自注意,防止看到未来)
def create_lookahead_mask(size):
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return maskBERT (Bidirectional Encoder Representations from Transformers)
预训练任务
python
# 1. MLM (Masked Language Model)
# 随机遮盖 15% 的词,预测被遮盖的词
# 输入: The [MASK] cat sat on the [MASK]
# 目标: mat, other
# 2. NSP (Next Sentence Prediction)
# 判断句子B是否是句子A的下一句
# 用于学习句子关系BERT 输入表示
python
[CLS] Sentence A [SEP] Sentence B [SEP]
Token Embeddings + Segment Embeddings + Position EmbeddingsBERT 应用
python
from transformers import BertModel, BertTokenizer
# 加载预训练模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# 分词
inputs = tokenizer("Hello, how are you?", return_tensors='pt')
# 前向传播
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state # (1, seq_len, 768)
pooler_output = outputs.pooler_output # (1, 768) [CLS] 的输出
# 文本分类
class BertClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.classifier = nn.Linear(768, num_classes)
def forward(self, x):
outputs = self.bert(**x)
return self.classifier(outputs.pooler_output)GPT (Generative Pre-trained Transformer)
GPT 预训练
python
# 单向 Language Model
# 预测下一个词
# 输入: The cat sat on the
# 目标: mat
# 因果掩码 (Causal Mask)
def causal_mask(size):
# 上三角为 True(需要 mask)
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return maskGPT 版本
| 版本 | 参数量 | 特点 |
|---|---|---|
| GPT-1 | 1.17亿 | 开创性 |
| GPT-2 | 15亿 | 开源、Zero-shot |
| GPT-3 | 1750亿 | Few-shot |
| GPT-4 | 未公开 | 多模态 |
| GPT-4o | - | 原生多模态 |
InstructGPT / ChatGPT
python
# RLHF (Reinforcement Learning from Human Feedback)
# 1. SFT: 有监督微调
# 2. Reward Model: 训练奖励模型
# 3. PPO: 强化学习优化T5 (Text-to-Text Transfer Transformer)
统一框架
python
# 所有任务都是 text-to-text
input: "translate English to German: Hello"
output: "Hallo"
input: "summarize: The text to summarize..."
output: "Summary text"模型对比
| 模型 | 架构 | 预训练任务 | 特点 |
|---|---|---|---|
| BERT | Encoder-only | MLM + NSP | 理解任务强 |
| GPT | Decoder-only | CLM | 生成任务强 |
| T5 | Encoder-Decoder | Span Corruption | 通用性 |
ViT (Vision Transformer)
图像分块
python
# 图像 → 16×16 patches → 展平 → Linear Projection
class ViT(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_classes=1000):
super().__init__()
self.num_patches = (image_size // patch_size) ** 2
# Patch + Position Embedding
self.proj = nn.Conv2d(3, 768, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, 768))
# Transformer Encoder
self.encoder = TransformerEncoder(num_layers=12, d_model=768, num_heads=12)
# 分类头
self.head = nn.Linear(768, num_classes)
def forward(self, x):
# x: (B, 3, 224, 224)
x = self.proj(x) # (B, 768, 14, 14)
x = x.flatten(2).transpose(1, 2) # (B, 196, 768)
# 添加 [CLS] token
cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, 197, 768)
# 添加位置编码
x = x + self.pos_embed
# Transformer 编码
x = self.encoder(x)
# 取 [CLS] token 输出
cls_output = x[:, 0]
return self.head(cls_output)Transformer 变体
高效注意力
python
# 1. Flash Attention(IO 优化)
from flash_attn import flash_attn_func
# 2. Linear Attention(线性复杂度)
# 用核函数近似 softmax
# 3. Sparse Attention
# 只计算部分注意力分数位置编码变体
python
# 1. RoPE (Rotary Position Embedding) - LLaMA 使用
# 2. ALiBi (Attention with Linear Biases) - BLOOM 使用训练技巧
python
# 1. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 2. 学习率调度
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=500,
num_training_steps=100000
)
# 3. 混合精度
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast():
outputs = model(**batch)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# 4. 分布式训练
from torch.nn.parallel import DataParallel
model = DataParallel(model)面试要点
1. Attention 公式
Attention(Q,K,V) = softmax(QK^T/√d_k)V
2. Self Attention vs Cross Attention
- Self: Q,K,V 都来自同一序列
- Cross: Q 来自解码器,K,V 来自编码器
3. Multi-Head Attention 的作用
- 多头可以关注不同子空间
- 增强模型表达能力
- 有些头学语法,有些学语义
4. Transformer vs RNN
- 并行计算(训练快)
- 长距离依赖(通过 attention 直接连接)
- 固定上下文(但 attention 可以全局)
5. BERT vs GPT
- BERT: Encoder,双向,MLM 预训练
- GPT: Decoder,单向,CLM 预训练
6. 位置编码为什么用正余弦
- 可以表示相对位置
- 可以外推到训练时见过的更长序列
7. Transformer 复杂度
- Self Attention: O(n²·d)
- 优化:FlashAttention、Linear Attention
8. LayerNorm vs BatchNorm
- LN: 在特征维度上归一化,适合变长序列
- BN: 在 batch 维度上归一化,适合 batch 较大的 CV