Skip to content

现代大模型架构基础

本章目标:理解为什么大语言模型长现在这个样子,掌握核心组件的数学本质


1.1 位置编码深度解析:RoPE

1.1.1 绝对位置编码的问题

传统 Transformer 使用绝对位置编码(Sinusoidal 或 Learnable),将位置信息加到 Token 嵌入上:

Token_Embed + Position_Embed → 输入

问题在哪?

  1. 无法表达相对位置:相对位置在 Attention 计算中很重要,但绝对编码很难隐式建模
  2. 训练外推差:训练最大 2048,推理想用 4096?效果崩
  3. 计算冗余:每个位置都要学习独立向量

1.1.2 RoPE 的核心思想

旋转位置编码(Rotary Position Embedding) 的核心是:不是把位置信息加到 embedding 上,而是通过旋转来编码位置

不在 embedding 空间操作,而是在 Query 和 Key 做旋转!

数学直觉

二维平面上的旋转:

[cos(θ)  -sin(θ)] [q_0]   = 旋转后的 q
[sin(θ)   cos(θ)] [q_1]

如果我们让 θ = m * θ_base,m 是位置
那么两个位置 m 和 n 的旋转角度差 = (m-n) * θ_base
旋转角度差只依赖于相对位置 (m-n)!

1.1.3 RoPE 公式推导

二维情况

对于 Query 向量 $\mathbf{q}$ 和 Key 向量 $\mathbf{k}$,它们的第 $(i, i+1)$ 个维度组成一个二维子空间。

RoPE 的做法是:把位置 $m$ 的 $\mathbf{q}$ 旋转 $m\theta$ 角度:

$$ \mathbf{q}'i = \begin{bmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \ \sin(m\theta_i) & \cos(m\theta_i) \end{bmatrix} \begin{bmatrix} q \ q_{2i+1} \end{bmatrix} $$

其中 $\theta_i = b^{-2i/d}$,$b$ 通常取 $10000$。

点积性质

旋转后的 q 和 k 做点积:

⟨R_m(q), R_n(k)⟩ = ⟨q, k⟩ 旋转到 (m-n) 角度

Attention score = ⟨R_m(q), R_n(k)⟩ 
                 = f(q, k, m-n)  ← 只依赖相对位置!

这意味着:RoPE 天然编码了相对位置信息!

1.1.4 多头 RoPE 实现

