Skip to content

预训练工程与稳定性

本章目标:掌握让模型"从头训练不崩"的工程技巧


2.1 数据工程基础

2.1.1 Tokenization:BPE 与 SentencePiece

Byte Pair Encoding (BPE):一种子词分词算法

原始语料: "hug", "pug", "pun", "bug", "bun", "hugs"

BPE 词表构建:
- 统计所有相邻字节对频率
- 合并最频繁的对
- 重复直到达到目标词表大小

典型 BPE 过程:
"low" = "l" "o" "w"
"low" 经过 BPE 后可能是 "lo" "w" 或 "l" "ow"
python
# 使用 SentencePiece (LLaMA 采用)
import sentencepiece as spm

# 训练 BPE 模型
spm.SentencePieceTrainer.train(
    input='corpus.txt',
    model_prefix='tokenizer',
    vocab_size=32000,
    model_type='bpe',
    character_coverage=1.0,
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3
)

# 使用
sp = sp.load('tokenizer.model')
tokens = sp.encode("Hello, world!", out_type=int)  # [1234, 45, 6789]
text = sp.decode(tokens)  # "Hello, world!"

为什么用 BPE 而不是字级别?

分词方式词表大小OOV 处理编码效率
字级别~50K
词级别~500K
BPE~32K

2.1.2 DoReMi:动态数据重采样

问题:不同文档的"有用程度"不同,均匀采样浪费算力。

DoReMi (Domain Reweighting with Minimax)

核心思想:用小模型学各领域的重要性,然后重采样训练大模型。

python
def doremi_resample(documents, domain_weights, alpha=0.5):
    """
    documents: [(text, domain), ...]
    domain_weights: 每个领域的采样权重
    alpha: 温度参数
    """
    # 计算采样概率
    probs = np.array([domain_weights[d] for _, d in documents])
    probs = probs ** alpha
    probs = probs / probs.sum()
    
    # 采样
    indices = np.random.choice(len(documents), size=len(documents), 
                              replace=True, p=probs)
    return [documents[i] for i in indices]

2.1.3 长文本语料构造

代码 vs 书籍的配比

LLaMA 训练数据分布:
- CommonCrawl (67%)
- C4 (15%)
- GitHub (4.5%)
- Wikipedia (4.5%)
- Books (4.5%)
- ArXiv (2.5%)
- Stack Exchange (2%)

长文本处理策略

python
def pack_documents(documents, max_seq_len, tokenizer):
    """
    将文档打包成固定长度序列
    """
    packed = []
    current = []
    current_len = 0
    
    for doc in documents:
        doc_tokens = tokenizer.encode(doc)
        
        if current_len + len(doc_tokens) + 1 <= max_seq_len:
            current.append(doc_tokens)
            current_len += len(doc_tokens) + 1  # +1 for separator
        else:
            # 填充或截断当前序列
            packed.append(pad_or_truncate(current, max_seq_len))
            current = [doc_tokens]
            current_len = len(doc_tokens)
    
    return packed

2.2 训练稳定性三板斧

2.2.1 混合精度训练:BF16 vs FP16

BF16 (Brain Float 16) vs FP16 (Float 16)

FP16:  1 sign | 5 exponent | 10 mantissa  = 16 bits
BF16:  1 sign | 8 exponent | 7 mantissa  = 16 bits

FP16:  指数小 → 数值范围小,但精度高
BF16:  指数大 → 数值范围大(≈ FP32),精度略低

BF16 的优势:
- 更大的动态范围,不容易 overflow
- 几乎和 FP32 一样的数值范围
- LLM 训练推荐用 BF16
python
# PyTorch 混合精度训练
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in dataloader:
    with autocast(dtype=torch.bfloat16):
        outputs = model(**batch)
        loss = criterion(outputs.logits, labels)
    
    # 缩放 loss,防止下溢
    scaled_loss = scaler.scale(loss)
    scaled_loss.backward()
    
    # 梯度裁剪
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    scaler.step(optimizer)
    scaler.update()

Dynamic Loss Scale

python
# 如果梯度全是 inf/NaN,降低 loss scale
# 如果连续 N 步没有 inf/NaN,尝试提高 loss scale
class DynamicLossScaler:
    def __init__(self, init_scale=2**16, growth_factor=2, backoff_factor=0.5):
        self.scale = init_scale
        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self._growth_steps = 0
        self._backoff_steps = 0
    
    def scale(self, loss):
        return loss * self.scale
    
    def update(self, found_inf):
        if found_inf:
            # 减小 scale
            self.scale *= self.backoff_factor
            self._backoff_steps += 1
        else:
            self._growth_steps += 1
            if self._growth_steps >= 2000:
                self.scale *= self.growth_factor
                self._growth_steps = 0

