现代大模型架构基础
本章目标:理解为什么大语言模型长现在这个样子,掌握核心组件的数学本质
1.1 位置编码深度解析:RoPE
1.1.1 绝对位置编码的问题
传统 Transformer 使用绝对位置编码(Sinusoidal 或 Learnable),将位置信息加到 Token 嵌入上:
Token_Embed + Position_Embed → 输入问题在哪?
- 无法表达相对位置:相对位置在 Attention 计算中很重要,但绝对编码很难隐式建模
- 训练外推差:训练最大 2048,推理想用 4096?效果崩
- 计算冗余:每个位置都要学习独立向量
1.1.2 RoPE 的核心思想
旋转位置编码(Rotary Position Embedding) 的核心是:不是把位置信息加到 embedding 上,而是通过旋转来编码位置。
不在 embedding 空间操作,而是在 Query 和 Key 做旋转!数学直觉:
二维平面上的旋转:
[cos(θ) -sin(θ)] [q_0] = 旋转后的 q
[sin(θ) cos(θ)] [q_1]
如果我们让 θ = m * θ_base,m 是位置
那么两个位置 m 和 n 的旋转角度差 = (m-n) * θ_base
旋转角度差只依赖于相对位置 (m-n)!1.1.3 RoPE 公式推导
二维情况:
对于 Query 向量 $\mathbf{q}$ 和 Key 向量 $\mathbf{k}$,它们的第 $(i, i+1)$ 个维度组成一个二维子空间。
RoPE 的做法是:把位置 $m$ 的 $\mathbf{q}$ 旋转 $m\theta$ 角度:
$$ \mathbf{q}'i = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \begin{bmatrix} q \ q_{2i+1} \end{bmatrix} $$
其中 $\theta_i = b^{-2i/d}$,$b$ 通常取 $10000$。
点积性质:
旋转后的 q 和 k 做点积:
⟨R_m(q), R_n(k)⟩ = ⟨q, k⟩ 旋转到 (m-n) 角度
Attention score = ⟨R_m(q), R_n(k)⟩
= f(q, k, m-n) ← 只依赖相对位置!这意味着:RoPE 天然编码了相对位置信息!
1.1.4 多头 RoPE 实现
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""
预计算旋转角度的复数形式
dim: head_dim
end: 最大序列长度
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs) # (seq_len, dim//2)
# 转为复数形式: e^(i*freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, num_heads, head_dim)
freqs_cis: (seq_len, head_dim//2) 复数
"""
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * freqs_cis.unsqueeze(0).unsqueeze(2)
return torch.view_as_real(x_rotated).flatten(-2).type_as(x)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=4096, theta=10000.0):
super().__init__()
self.freqs_cis = precompute_freqs_cis(dim, max_seq_len, theta)
def forward(self, x):
# x: (batch, seq_len, num_heads, head_dim)
self.freqs_cis = self.freqs_cis.to(x.device)
return apply_rotary_emb(x, self.freqs_cis)1.1.5 RoPE 的远程衰减特性
RoPE 的一个神奇特性:位置越远,attention 分数衰减
远程衰减(Long-Range Decay):
Attention Score
│
│╲
│ ╲
│ ╲
│ ╲___ ← 逐渐衰减
│ ╲____
└──────────────────→ 相对距离为什么有用?
- 远处的 token 贡献自然减小
- 类似自然语言中"越远越不相关"的直觉
- 有助于处理超长上下文
1.1.6 与其他位置编码对比
| 特性 | RoPE (Rotary) | ALiBi | Sinusoidal | Learnable |
|---|---|---|---|---|
| 相对位置编码 | ✅ 天然 | ✅ 线性 | ❌ | ❌ |
| 外推能力 | ✅ 好 | ✅ 好 | ❌ 差 | ❌ 差 |
| 计算效率 | ✅ 高 | ✅ 高 | ❌ 需查表 | ✅ |
| 显存开销 | ✅ 小 | ✅ 小 | ❌ 大 | ❌ 大 |
| 适用场景 | LLaMA, GLM | BLOOM | 原始 Transformer | 短上下文 |
| 实现复杂度 | 中 | 低 | 低 | 低 |
1.2 注意力机制演进:MHA → GQA → MQA
1.2.1 标准 Multi-Head Attention (MHA)
每个 head 有独立的 Wq, Wk, Wv
Head_i: Q_i = X · Wq_i, K_i = X · Wk_i, V_i = X · Wv_i
Output_i = Attention(Q_i, K_i, V_i)
Final = Concat(Output_1, ..., Output_h) · Wo显存消耗分析:
对于每个 token,KV Cache 需要存储:
每个 head: K_i, V_i 各 (seq_len, head_dim)
总共 h 个 headKV Cache 显存:
KV_Cache = 2 × batch × seq_len × num_heads × head_dim × bytes_per_param以 LLaMA-7B 为例:
- num_heads = 32
- head_dim = 128
- BF16 = 2 bytes
单 token KV Cache = 2 × 32 × 128 × 2 = 16KB
65B 模型:num_heads = 64, head_dim = 128
单 token KV Cache = 2 × 64 × 128 × 2 = 32KB推理时,KV Cache 是显存的主要瓶颈!
1.2.2 Multi-Query Attention (MQA)
核心思想:所有 head 共享同一份 K 和 V
Q_1, Q_2, ..., Q_h ← 独立(每个 head 自己)
K, V ← 所有 head 共享一份!class MQA(nn.Module):
def __init__(self, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
# Q: 每个 head 独立
self.Wq = nn.Linear(hidden_dim, num_heads * head_dim)
# K, V: 所有 head 共享
self.Wk = nn.Linear(hidden_dim, head_dim) # 不是 num_heads * head_dim!
self.Wv = nn.Linear(hidden_dim, head_dim)
def forward(self, x):
q = self.Wq(x).view(batch, seq, self.num_heads, head_dim)
k = self.Wk(x).unsqueeze(1) # (batch, 1, 1, head_dim)
v = self.Wv(x).unsqueeze(1)
# Q 扩展,K/V 广播
q = q.transpose(1, 2) # (batch, num_heads, seq, head_dim)
k = k.expand(batch, self.num_heads, seq, head_dim) # 广播
v = v.expand(batch, self.num_heads, seq, head_dim) # 广播参数量变化:
MHA: K_params = num_heads × head_dim × embed_dim
MQA: K_params = 1 × head_dim × embed_dim
LLaMA-7B:
- MHA: 32 × 128 × 4096 = 16,777,216
- MQA: 1 × 128 × 4096 = 524,288
→ 减少 32 倍!1.2.3 Grouped-Query Attention (GQA)
核心思想:介于 MHA 和 MQA 之间,$g$ 个 head 共享一组 K/V
GQA (g=4, h=32):
- Q: 32 个独立的 head
- K, V: 分成 8 组,每组 4 个 head 共享一组LLaMA 2 用的就是 GQA:
| 模型 | num_heads | num_kv_heads | 每组 head 数 |
|---|---|---|---|
| LLaMA 2-7B | 32 | 32 | 1 (MHA) |
| LLaMA 2-13B | 40 | 40 | 1 (MHA) |
| LLaMA 2-34B | 48 | 8 | 6 (GQA) |
| LLaMA 3-8B | 32 | 8 | 4 (GQA) |
| LLaMA 3-70B | 64 | 8 | 8 (GQA) |
1.2.4 GQA/MQA 的实现细节
def repeat_kv(x: torch.Tensor, num_repeats: int) -> torch.Tensor:
"""
将 (batch, num_kv_heads, seq, head_dim)
扩展为 (batch, num_heads, seq, head_dim)
"""
batch, num_kv, seq, head_dim = x.shape
if num_repeats == 1:
return x
# (batch, num_kv, seq, 1, head_dim) → (batch, num_kv, seq, num_repeats, head_dim)
# → (batch, num_kv * num_repeats, seq, head_dim)
x = x[:, :, None, :, :].expand(batch, num_kv, num_repeats, seq, head_dim)
return x.reshape(batch, num_kv * num_repeats, seq, head_dim)
class GQAAttention(nn.Module):
def __init__(self, num_heads, num_kv_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_repeats = num_heads // num_kv_heads
self.Wq = nn.Linear(hidden, num_heads * head_dim)
self.Wk = nn.Linear(hidden, num_kv_heads * head_dim)
self.Wv = nn.Linear(hidden, num_kv_heads * head_dim)
def forward(self, x):
b, seq_len, _ = x.shape
q = self.Wq(x).view(b, seq_len, self.num_heads, self.head_dim)
k = self.Wk(x).view(b, seq_len, self.num_kv_heads, self.head_dim)
v = self.Wv(x).view(b, seq_len, self.num_kv_heads, self.head_dim)
# 扩展 K, V 到 num_heads 维度
k = repeat_kv(k.transpose(1, 2), self.num_repeats) # (b, h, seq, d)
v = repeat_kv(v.transpose(1, 2), self.num_repeats)
# 继续 attention 计算...1.2.5 显存 vs 质量权衡
显存占用
↑
│ MHA (最大)
│ GQA
│ MQA
│ (最小)
│__________________→ 质量
←―――――→
GQA 在质量和效率间平衡经验法则:
num_kv_heads = num_heads / 4到num_heads / 8是比较好的平衡点- 减少到 1 会损失性能(LLaMA 2 34B 用 8 而不是 1)
- GQA 相比 MHA,KV Cache 减少 4-8 倍
1.3 前馈网络设计:SwiGLU
1.3.1 标准 FFN
经典 FFN(两层全连接):
FFN(x) = max(0, xW₁)W₂ + b
等价于:
FFN(x) = GeLU(xW₁)W₂问题:ReLU 会"kill"一半的神经元(输出为0),对于深层网络可能导致信息丢失。
1.3.2 SwiGLU 的 Gated 机制
Swish 函数:
swish(x) = x · sigmoid(x) = x / (1 + e^(-x))SwiGLU(Swish-Gated Linear Unit):
SwiGLU(x) = Swish₁(xW₁) ⊗ (xV)
= (xW₁ · sigmoid(xW₁)) ⊗ (xV)class SwiGLU(nn.Module):
def __init__(self, hidden_dim, intermediate_size):
super().__init__()
self.w1 = nn.Linear(hidden_dim, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, intermediate_size, bias=False) # 额外的门控
def forward(self, x):
# SwiGLU: SiLU(xW₁) ⊗ (xW₃)
return self.w2(F.silu(self.w1(x)) * self.w3(x))1.3.3 为什么 SwiGLU 更好
ReLU: f(x) = max(0, x) → 梯度要么是 0 要么是 1
Swish: f(x) = x·sigmoid(x) → 平滑的梯度流
ReLU Swish
│ │
┌─────┴────┐ ┌────┴────┐
│ │ │ ╱╲ │
│ │ │ ╱ ╲ │
─────┴──────────┘──┴────────────
0 x 0 x优势:
- 平滑梯度:Swish 在 0 附近有非零梯度,避免神经元"死亡"
- 自适应门控:门控值由输入本身决定,不是固定的
- 信息流动:两层变三层,但表达能力更强
1.3.4 与其他激活函数对比
| 激活函数 | 公式 | 特点 | 使用场景 |
|---|---|---|---|
| ReLU | max(0, x) | 简单,但有 dead relu | 经典 CNN |
| GELU | x·Φ(x) | 平滑,Transformer 默认 | BERT |
| Swish | x·sigmoid(x) | 自门控,平滑 | 炼丹经验 |
| SwiGLU | Swish₁(xW₁)⊗xV | Swish + 门控 | LLaMA, GLM |
| GeGLU | GeLU(xW₁)⊗xV | GELU + 门控 | - |
1.4 归一化技术:RMSNorm 与 DeepNorm
1.4.1 LayerNorm 的问题
标准 LayerNorm:
def layer_norm(x, gamma, beta, eps=1e-6):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_norm = (x - mean) / torch.sqrt(var + eps)
return gamma * x_norm + beta计算开销:
- 需要计算均值
mean和方差var - 两个 reduce 操作(求和、求平方和)
- 对于超长序列,这是不小的开销
1.4.2 RMSNorm:只算 RMS
核心观察:LayerNorm 的 mean 项其实没那么重要!
RMSNorm(Root Mean Square Norm):
$$ \overline{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} * \gamma_i, \quad \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n}\sum_i a_i^2} $$
def rms_norm(x, gamma, eps=1e-6):
rms = x.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
return (x / rms) * gamma省去了 mean 的计算!
1.4.3 RMSNorm vs LayerNorm 对比
LayerNorm: 需要计算 mean + var
RMSNorm: 只计算 RMS (var 的平方根)
速度提升: ~10-30%(取决于序列长度)1.4.4 DeepNorm:深层网络稳定性
对于极深(如 100+ 层)的 Transformer,标准 LayerNorm 可能不够稳定。
DeepNorm(微软 Phi 模型、GLM 系列使用):
$$ \text{DeepNorm}(x) = \alpha \cdot \text{LayerNorm}(x + f(x)) $$
其中 $f(x)$ 是残差分支(如 Multi-Head Attention 或 FFN)。
稳定性分析:
标准 Pre-LN: LayerNorm(x + SubLayer(x))
DeepNorm: α · LayerNorm(x + SubLayer(x))
α 通常取 0.8~1.0(取决于模型深度)为什么有效:
- 残差分支的输出被缩放
- 防止深层网络的方差爆炸
- 结合 Post-LN 的稳定性和 Pre-LN 的效率
1.5 MoE 架构专题
1.5.1 MoE 核心思想
Sparse MoE:不是所有 token 都激活所有参数,而是只激活"专家"的一部分。
标准 Dense: 每个 token 通过所有 FFN
MoE: 每个 token 只通过 Top-K 个专家 Token 输入
│
▼
[Router] ──→ 决定激活哪些专家
│
┌─────┼─────┐
▼ ▼ ▼
┌───┐ ┌───┐ ┌───┐
│ E │ │ E │ │ E │ 专家 FFN
│ 1 │ │ 2 │ │ 3 │
└───┘ └───┘ └───┘
│ │ │
└─────┼─────┘
│
▼
输出加权1.5.2 Top-K Router
class MoELayer(nn.Module):
def __init__(self, num_experts, top_k):
super().__init__()
self.router = nn.Linear(hidden_dim, num_experts)
self.top_k = top_k
def forward(self, x):
# x: (batch, seq, hidden)
b, seq, h = x.shape
# Router 计算每个专家的分数
router_logits = self.router(x) # (b, seq, num_experts)
weights = F.softmax(router_logits, dim=-1)
# Top-K 选择
top_weights, top_idx = torch.topk(weights, self.top_k, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) # 归一化
# 对每个 token,只更新被选中的专家
output = torch.zeros(b, seq, h, device=x.device, dtype=x.dtype)
for i in range(self.top_k):
expert = self.experts[top_idx[:, :, i]] # 选择专家
output += top_weights[:, :, i:i+1] * expert(x)
return output1.5.3 负载均衡 (Load Balancing)
问题:Router 可能总是选同样的专家,导致负载不均。
Aux Loss 解决方案:
def load_balancing_loss(router_probs, top_k_ids, num_experts):
"""
router_probs: 每个 expert 被选中的概率
top_k_ids: 每个 token 选中的 top-k experts
"""
# 1. 计算每个 expert 被选中的频率
expert_counts = torch.zeros(num_experts, device=router_probs.device)
for i in range(top_k_ids.shape[2]):
expert_counts.scatter_add_(0, top_k_ids[:, :, i].flatten(),
torch.ones_like(top_k_ids[:, :, i].flatten().float()))
expert_freq = expert_counts / (top_k_ids.shape[0] * top_k_ids.shape[1] * top_k_ids.shape[2])
# 2. 计算 Router 概率均值
router_mean = router_probs.mean(dim=[0, 1])
# 3. 辅助损失 = sum(router_mean * expert_freq)
# 最小化这个损失 = 均匀分配
return num_experts * (router_mean * expert_freq).sum()1.5.4 Expert Parallelism (EP)
Standard: 每个 GPU 有完整模型
Expert Parallel: 每个 GPU 只有部分专家
4 GPUs, 8 Experts:
GPU 0: Expert 0, 1
GPU 1: Expert 2, 3
GPU 2: Expert 4, 5
GPU 3: Expert 6, 7
Token routing 到不同 GPU → All-to-All 通信1.5.5 DeepSeek-MoE 的细粒度设计
DeepSeek-V2 的创新:
- 细粒度专家分割:不是 8 个大专家,而是拆成 64 个小专家
- 共享专家:从所有专家中选出 top-k,但有 2 个"共享专家"总是被激活
- 设备限制路由:减少跨设备通信
传统 MoE: [E₁][E₂][E₃][E₄] Top-2 = 激活 2 个
DeepSeek: [e][e][e][e][e][e]... Top-8 = 激活 8 个(更细粒度)1.6 本章小结
┌─────────────────────────────────────────────────────────────┐
│ Transformer 核心组件 │
├─────────────────────────────────────────────────────────────┤
│ 位置编码 │ RoPE │ 旋转编码,相对位置,外推好 │
├─────────────────────────────────────────────────────────────┤
│ 注意力 │ GQA │ 减少 KV Cache,平衡质量与效率 │
├─────────────────────────────────────────────────────────────┤
│ FFN │ SwiGLU │ 门控机制,平滑梯度流 │
├─────────────────────────────────────────────────────────────┤
│ 归一化 │ RMSNorm│ 只算 RMS,省去 mean 计算 │
├─────────────────────────────────────────────────────────────┤
│ 超深网络 │ DeepNorm│ 残差缩放,稳定性保证 │
├─────────────────────────────────────────────────────────────┤
│ 超大模型 │ MoE │ 稀疏激活,专家路由 │
└─────────────────────────────────────────────────────────────┘关键公式速查
| 技术 | 公式 |
|---|---|
| RoPE 旋转 | $R_{m,d} = \begin{bmatrix} \cos(m\theta_d) & -\sin(m\theta_d) \ \sin(m\theta_d) & \cos(m\theta_d) \end{bmatrix}$ |
| RMSNorm | $\bar{a}_i = \frac{a_i}{\sqrt{\frac{1}{n}\sum_i a_i^2}} \cdot \gamma_i$ |
| SwiGLU | $\text{SwiGLU}(x) = \text{Swish}_1(xW_1) \otimes (xW_3)$ |
| DeepNorm | $\text{DeepNorm}(x) = \alpha \cdot \text{LN}(x + f(x))$ |
面试高频问题
RoPE 为什么能外推?
→ 旋转矩阵编码相对位置,通过插值或 NTK 方法处理超长序列GQA 省多少显存?
→ KV Cache 减少 $num_heads / num_kv_heads$ 倍SwiGLU 为什么比 ReLU 好?
→ 平滑梯度流 + 自适应门控MoE 的负载均衡问题怎么解决?
→ Auxiliary Loss 强制均匀分配