预训练工程与稳定性
本章目标:掌握让模型"从头训练不崩"的工程技巧
2.1 数据工程基础
2.1.1 Tokenization:BPE 与 SentencePiece
Byte Pair Encoding (BPE):一种子词分词算法
原始语料: "hug", "pug", "pun", "bug", "bun", "hugs"
↓
BPE 词表构建:
- 统计所有相邻字节对频率
- 合并最频繁的对
- 重复直到达到目标词表大小
典型 BPE 过程:
"low" = "l" "o" "w"
"low" 经过 BPE 后可能是 "lo" "w" 或 "l" "ow"# 使用 SentencePiece (LLaMA 采用)
import sentencepiece as spm
# 训练 BPE 模型
spm.SentencePieceTrainer.train(
input='corpus.txt',
model_prefix='tokenizer',
vocab_size=32000,
model_type='bpe',
character_coverage=1.0,
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3
)
# 使用
sp = sp.load('tokenizer.model')
tokens = sp.encode("Hello, world!", out_type=int) # [1234, 45, 6789]
text = sp.decode(tokens) # "Hello, world!"为什么用 BPE 而不是字级别?
| 分词方式 | 词表大小 | OOV 处理 | 编码效率 |
|---|---|---|---|
| 字级别 | ~50K | 差 | 低 |
| 词级别 | ~500K | 好 | 高 |
| BPE | ~32K | 好 | 中 |
2.1.2 DoReMi:动态数据重采样
问题:不同文档的"有用程度"不同,均匀采样浪费算力。
DoReMi (Domain Reweighting with Minimax):
核心思想:用小模型学各领域的重要性,然后重采样训练大模型。
def doremi_resample(documents, domain_weights, alpha=0.5):
"""
documents: [(text, domain), ...]
domain_weights: 每个领域的采样权重
alpha: 温度参数
"""
# 计算采样概率
probs = np.array([domain_weights[d] for _, d in documents])
probs = probs ** alpha
probs = probs / probs.sum()
# 采样
indices = np.random.choice(len(documents), size=len(documents),
replace=True, p=probs)
return [documents[i] for i in indices]2.1.3 长文本语料构造
代码 vs 书籍的配比:
LLaMA 训练数据分布:
- CommonCrawl (67%)
- C4 (15%)
- GitHub (4.5%)
- Wikipedia (4.5%)
- Books (4.5%)
- ArXiv (2.5%)
- Stack Exchange (2%)长文本处理策略:
def pack_documents(documents, max_seq_len, tokenizer):
"""
将文档打包成固定长度序列
"""
packed = []
current = []
current_len = 0
for doc in documents:
doc_tokens = tokenizer.encode(doc)
if current_len + len(doc_tokens) + 1 <= max_seq_len:
current.append(doc_tokens)
current_len += len(doc_tokens) + 1 # +1 for separator
else:
# 填充或截断当前序列
packed.append(pad_or_truncate(current, max_seq_len))
current = [doc_tokens]
current_len = len(doc_tokens)
return packed2.2 训练稳定性三板斧
2.2.1 混合精度训练:BF16 vs FP16
BF16 (Brain Float 16) vs FP16 (Float 16):
FP16: 1 sign | 5 exponent | 10 mantissa = 16 bits
BF16: 1 sign | 8 exponent | 7 mantissa = 16 bits
FP16: 指数小 → 数值范围小,但精度高
BF16: 指数大 → 数值范围大(≈ FP32),精度略低
BF16 的优势:
- 更大的动态范围,不容易 overflow
- 几乎和 FP32 一样的数值范围
- LLM 训练推荐用 BF16# PyTorch 混合精度训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
with autocast(dtype=torch.bfloat16):
outputs = model(**batch)
loss = criterion(outputs.logits, labels)
# 缩放 loss,防止下溢
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
# 梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()Dynamic Loss Scale:
# 如果梯度全是 inf/NaN,降低 loss scale
# 如果连续 N 步没有 inf/NaN,尝试提高 loss scale
class DynamicLossScaler:
def __init__(self, init_scale=2**16, growth_factor=2, backoff_factor=0.5):
self.scale = init_scale
self.growth_factor = growth_factor
self.backoff_factor = backoff_factor
self._growth_steps = 0
self._backoff_steps = 0
def scale(self, loss):
return loss * self.scale
def update(self, found_inf):
if found_inf:
# 减小 scale
self.scale *= self.backoff_factor
self._backoff_steps += 1
else:
self._growth_steps += 1
if self._growth_steps >= 2000:
self.scale *= self.growth_factor
self._growth_steps = 02.2.2 梯度裁剪:Global Norm
问题:早期训练时梯度可能爆炸
def clip_grad_norm_(parameters, max_norm, eps=1e-6):
"""
梯度裁剪:所有梯度的 L2 norm 超过 max_norm 时等比例缩放
"""
# 计算全局梯度范数
total_norm = 0.0
for p in parameters:
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# 裁剪系数
clip_coef = max_norm / (total_norm + eps)
if clip_coef < 1:
for p in parameters:
if p.grad is not None:
p.grad.data.mul_(clip_coef)
return total_norm使用:
# 通常设置 max_norm = 1.0 或更小
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)2.2.3 学习率 Warmup
为什么需要 Warmup?
问题:模型参数是随机初始化的,早期梯度可能很大
预热期间:让参数从随机状态逐步稳定,避免大幅震荡线性 Warmup + Cosine Decay:
class WarmupCosineScheduler:
def __init__(self, optimizer, warmup_steps, total_steps,
min_lr_ratio=0.1, base_lr=1e-3):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = base_lr * min_lr_ratio
self.base_lr = base_lr
self.current_step = 0
def step(self):
self.current_step += 1
if self.current_step <= self.warmup_steps:
# 线性 warmup
lr = self.base_lr * (self.current_step / self.warmup_steps)
else:
# Cosine decay
progress = (self.current_step - self.warmup_steps) / \
(self.total_steps - self.warmup_steps)
lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * \
(1 + math.cos(math.pi * progress))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
# 训练循环
scheduler = WarmupCosineScheduler(
optimizer,
warmup_steps=2000,
total_steps=100000,
min_lr_ratio=0.1,
base_lr=1e-3
)
for step in range(total_steps):
loss = train_step()
loss.backward()
optimizer.step()
scheduler.step()可视化:
学习率
↑
1e-3 │╲
│ ╲___
│ ╲
│ ╲____
│ ╲________
│ ╲____
min │ ╲____
└──────────────────────────────────→ Steps
0 warmup decay2.3 内存与速度优化
2.3.1 Gradient Checkpointing(梯度检查点)
核心思想:用时间换空间。不保存所有中间激活值,而是在反向传播时重新计算。
标准前向: O(N) 内存 (N = 层数)
Checkpointing: O(√N) 最优策略
策略:每隔 √N 层保存一个 checkpoint
反向时在这些点重新计算# PyTorch 实现
from torch.utils.checkpoint import checkpoint_sequential
# 方式1: 模块列表
model = nn.Sequential(*layers)
# 前向时只保存部分激活
output = checkpoint_sequential(model, 5, input) # 每5层一个checkpoint
# 方式2: 手动指定
class CheckpointedBlock(nn.Module):
def forward(self, x):
return checkpoint(self._forward, x)
def _forward(self, x):
return self.norm(self.attention(x))2.3.2 ZeRO:分布式训练中的状态分片
ZeRO Stage 1/2/3:
Stage 0: 无分片(标准 DDP)
所有参数、梯度、优化器状态都在每个 GPU 上
Stage 1: 分片优化器状态
每个 GPU 只保存 1/N 的优化器状态
Stage 2: 分片优化器状态 + 梯度
梯度也分片存储
Stage 3: 分片所有状态
参数、梯度、优化器状态全分片
通信量最大,但内存最少# DeepSpeed ZeRO 配置
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu", # 把优化器状态卸载到 CPU
"pin_memory": true
},
"offload_param": {
"device": "cpu"
}
}
}2.3.3 FSDP (Fully Sharded Data Parallel)
# PyTorch FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
# 训练
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # 对应 ZeRO-3
device_id=torch.cuda.current_device()
)2.4 长上下文预训练
2.4.1 Flash Attention 演进
Flash Attention 1: 分块计算 + IO 优化
标准 Attention: O(N²) 显存 (存储整个 N×N 注意力矩阵)
Flash Attention: O(N) 显存 (分块计算)
核心思想:
1. 把 Q, K, V 分成小块 (block_size = 64~128)
2. 每次只加载一块 K, V 到 SRAM
3. 累加计算 softmax (在线算法)# Flash Attention 核心思想
def flash_attention(Q, K, V, block_size=64):
"""
Q, K, V: (seq_len, head_dim)
"""
seq_len = Q.shape[0]
scale = 1 / math.sqrt(Q.shape[1])
# 初始化
m = torch.full((seq_len,), -torch.inf) # max scores
l = torch.zeros((seq_len,)) # sum of exp
O = torch.zeros_like(Q) # output
# 分块计算
for i in range(0, seq_len, block_size):
# 加载 Q 的一块
q_block = Q[i:i+block_size]
# 初始化这一块的累加器
m_block = torch.full((block_size,), -torch.inf)
l_block = torch.zeros((block_size,))
O_block = torch.zeros_like(q_block)
for j in range(0, seq_len, block_size):
# 加载 K, V 的一块
k_block = K[j:j+block_size]
v_block = V[j:j+block_size]
# 计算 attention block
s_block = q_block @ k_block.T * scale # (block, block)
# 更新 max 和 sum
m_new = torch.maximum(m_block, s_block.max(1, keepdim=True).values)
l_block = l_block * torch.exp(m_block - m_new) + \
torch.exp(s_block - m_new).sum(1)
O_block = O_block * torch.exp(m_block - m_new).unsqueeze(1) + \
(torch.exp(s_block - m_new) @ v_block).unsqueeze(1)
m_block = m_new
# 归一化
O_block = O_block / l_block.unsqueeze(1)
# 合并到最终输出
O[i:i+block_size] = O_block * torch.exp(m[i:i+block_size] - m_block) + \
O[i:i+block_size] * l[i:i+block_size]
l[i:i+block_size] = l[i:i+block_size] * torch.exp(m[i:i+block_size] - m_block) + l_block
# 最终归一化
O = O / l.unsqueeze(1)
return OFlash Attention 2: 支持 sequence parallelism
Flash 1: 单 GPU 内部分块
Flash 2: 支持多 GPU 之间的 sequence 并行
沿 sequence 维度切分
每 GPU 计算一部分 attention,最后 allgatherFlash Attention 3: Warp-specialized
Flash 3: 利用 NVIDIA H100 的新特性
- Tensor Memory Accelerator (TMA)
- Warp-level matrix multiply
- 进一步提升约 2 倍2.4.2 Ring Attention:多卡分布式 Attention
4 GPUs 处理 seq_len=8192:
GPU 0: Q[0:2048], K/V[0:2048]
GPU 1: Q[2048:4096], K/V[2048:4096]
...
Ring Attention:
1. 每 GPU 计算本地 Q × 本地 K^T
2. 沿 ring 传递 K, V 块,累加 attention
3. 最终每 GPU 有完整的 O[local]def ring_attention_forward(Q, K, V, num_gpus):
"""
Q, K, V 已按 sequence 维度分片到各 GPU
"""
rank = get_rank()
seq_per_gpu = Q.shape[1]
O_local = torch.zeros_like(Q)
l_local = torch.zeros(Q.shape[:2])
m_local = torch.full(Q.shape[:2], -torch.inf)
for step in range(num_gpus):
# 当前处理第 (rank + step) % num_gpus 块的 K, V
k_offset = ((rank + step) % num_gpus) * seq_per_gpu
k_slice = K[:, k_offset:k_offset+seq_per_gpu]
v_slice = V[:, k_offset:k_offset+seq_per_gpu]
# 计算本地 attention
s = torch.matmul(Q, k_slice.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
# 在线 softmax 更新
m_new, _ = torch.maximum(m_local, s.max(-1, keepdim=True).values).max(-2, keepdim=True)
# ... (同 Flash Attention 的在线 softmax)
# Ring 通信:传递 K, V 给下一个 GPU
if step < num_gpus - 1:
send(k_slice, (rank + 1) % num_gpus)
recv_k = recv((rank - 1) % num_gpus)
K = recv_k2.4.3 位置外推四件套
问题:训练时 seq_len=4096,推理想用 8192 怎么办?
方法1: Position Interpolation (PI)
线性插值:把 8192 个位置"压缩"到 4096 范围内
def position_interpolation(positions, scale):
"""
positions: [0, 1, 2, ..., 8191]
scale = 4096 / 8192 = 0.5
"""
return positions * scale # [0, 0.5, 1, ..., 4095.5]方法2: NTK-aware Scaling
非均匀缩放:高频位置精细缩放,低频位置粗略缩放
def ntk_scaling(freqs, scale, dim):
"""
NTK-aware: 不是线性缩放,而是按频率调整
"""
# 重建 freqs
base = 10000
freqs_base = 1 / (base ** (torch.arange(0, dim, 2).float() / dim))
# 非均匀缩放
freqs_low = freqs_base[:len(freqs)//4] / scale # 低频缩放少
freqs_high = freqs_base[len(freqs)//4:] / (scale ** 0.75) # 高频缩放多
return torch.cat([freqs_low, freqs_high])方法3: YaRN (Yet another RoPE extensioN)
def yarn_scaling(seq_len, original_seq_len, rope_dim, alpha=32, beta=32):
"""
YaRN: 温度缩放 + 注意力修正
"""
scale = original_seq_len / seq_len
# 温度缩放
def rope_scaling(self, dim, max_position, base=10000, alpha=32, beta=32):
# 扩展上下文长度
pos_freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# YaRN 缩放
low_freq_wavelen = self.base * alpha / (2 * math.pi)
high_freq_wavelen = 2 * math.pi / (beta * self.base)
# 缩放低频部分
# ...方法4: Dynamic Scaling
动态缩放因子:根据实际位置动态调整
def dynamic_position_scaling(positions, base_scale=2.0):
"""
Dynamic Scaling: 不同位置段用不同缩放因子
"""
scales = []
for pos in positions:
if pos < 4096:
scales.append(1.0)
elif pos < 8192:
scales.append(1.5)
else:
scales.append(base_scale)
return torch.tensor(scales)2.4.4 长上下文训练配置
# 训练长上下文模型的配置建议
config = {
# 位置编码
"rope_theta": 500000, # 更大的 base,RoPE 外推更好
"rope_scaling": {
"type": "yarn",
"alpha": 32,
"beta": 32
},
# Flash Attention
"use_flash_attn": True,
"attn_impl": "flash", # 或 "triton" (Fused Flash Attention)
# 梯度检查点
"gradient_checkpointing": True,
"checkpoint_every_n": 1, # 每层都 checkpoint
# 学习率
"learning_rate": 1e-5, # 长文本训练用更小的 LR
"warmup_steps": 200,
"min_lr_ratio": 0.1,
# 序列长度
"max_seq_len": 32768,
"sequence_parallel": True, # 如果多 GPU
}2.5 本章小结
┌─────────────────────────────────────────────────────────────┐
│ 预训练稳定性三板斧 │
├─────────────────────────────────────────────────────────────┤
│ 1. BF16 混合精度 │ 更大的动态范围,避免 overflow │
│ 2. GradClip │ 防止梯度爆炸,global norm 裁剪 │
│ 3. Warmup + Cosine │ 避免早期震荡,稳定收敛 │
├─────────────────────────────────────────────────────────────┤
│ 显存优化技术 │
├─────────────────────────────────────────────────────────────┤
│ Gradient Checkpoint │ O(N) → O(√N),用时间换空间 │
│ ZeRO-1/2/3 │ 分布式训练状态分片 │
│ FSDP │ PyTorch 原生全分片数据并行 │
├─────────────────────────────────────────────────────────────┤
│ 长上下文技术 │
├─────────────────────────────────────────────────────────────┤
│ Flash Attention │ IO 优化,分块计算,O(N²)→O(N) │
│ Ring Attention │ 多卡分布式 Attention │
│ Position Interp. │ 线性/非线性位置缩放 │
└─────────────────────────────────────────────────────────────┘显存估算公式
模型参数: 参数量 × 2 bytes (BF16)
优化器状态: 参数量 × 12 bytes (AdamW, 2 states + 1 momentum + 1 variance)
梯度: 参数量 × 2 bytes (BF16)
激活值: batch × seq_len × hidden_dim × layers × bytes_per_param
KV Cache: 2 × batch × seq_len × num_kv_heads × head_dim × bytes
= 2 × batch × seq_len × 2 × 128/head_dim × bytes (近似)面试高频问题
BF16 vs FP16 区别?
→ BF16 指数位更多,动态范围更大,不容易 overflow为什么需要 Warmup?
→ 随机初始化时梯度可能很大,warmup 让参数稳定后再大学习率Flash Attention 怎么省显存的?
→ 分块计算 + 在线 softmax,不需要存储 N×N 注意力矩阵Position Interpolation 原理?
→ 把长位置缩放到训练范围内,牺牲高频精度换取外推能力