2.2.2 梯度裁剪:Global Norm

问题:早期训练时梯度可能爆炸

python
def clip_grad_norm_(parameters, max_norm, eps=1e-6):
    """
    梯度裁剪:所有梯度的 L2 norm 超过 max_norm 时等比例缩放
    """
    # 计算全局梯度范数
    total_norm = 0.0
    for p in parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    
    # 裁剪系数
    clip_coef = max_norm / (total_norm + eps)
    if clip_coef < 1:
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(clip_coef)
    
    return total_norm

使用

python
# 通常设置 max_norm = 1.0 或更小
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

2.2.3 学习率 Warmup

为什么需要 Warmup?

问题:模型参数是随机初始化的,早期梯度可能很大
      预热期间:让参数从随机状态逐步稳定,避免大幅震荡

线性 Warmup + Cosine Decay

python
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps, 
                 min_lr_ratio=0.1, base_lr=1e-3):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = base_lr * min_lr_ratio
        self.base_lr = base_lr
        self.current_step = 0
    
    def step(self):
        self.current_step += 1
        
        if self.current_step <= self.warmup_steps:
            # 线性 warmup
            lr = self.base_lr * (self.current_step / self.warmup_steps)
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / \
                      (self.total_steps - self.warmup_steps)
            lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * \
                 (1 + math.cos(math.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr


# 训练循环
scheduler = WarmupCosineScheduler(
    optimizer,
    warmup_steps=2000,
    total_steps=100000,
    min_lr_ratio=0.1,
    base_lr=1e-3
)

for step in range(total_steps):
    loss = train_step()
    loss.backward()
    optimizer.step()
    scheduler.step()

可视化

学习率

1e-3 │╲
     │  ╲___
     │      ╲
     │       ╲____
     │            ╲________
     │                      ╲____
min  │                          ╲____
     └──────────────────────────────────→ Steps
     0  warmup      decay

2.3 内存与速度优化

2.3.1 Gradient Checkpointing(梯度检查点)

核心思想:用时间换空间。不保存所有中间激活值,而是在反向传播时重新计算。

标准前向: O(N) 内存 (N = 层数)
Checkpointing: O(√N) 最优策略

策略:每隔 √N 层保存一个 checkpoint
      反向时在这些点重新计算
python
# PyTorch 实现
from torch.utils.checkpoint import checkpoint_sequential

# 方式1: 模块列表
model = nn.Sequential(*layers)
# 前向时只保存部分激活
output = checkpoint_sequential(model, 5, input)  # 每5层一个checkpoint

# 方式2: 手动指定
class CheckpointedBlock(nn.Module):
    def forward(self, x):
        return checkpoint(self._forward, x)
    
    def _forward(self, x):
        return self.norm(self.attention(x))

2.3.2 ZeRO:分布式训练中的状态分片

ZeRO Stage 1/2/3

Stage 0: 无分片(标准 DDP)
         所有参数、梯度、优化器状态都在每个 GPU 上

Stage 1: 分片优化器状态
         每个 GPU 只保存 1/N 的优化器状态

Stage 2: 分片优化器状态 + 梯度
         梯度也分片存储

Stage 3: 分片所有状态
         参数、梯度、优化器状态全分片
         通信量最大,但内存最少
python
# DeepSpeed ZeRO 配置
{
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",  # 把优化器状态卸载到 CPU
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu"
        }
    }
}

2.3.3 FSDP (Fully Sharded Data Parallel)

python
# PyTorch FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

# 训练
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # 对应 ZeRO-3
    device_id=torch.cuda.current_device()
)

2.4 长上下文预训练

2.4.1 Flash Attention 演进

Flash Attention 1: 分块计算 + IO 优化

标准 Attention: O(N²) 显存 (存储整个 N×N 注意力矩阵)
Flash Attention: O(N) 显存 (分块计算)

