Skip to content

Transformer 与 Attention

Attention 机制

核心思想

Seq2Seq 问题:固定上下文向量是瓶颈

解决方案:每个输出位置都能关注输入的所有位置

Attention(Q, K, V) = softmax(QK^T / √d_k) V

Query: 我要查询什么
Key: 我有什么特征可以匹配
Value: 实际的上下文信息

Self-Attention

python
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.scale = math.sqrt(d_model)
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # scores: (batch, seq_len, seq_len)
        
        # 归一化
        attn_weights = F.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attn_weights, V)
        # output: (batch, seq_len, d_model)
        
        return output, attn_weights

Multi-Head Attention

python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_o = nn.Linear(d_model, d_model)
        self.heads = nn.ModuleList([
            SelfAttention(d_model) for _ in range(num_heads)
        ])
    
    def forward(self, x):
        # 多头并行计算
        head_outputs = [head(x) for head in self.heads]
        
        # 拼接所有头
        concat = torch.cat([h[0] for h in head_outputs], dim=-1)
        
        # 线性变换
        output = self.W_o(concat)
        
        return output

为什么用 Multi-Head:

  • 不同头可以关注不同位置/模式
  • 有的头关注语法,有的关注语义
  • 增强模型表达能力

Scaled Dot-Product Attention

QK^T 的问题:
- d_k 很大时,点积值方差变大
- softmax 进入饱和区,梯度小

解决:除以 √d_k

scale = √d_k

Transformer 结构

整体架构

                    Encoder                        Decoder
                     │                               │
Input ──→ Embedding ──→ [Encode] ──→ Enc Output ──→ [Decode] ──→ Output
              │                              ↑        │
              ↓                              │        ↓
         Positional                        │      Linear + Softmax
          Encoding                          │

                                        (堆叠)

位置编码 (Positional Encoding)

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                             (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        # x: (batch, seq_len, d_model)
        return x + self.pe[:, :x.size(1)]

为什么需要位置编码:

  • Self-Attention 不包含位置信息
  • 词序信息需要额外注入
  • 用正弦/余弦函数编码位置

Encoder

python
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # Self Attention + Residual + Norm
        attn_output, _ = self.self_attn(x)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed Forward + Residual + Norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x

class Encoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Decoder

python
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(...)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x, enc_output):
        # Masked Self Attention(防止看到未来)
        attn1, _ = self.self_attn(x)
        x = self.norm1(x + attn1)
        
        # Cross Attention(关注编码器输出)
        attn2, _ = self.cross_attn(x, enc_output)
        x = self.norm2(x + attn2)
        
        # FFN
        ffn_output = self.ffn(x)
        x = self.norm3(x + ffn_output)
        
        return x

掩码 (Masking)

python
# 1. Padding Mask(处理变长序列)
def create_padding_mask(seq):
    # mask = 1 表示需要 mask(填充)
    mask = (seq == 0)  # 假设 0 是 padding
    return mask.unsqueeze(1).unsqueeze(2)

# 2. Look-ahead Mask(解码器自注意,防止看到未来)
def create_lookahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask

BERT (Bidirectional Encoder Representations from Transformers)

预训练任务

python
# 1. MLM (Masked Language Model)
# 随机遮盖 15% 的词,预测被遮盖的词
# 输入: The [MASK] cat sat on the [MASK]
# 目标: mat, other

# 2. NSP (Next Sentence Prediction)
# 判断句子B是否是句子A的下一句
# 用于学习句子关系

BERT 输入表示

python
[CLS] Sentence A [SEP] Sentence B [SEP]

Token Embeddings + Segment Embeddings + Position Embeddings

BERT 应用

python
from transformers import BertModel, BertTokenizer

# 加载预训练模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# 分词
inputs = tokenizer("Hello, how are you?", return_tensors='pt')

# 前向传播
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state  # (1, seq_len, 768)
pooler_output = outputs.pooler_output          # (1, 768) [CLS] 的输出

# 文本分类
class BertClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(768, num_classes)
    
    def forward(self, x):
        outputs = self.bert(**x)
        return self.classifier(outputs.pooler_output)