python
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    预计算旋转角度的复数形式
    dim: head_dim
    end: 最大序列长度
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end)
    freqs = torch.outer(t, freqs)  # (seq_len, dim//2)
    
    # 转为复数形式: e^(i*freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """
    x: (batch, seq_len, num_heads, head_dim)
    freqs_cis: (seq_len, head_dim//2) 复数
    """
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    x_rotated = x_complex * freqs_cis.unsqueeze(0).unsqueeze(2)
    return torch.view_as_real(x_rotated).flatten(-2).type_as(x)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=4096, theta=10000.0):
        super().__init__()
        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len, theta)
    
    def forward(self, x):
        # x: (batch, seq_len, num_heads, head_dim)
        self.freqs_cis = self.freqs_cis.to(x.device)
        return apply_rotary_emb(x, self.freqs_cis)

1.1.5 RoPE 的远程衰减特性

RoPE 的一个神奇特性:位置越远,attention 分数衰减

远程衰减(Long-Range Decay):

Attention Score

    │╲
    │  ╲
    │    ╲
    │     ╲___  ← 逐渐衰减
    │        ╲____
    └──────────────────→ 相对距离

为什么有用?

  • 远处的 token 贡献自然减小
  • 类似自然语言中"越远越不相关"的直觉
  • 有助于处理超长上下文

1.1.6 与其他位置编码对比

特性RoPE (Rotary)ALiBiSinusoidalLearnable
相对位置编码✅ 天然✅ 线性
外推能力✅ 好✅ 好❌ 差❌ 差
计算效率✅ 高✅ 高❌ 需查表
显存开销✅ 小✅ 小❌ 大❌ 大
适用场景LLaMA, GLMBLOOM原始 Transformer短上下文
实现复杂度

1.2 注意力机制演进:MHA → GQA → MQA

1.2.1 标准 Multi-Head Attention (MHA)

每个 head 有独立的 Wq, Wk, Wv

Head_i: Q_i = X · Wq_i, K_i = X · Wk_i, V_i = X · Wv_i
Output_i = Attention(Q_i, K_i, V_i)
Final = Concat(Output_1, ..., Output_h) · Wo

显存消耗分析

对于每个 token,KV Cache 需要存储:

每个 head: K_i, V_i 各 (seq_len, head_dim)
总共 h 个 head

KV Cache 显存

KV_Cache = 2 × batch × seq_len × num_heads × head_dim × bytes_per_param

以 LLaMA-7B 为例:

  • num_heads = 32
  • head_dim = 128
  • BF16 = 2 bytes
单 token KV Cache = 2 × 32 × 128 × 2 = 16KB
65B 模型:num_heads = 64, head_dim = 128
单 token KV Cache = 2 × 64 × 128 × 2 = 32KB

推理时,KV Cache 是显存的主要瓶颈!

1.2.2 Multi-Query Attention (MQA)

核心思想:所有 head 共享同一份 K 和 V

Q_1, Q_2, ..., Q_h ← 独立(每个 head 自己)
K, V ← 所有 head 共享一份!
python
class MQA(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.num_heads = num_heads
        # Q: 每个 head 独立
        self.Wq = nn.Linear(hidden_dim, num_heads * head_dim)
        # K, V: 所有 head 共享
        self.Wk = nn.Linear(hidden_dim, head_dim)  # 不是 num_heads * head_dim!
        self.Wv = nn.Linear(hidden_dim, head_dim)
        
    def forward(self, x):
        q = self.Wq(x).view(batch, seq, self.num_heads, head_dim)
        k = self.Wk(x).unsqueeze(1)  # (batch, 1, 1, head_dim)
        v = self.Wv(x).unsqueeze(1)
        
        # Q 扩展,K/V 广播
        q = q.transpose(1, 2)  # (batch, num_heads, seq, head_dim)
        k = k.expand(batch, self.num_heads, seq, head_dim)  # 广播
        v = v.expand(batch, self.num_heads, seq, head_dim)  # 广播

参数量变化

MHA:  K_params = num_heads × head_dim × embed_dim
MQA:  K_params = 1 × head_dim × embed_dim

LLaMA-7B:
- MHA: 32 × 128 × 4096 = 16,777,216
- MQA: 1 × 128 × 4096 = 524,288
→ 减少 32 倍!

1.2.3 Grouped-Query Attention (GQA)

核心思想:介于 MHA 和 MQA 之间,$g$ 个 head 共享一组 K/V

GQA (g=4, h=32):
- Q: 32 个独立的 head
- K, V: 分成 8 组,每组 4 个 head 共享一组

LLaMA 2 用的就是 GQA

模型num_headsnum_kv_heads每组 head 数
LLaMA 2-7B32321 (MHA)
LLaMA 2-13B40401 (MHA)
LLaMA 2-34B4886 (GQA)
LLaMA 3-8B3284 (GQA)
LLaMA 3-70B6488 (GQA)

1.2.4 GQA/MQA 的实现细节

python
def repeat_kv(x: torch.Tensor, num_repeats: int) -> torch.Tensor:
    """
    将 (batch, num_kv_heads, seq, head_dim)
    扩展为 (batch, num_heads, seq, head_dim)
    """
    batch, num_kv, seq, head_dim = x.shape
    if num_repeats == 1:
        return x
    
    # (batch, num_kv, seq, 1, head_dim) → (batch, num_kv, seq, num_repeats, head_dim)
    # → (batch, num_kv * num_repeats, seq, head_dim)
    x = x[:, :, None, :, :].expand(batch, num_kv, num_repeats, seq, head_dim)
    return x.reshape(batch, num_kv * num_repeats, seq, head_dim)


class GQAAttention(nn.Module):
    def __init__(self, num_heads, num_kv_heads, head_dim):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_repeats = num_heads // num_kv_heads
        
        self.Wq = nn.Linear(hidden, num_heads * head_dim)
        self.Wk = nn.Linear(hidden, num_kv_heads * head_dim)
        self.Wv = nn.Linear(hidden, num_kv_heads * head_dim)
    
    def forward(self, x):
        b, seq_len, _ = x.shape
        
        q = self.Wq(x).view(b, seq_len, self.num_heads, self.head_dim)
        k = self.Wk(x).view(b, seq_len, self.num_kv_heads, self.head_dim)
        v = self.Wv(x).view(b, seq_len, self.num_kv_heads, self.head_dim)
        
        # 扩展 K, V 到 num_heads 维度
        k = repeat_kv(k.transpose(1, 2), self.num_repeats)  # (b, h, seq, d)
        v = repeat_kv(v.transpose(1, 2), self.num_repeats)
        
        # 继续 attention 计算...

1.2.5 显存 vs 质量权衡

                    显存占用

                      │          MHA (最大)
                      │     GQA
                      │         MQA
                      │              (最小)
                      │__________________→ 质量
                              ←―――――→
                             GQA 在质量和效率间平衡

经验法则

  • num_kv_heads = num_heads / 4num_heads / 8 是比较好的平衡点
  • 减少到 1 会损失性能(LLaMA 2 34B 用 8 而不是 1)
  • GQA 相比 MHA,KV Cache 减少 4-8 倍

1.3 前馈网络设计:SwiGLU

1.3.1 标准 FFN

经典 FFN(两层全连接):

FFN(x) = max(0, xW₁)W₂ + b

等价于:
FFN(x) = GeLU(xW₁)W₂

问题:ReLU 会"kill"一半的神经元(输出为0),对于深层网络可能导致信息丢失。

1.3.2 SwiGLU 的 Gated 机制

Swish 函数

swish(x) = x · sigmoid(x) = x / (1 + e^(-x))

SwiGLU(Swish-Gated Linear Unit):

SwiGLU(x) = Swish₁(xW₁) ⊗ (xV)
          = (xW₁ · sigmoid(xW₁)) ⊗ (xV)
python
class SwiGLU(nn.Module):
    def __init__(self, hidden_dim, intermediate_size):
        super().__init__()
        self.w1 = nn.Linear(hidden_dim, intermediate_size, bias=False)
        self.w2 = nn.Linear(intermediate_size, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, intermediate_size, bias=False)  # 额外的门控
        
    def forward(self, x):
        # SwiGLU: SiLU(xW₁) ⊗ (xW₃)
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

1.3.3 为什么 SwiGLU 更好

ReLU:     f(x) = max(0, x)     → 梯度要么是 0 要么是 1
Swish:    f(x) = x·sigmoid(x)  → 平滑的梯度流

           ReLU        Swish
            │            │
      ┌─────┴────┐  ┌────┴────┐
      │          │  │  ╱╲    │
      │          │  │ ╱  ╲   │
─────┴──────────┘──┴────────────
 0    x            0    x

优势

  1. 平滑梯度:Swish 在 0 附近有非零梯度,避免神经元"死亡"
  2. 自适应门控:门控值由输入本身决定,不是固定的
  3. 信息流动:两层变三层,但表达能力更强

1.3.4 与其他激活函数对比

激活函数公式特点使用场景
ReLUmax(0, x)简单,但有 dead relu经典 CNN
GELUx·Φ(x)平滑,Transformer 默认BERT
Swishx·sigmoid(x)自门控,平滑炼丹经验
SwiGLUSwish₁(xW₁)⊗xVSwish + 门控LLaMA, GLM
GeGLUGeLU(xW₁)⊗xVGELU + 门控-

1.4 归一化技术:RMSNorm 与 DeepNorm

1.4.1 LayerNorm 的问题

标准 LayerNorm:

python
def layer_norm(x, gamma, beta, eps=1e-6):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + eps)
    return gamma * x_norm + beta

计算开销

  • 需要计算均值 mean 和方差 var
  • 两个 reduce 操作(求和、求平方和)
  • 对于超长序列,这是不小的开销

1.4.2 RMSNorm:只算 RMS

核心观察:LayerNorm 的 mean 项其实没那么重要!

RMSNorm(Root Mean Square Norm):

$$ \overline{a}_i = \frac{a_i}{\text{RMS}(\mathbf{a})} * \gamma_i, \quad \text{RMS}(\mathbf{a}) = \sqrt{\frac{1}{n}\sum_i a_i^2} $$

python
def rms_norm(x, gamma, eps=1e-6):
    rms = x.pow(2).mean(dim=-1, keepdim=True).add(eps).sqrt()
    return (x / rms) * gamma

省去了 mean 的计算!

1.4.3 RMSNorm vs LayerNorm 对比

LayerNorm:  需要计算 mean + var
RMSNorm:    只计算 RMS (var 的平方根)

速度提升: ~10-30%(取决于序列长度)

1.4.4 DeepNorm:深层网络稳定性

对于极深(如 100+ 层)的 Transformer,标准 LayerNorm 可能不够稳定。

DeepNorm(微软 Phi 模型、GLM 系列使用):

$$ \text{DeepNorm}(x) = \alpha \cdot \text{LayerNorm}(x + f(x)) $$

其中 $f(x)$ 是残差分支(如 Multi-Head Attention 或 FFN)。

稳定性分析

标准 Pre-LN: LayerNorm(x + SubLayer(x))
DeepNorm:    α · LayerNorm(x + SubLayer(x))

α 通常取 0.8~1.0(取决于模型深度)

为什么有效

  • 残差分支的输出被缩放
  • 防止深层网络的方差爆炸
  • 结合 Post-LN 的稳定性和 Pre-LN 的效率

1.5 MoE 架构专题

1.5.1 MoE 核心思想

Sparse MoE:不是所有 token 都激活所有参数,而是只激活"专家"的一部分。

标准 Dense:  每个 token 通过所有 FFN
MoE:         每个 token 只通过 Top-K 个专家
        Token 输入


        [Router] ──→ 决定激活哪些专家

      ┌─────┼─────┐
      ▼     ▼     ▼
    ┌───┐ ┌───┐ ┌───┐
    │ E │ │ E │ │ E │  专家 FFN
    │ 1 │ │ 2 │ │ 3 │
    └───┘ └───┘ └───┘
      │     │     │
      └─────┼─────┘


        输出加权

1.5.2 Top-K Router

python
class MoELayer(nn.Module):
    def __init__(self, num_experts, top_k):
        super().__init__()
        self.router = nn.Linear(hidden_dim, num_experts)
        self.top_k = top_k
        
    def forward(self, x):
        # x: (batch, seq, hidden)
        b, seq, h = x.shape
        
        # Router 计算每个专家的分数
        router_logits = self.router(x)  # (b, seq, num_experts)
        weights = F.softmax(router_logits, dim=-1)
        
        # Top-K 选择
        top_weights, top_idx = torch.topk(weights, self.top_k, dim=-1)
        top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)  # 归一化
        
        # 对每个 token,只更新被选中的专家
        output = torch.zeros(b, seq, h, device=x.device, dtype=x.dtype)
        
        for i in range(self.top_k):
            expert = self.experts[top_idx[:, :, i]]  # 选择专家
            output += top_weights[:, :, i:i+1] * expert(x)
        
        return output

1.5.3 负载均衡 (Load Balancing)

问题:Router 可能总是选同样的专家,导致负载不均。

Aux Loss 解决方案

python
def load_balancing_loss(router_probs, top_k_ids, num_experts):
    """
    router_probs: 每个 expert 被选中的概率
    top_k_ids: 每个 token 选中的 top-k experts
    """
    # 1. 计算每个 expert 被选中的频率
    expert_counts = torch.zeros(num_experts, device=router_probs.device)
    for i in range(top_k_ids.shape[2]):
        expert_counts.scatter_add_(0, top_k_ids[:, :, i].flatten(), 
                                   torch.ones_like(top_k_ids[:, :, i].flatten().float()))
    expert_freq = expert_counts / (top_k_ids.shape[0] * top_k_ids.shape[1] * top_k_ids.shape[2])
    
    # 2. 计算 Router 概率均值
    router_mean = router_probs.mean(dim=[0, 1])
    
    # 3. 辅助损失 = sum(router_mean * expert_freq)
    # 最小化这个损失 = 均匀分配
    return num_experts * (router_mean * expert_freq).sum()

1.5.4 Expert Parallelism (EP)

Standard: 每个 GPU 有完整模型
Expert Parallel: 每个 GPU 只有部分专家

4 GPUs, 8 Experts:
GPU 0: Expert 0, 1
GPU 1: Expert 2, 3
GPU 2: Expert 4, 5
GPU 3: Expert 6, 7

Token routing 到不同 GPU → All-to-All 通信

1.5.5 DeepSeek-MoE 的细粒度设计

DeepSeek-V2 的创新

  1. 细粒度专家分割:不是 8 个大专家,而是拆成 64 个小专家
  2. 共享专家:从所有专家中选出 top-k,但有 2 个"共享专家"总是被激活
  3. 设备限制路由:减少跨设备通信
传统 MoE:    [E₁][E₂][E₃][E₄]     Top-2 = 激活 2 个
DeepSeek:    [e][e][e][e][e][e]...  Top-8 = 激活 8 个(更细粒度)

1.6 本章小结

┌─────────────────────────────────────────────────────────────┐
│                    Transformer 核心组件                        │
├─────────────────────────────────────────────────────────────┤
│  位置编码        │  RoPE  │  旋转编码,相对位置,外推好        │
├─────────────────────────────────────────────────────────────┤
│  注意力         │  GQA   │  减少 KV Cache,平衡质量与效率    │
├─────────────────────────────────────────────────────────────┤
│  FFN           │ SwiGLU │  门控机制,平滑梯度流             │
├─────────────────────────────────────────────────────────────┤
│  归一化         │ RMSNorm│  只算 RMS,省去 mean 计算         │
├─────────────────────────────────────────────────────────────┤
│  超深网络       │ DeepNorm│  残差缩放,稳定性保证             │
├─────────────────────────────────────────────────────────────┤
│  超大模型       │   MoE  │  稀疏激活,专家路由               │
└─────────────────────────────────────────────────────────────┘

关键公式速查

技术公式
RoPE 旋转$R_{m,d} = \begin{bmatrix} \cos(m\theta_d) & -\sin(m\theta_d) \ \sin(m\theta_d) & \cos(m\theta_d) \end{bmatrix}$
RMSNorm$\bar{a}_i = \frac{a_i}{\sqrt{\frac{1}{n}\sum_i a_i^2}} \cdot \gamma_i$
SwiGLU$\text{SwiGLU}(x) = \text{Swish}_1(xW_1) \otimes (xW_3)$
DeepNorm$\text{DeepNorm}(x) = \alpha \cdot \text{LN}(x + f(x))$

面试高频问题

  1. RoPE 为什么能外推?
    → 旋转矩阵编码相对位置,通过插值或 NTK 方法处理超长序列

  2. GQA 省多少显存?
    → KV Cache 减少 $num_heads / num_kv_heads$ 倍

  3. SwiGLU 为什么比 ReLU 好?
    → 平滑梯度流 + 自适应门控

  4. MoE 的负载均衡问题怎么解决?
    → Auxiliary Loss 强制均匀分配

基于 MIT 许可发布