核心思想:
1. 把 Q, K, V 分成小块 (block_size = 64~128)
2. 每次只加载一块 K, V 到 SRAM
3. 累加计算 softmax (在线算法)
python
# Flash Attention 核心思想
def flash_attention(Q, K, V, block_size=64):
    """
    Q, K, V: (seq_len, head_dim)
    """
    seq_len = Q.shape[0]
    scale = 1 / math.sqrt(Q.shape[1])
    
    # 初始化
    m = torch.full((seq_len,), -torch.inf)  # max scores
    l = torch.zeros((seq_len,))              # sum of exp
    O = torch.zeros_like(Q)                  # output
    
    # 分块计算
    for i in range(0, seq_len, block_size):
        # 加载 Q 的一块
        q_block = Q[i:i+block_size]
        
        # 初始化这一块的累加器
        m_block = torch.full((block_size,), -torch.inf)
        l_block = torch.zeros((block_size,))
        O_block = torch.zeros_like(q_block)
        
        for j in range(0, seq_len, block_size):
            # 加载 K, V 的一块
            k_block = K[j:j+block_size]
            v_block = V[j:j+block_size]
            
            # 计算 attention block
            s_block = q_block @ k_block.T * scale  # (block, block)
            
            # 更新 max 和 sum
            m_new = torch.maximum(m_block, s_block.max(1, keepdim=True).values)
            l_block = l_block * torch.exp(m_block - m_new) + \
                     torch.exp(s_block - m_new).sum(1)
            O_block = O_block * torch.exp(m_block - m_new).unsqueeze(1) + \
                     (torch.exp(s_block - m_new) @ v_block).unsqueeze(1)
            
            m_block = m_new
        
        # 归一化
        O_block = O_block / l_block.unsqueeze(1)
        
        # 合并到最终输出
        O[i:i+block_size] = O_block * torch.exp(m[i:i+block_size] - m_block) + \
                            O[i:i+block_size] * l[i:i+block_size]
        l[i:i+block_size] = l[i:i+block_size] * torch.exp(m[i:i+block_size] - m_block) + l_block
    
    # 最终归一化
    O = O / l.unsqueeze(1)
    return O

Flash Attention 2: 支持 sequence parallelism

Flash 1: 单 GPU 内部分块
Flash 2: 支持多 GPU 之间的 sequence 并行
         沿 sequence 维度切分
         每 GPU 计算一部分 attention,最后 allgather

Flash Attention 3: Warp-specialized

Flash 3: 利用 NVIDIA H100 的新特性
- Tensor Memory Accelerator (TMA)
- Warp-level matrix multiply
- 进一步提升约 2 倍

2.4.2 Ring Attention:多卡分布式 Attention

4 GPUs 处理 seq_len=8192:

GPU 0: Q[0:2048], K/V[0:2048]
GPU 1: Q[2048:4096], K/V[2048:4096]
...