GPT (Generative Pre-trained Transformer)

GPT 预训练

python
# 单向 Language Model
# 预测下一个词
# 输入: The cat sat on the
# 目标: mat

# 因果掩码 (Causal Mask)
def causal_mask(size):
    # 上三角为 True(需要 mask)
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask

GPT 版本

版本参数量特点
GPT-11.17亿开创性
GPT-215亿开源、Zero-shot
GPT-31750亿Few-shot
GPT-4未公开多模态
GPT-4o-原生多模态

InstructGPT / ChatGPT

python
# RLHF (Reinforcement Learning from Human Feedback)
# 1. SFT: 有监督微调
# 2. Reward Model: 训练奖励模型
# 3. PPO: 强化学习优化

T5 (Text-to-Text Transfer Transformer)

统一框架

python
# 所有任务都是 text-to-text
input: "translate English to German: Hello"
output: "Hallo"

input: "summarize: The text to summarize..."
output: "Summary text"

模型对比

模型架构预训练任务特点
BERTEncoder-onlyMLM + NSP理解任务强
GPTDecoder-onlyCLM生成任务强
T5Encoder-DecoderSpan Corruption通用性

ViT (Vision Transformer)

图像分块

python
# 图像 → 16×16 patches → 展平 → Linear Projection

class ViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=1000):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2
        
        # Patch + Position Embedding
        self.proj = nn.Conv2d(3, 768, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, 768))
        
        # Transformer Encoder
        self.encoder = TransformerEncoder(num_layers=12, d_model=768, num_heads=12)
        
        # 分类头
        self.head = nn.Linear(768, num_classes)
    
    def forward(self, x):
        # x: (B, 3, 224, 224)
        x = self.proj(x)  # (B, 768, 14, 14)
        x = x.flatten(2).transpose(1, 2)  # (B, 196, 768)
        
        # 添加 [CLS] token
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 197, 768)
        
        # 添加位置编码
        x = x + self.pos_embed
        
        # Transformer 编码
        x = self.encoder(x)
        
        # 取 [CLS] token 输出
        cls_output = x[:, 0]
        
        return self.head(cls_output)

Transformer 变体

高效注意力

python
# 1. Flash Attention(IO 优化)
from flash_attn import flash_attn_func

# 2. Linear Attention(线性复杂度)
# 用核函数近似 softmax

# 3. Sparse Attention
# 只计算部分注意力分数

位置编码变体

python
# 1. RoPE (Rotary Position Embedding) - LLaMA 使用
# 2. ALiBi (Attention with Linear Biases) - BLOOM 使用

训练技巧

python
# 1. 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 2. 学习率调度
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=500,
    num_training_steps=100000
)

# 3. 混合精度
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

for batch in dataloader:
    with autocast():
        outputs = model(**batch)
        loss = criterion(outputs, labels)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

# 4. 分布式训练
from torch.nn.parallel import DataParallel
model = DataParallel(model)

面试要点

1. Attention 公式
   Attention(Q,K,V) = softmax(QK^T/√d_k)V

2. Self Attention vs Cross Attention
   - Self: Q,K,V 都来自同一序列
   - Cross: Q 来自解码器,K,V 来自编码器

3. Multi-Head Attention 的作用
   - 多头可以关注不同子空间
   - 增强模型表达能力
   - 有些头学语法,有些学语义

4. Transformer vs RNN
   - 并行计算(训练快)
   - 长距离依赖(通过 attention 直接连接)
   - 固定上下文(但 attention 可以全局)

5. BERT vs GPT
   - BERT: Encoder,双向,MLM 预训练
   - GPT: Decoder,单向,CLM 预训练

6. 位置编码为什么用正余弦
   - 可以表示相对位置
   - 可以外推到训练时见过的更长序列

7. Transformer 复杂度
   - Self Attention: O(n²·d)
   - 优化:FlashAttention、Linear Attention

8. LayerNorm vs BatchNorm
   - LN: 在特征维度上归一化,适合变长序列
   - BN: 在 batch 维度上归一化,适合 batch 较大的 CV

基于 MIT 许可发布