Ring Attention:
1. 每 GPU 计算本地 Q × 本地 K^T
2. 沿 ring 传递 K, V 块,累加 attention
3. 最终每 GPU 有完整的 O[local]
python
def ring_attention_forward(Q, K, V, num_gpus):
    """
    Q, K, V 已按 sequence 维度分片到各 GPU
    """
    rank = get_rank()
    seq_per_gpu = Q.shape[1]
    
    O_local = torch.zeros_like(Q)
    l_local = torch.zeros(Q.shape[:2])
    m_local = torch.full(Q.shape[:2], -torch.inf)
    
    for step in range(num_gpus):
        # 当前处理第 (rank + step) % num_gpus 块的 K, V
        k_offset = ((rank + step) % num_gpus) * seq_per_gpu
        k_slice = K[:, k_offset:k_offset+seq_per_gpu]
        v_slice = V[:, k_offset:k_offset+seq_per_gpu]
        
        # 计算本地 attention
        s = torch.matmul(Q, k_slice.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
        
        # 在线 softmax 更新
        m_new, _ = torch.maximum(m_local, s.max(-1, keepdim=True).values).max(-2, keepdim=True)
        # ... (同 Flash Attention 的在线 softmax)
        
        # Ring 通信:传递 K, V 给下一个 GPU
        if step < num_gpus - 1:
            send(k_slice, (rank + 1) % num_gpus)
            recv_k = recv((rank - 1) % num_gpus)
            K = recv_k

2.4.3 位置外推四件套

问题:训练时 seq_len=4096,推理想用 8192 怎么办?

方法1: Position Interpolation (PI)

线性插值:把 8192 个位置"压缩"到 4096 范围内

python
def position_interpolation(positions, scale):
    """
    positions: [0, 1, 2, ..., 8191]
    scale = 4096 / 8192 = 0.5
    """
    return positions * scale  # [0, 0.5, 1, ..., 4095.5]

方法2: NTK-aware Scaling

非均匀缩放:高频位置精细缩放,低频位置粗略缩放

python
def ntk_scaling(freqs, scale, dim):
    """
    NTK-aware: 不是线性缩放,而是按频率调整
    """
    # 重建 freqs
    base = 10000
    freqs_base = 1 / (base ** (torch.arange(0, dim, 2).float() / dim))
    
    # 非均匀缩放
    freqs_low = freqs_base[:len(freqs)//4] / scale  # 低频缩放少
    freqs_high = freqs_base[len(freqs)//4:] / (scale ** 0.75)  # 高频缩放多
    
    return torch.cat([freqs_low, freqs_high])

方法3: YaRN (Yet another RoPE extensioN)

python
def yarn_scaling(seq_len, original_seq_len, rope_dim, alpha=32, beta=32):
    """
    YaRN: 温度缩放 + 注意力修正
    """
    scale = original_seq_len / seq_len
    
    # 温度缩放
    def rope_scaling(self, dim, max_position, base=10000, alpha=32, beta=32):
        # 扩展上下文长度
        pos_freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        
        # YaRN 缩放
        low_freq_wavelen = self.base * alpha / (2 * math.pi)
        high_freq_wavelen = 2 * math.pi / (beta * self.base)
        
        # 缩放低频部分
        # ...

方法4: Dynamic Scaling

动态缩放因子:根据实际位置动态调整

python
def dynamic_position_scaling(positions, base_scale=2.0):
    """
    Dynamic Scaling: 不同位置段用不同缩放因子
    """
    scales = []
    for pos in positions:
        if pos < 4096:
            scales.append(1.0)
        elif pos < 8192:
            scales.append(1.5)
        else:
            scales.append(base_scale)
    return torch.tensor(scales)

2.4.4 长上下文训练配置

python
# 训练长上下文模型的配置建议
config = {
    # 位置编码
    "rope_theta": 500000,  # 更大的 base,RoPE 外推更好
    "rope_scaling": {
        "type": "yarn",
        "alpha": 32,
        "beta": 32
    },
    
    # Flash Attention
    "use_flash_attn": True,
    "attn_impl": "flash",  # 或 "triton" (Fused Flash Attention)
    
    # 梯度检查点
    "gradient_checkpointing": True,
    "checkpoint_every_n": 1,  # 每层都 checkpoint
    
    # 学习率
    "learning_rate": 1e-5,  # 长文本训练用更小的 LR
    "warmup_steps": 200,
    "min_lr_ratio": 0.1,
    
    # 序列长度
    "max_seq_len": 32768,
    "sequence_parallel": True,  # 如果多 GPU
}

2.5 本章小结

┌─────────────────────────────────────────────────────────────┐
│                    预训练稳定性三板斧                        │
├─────────────────────────────────────────────────────────────┤
│  1. BF16 混合精度    │  更大的动态范围,避免 overflow     │
│  2. GradClip         │  防止梯度爆炸,global norm 裁剪      │
│  3. Warmup + Cosine  │  避免早期震荡,稳定收敛              │
├─────────────────────────────────────────────────────────────┤
│                    显存优化技术                             │
├─────────────────────────────────────────────────────────────┤
│  Gradient Checkpoint │  O(N) → O(√N),用时间换空间        │
│  ZeRO-1/2/3         │  分布式训练状态分片                  │
│  FSDP               │  PyTorch 原生全分片数据并行           │
├─────────────────────────────────────────────────────────────┤
│                    长上下文技术                             │
├─────────────────────────────────────────────────────────────┤
│  Flash Attention    │  IO 优化,分块计算,O(N²)→O(N)     │
│  Ring Attention     │  多卡分布式 Attention                 │
│  Position Interp.  │  线性/非线性位置缩放                  │
└─────────────────────────────────────────────────────────────┘

显存估算公式

模型参数: 参数量 × 2 bytes (BF16)
优化器状态: 参数量 × 12 bytes (AdamW, 2 states + 1 momentum + 1 variance)
梯度: 参数量 × 2 bytes (BF16)
激活值: batch × seq_len × hidden_dim × layers × bytes_per_param

KV Cache: 2 × batch × seq_len × num_kv_heads × head_dim × bytes
       = 2 × batch × seq_len × 2 × 128/head_dim × bytes (近似)

面试高频问题

  1. BF16 vs FP16 区别?
    → BF16 指数位更多,动态范围更大,不容易 overflow

  2. 为什么需要 Warmup?
    → 随机初始化时梯度可能很大,warmup 让参数稳定后再大学习率

  3. Flash Attention 怎么省显存的?
    → 分块计算 + 在线 softmax,不需要存储 N×N 注意力矩阵

  4. Position Interpolation 原理?
    → 把长位置缩放到训练范围内,牺牲高频精度换取外推能力

基于 MIT 许可发布