Skip to content

推理系统与量化部署

本章目标:掌握生产环境的吞吐优化与模型压缩


6.1 KV Cache 工程

6.1.1 KV Cache 显存计算

python
def calculate_kv_cache_memory(
    num_layers: int,
    num_kv_heads: int,
    head_dim: int,
    batch_size: int,
    max_seq_len: int,
    bytes_per_param: int = 2  # BF16 = 2 bytes
) -> str:
    """
    计算 KV Cache 显存占用
    
    每个 token 需要存储:
    - K: num_kv_heads × head_dim
    - V: num_kv_heads × head_dim
    - 共 2 × num_kv_heads × head_dim 个参数
    
    总显存 = 2 × batch × seq × num_kv_heads × head_dim × bytes
    """
    # 单 token 的 KV Cache 大小(bytes)
    per_token = 2 * num_kv_heads * head_dim * bytes_per_param
    
    # 总显存
    total_bytes = per_token * batch_size * max_seq_len
    
    # 转换为人类可读格式
    if total_bytes > 1e9:
        return f"{total_bytes / 1e9:.2f} GB"
    elif total_bytes > 1e6:
        return f"{total_bytes / 1e6:.2f} MB"
    else:
        return f"{total_bytes / 1e3:.2f} KB"


# LLaMA-2 7B 示例
print(calculate_kv_cache_memory(
    num_layers=32,
    num_kv_heads=32,  # MHA
    head_dim=128,
    batch_size=1,
    max_seq_len=4096
))
# 输出: 128.00 MB

# LLaMA-2 34B (GQA, num_kv_heads=8)
print(calculate_kv_cache_memory(
    num_layers=48,
    num_kv_heads=8,
    head_dim=128,
    batch_size=1,
    max_seq_len=4096
))
# 输出: 48.00 MB (减少了 4 倍!)

6.1.2 KV Cache 管理

python
class KVCache:
    """
    简单的 KV Cache 实现
    """
    
    def __init__(self, max_batch_size, max_seq_len, num_kv_heads, head_dim):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        
        # 预分配 GPU 显存
        self.k_cache = torch.zeros(
            max_batch_size, max_seq_len, num_kv_heads, head_dim,
            dtype=torch.bfloat16, device='cuda'
        )
        self.v_cache = torch.zeros(
            max_batch_size, max_seq_len, num_kv_heads, head_dim,
            dtype=torch.bfloat16, device='cuda'
        )
        
        # 当前序列长度
        self.seq_lengths = [0] * max_batch_size
    
    def update(self, batch_idx, seq_len, k, v):
        """
        更新指定 batch 位置的 KV
        """
        self.k_cache[batch_idx, seq_len] = k
        self.v_cache[batch_idx, seq_len] = v
        self.seq_lengths[batch_idx] = seq_len + 1
    
    def get(self, batch_idx):
        """
        获取指定 batch 的完整 KV
        """
        seq_len = self.seq_lengths[batch_idx]
        return (
            self.k_cache[batch_idx, :seq_len],
            self.v_cache[batch_idx, :seq_len]
        )

6.1.3 Prefix Caching

问题:多个请求可能有相同的前缀(如 system prompt),重复计算浪费

python
class RadixTree:
    """
    前缀复用:使用 Radix Tree 管理 KV Cache
    
    相同前缀的请求可以复用已计算的 KV
    """
    
    def __init__(self):
        self.root = {'children': {}, 'kv_cache': None, 'ref_count': 0}
    
    def insert(self, tokens, kv_cache):
        """
        插入新序列
        """
        node = self.root
        node['ref_count'] += 1
        
        for token in tokens:
            if token not in node['children']:
                node['children'][token] = {
                    'children': {}, 
                    'kv_cache': None, 
                    'ref_count': 0
                }
            node = node['children'][token]
            node['ref_count'] += 1
        
        node['kv_cache'] = kv_cache
    
    def lookup(self, tokens):
        """
        查找可复用的前缀
        返回: (最长匹配长度, kv_cache)
        """
        node = self.root
        match_len = 0
        
        for i, token in enumerate(tokens):
            if token not in node['children']:
                break
            node = node['children'][token]
            if node['kv_cache'] is not None:
                match_len = i + 1
        
        return match_len, node['kv_cache']
    
    def evict_if_needed(self):
        """
        LRU 驱逐低引用计数的节点
        """
        # 简化实现:驱逐 ref_count == 0 的节点
        def _cleanup(node):
            to_delete = []
            for token, child in node['children'].items():
                if child['ref_count'] == 0 and child['kv_cache'] is not None:
                    to_delete.append(token)
                else:
                    _cleanup(child)
            
            for token in to_delete:
                del node['children'][token]
        
        _cleanup(self.root)

6.2 解码优化策略

6.2.1 Speculative Decoding

核心思想:用小模型"猜"多个 token,大模型验证

传统 Decoding:
Step 1: 大模型生成 1 个 token
Step 2: 大模型生成 1 个 token
Step 3: ...

Speculative Decoding:
Step 1: 小模型连续生成 k=4 个 token: [a, b, c, d]
Step 2: 大模型并行验证这 4 个 token
        如果都正确 → 一次输出 4 个
        如果第 3 个错 → 只输出前 2 个,然后继续

速度提升: 通常 2-4x
python
def speculative_decoding(
    draft_model,  # 小模型(8B)
    target_model,  # 大模型(70B)
    prompt,
    k=4,  # draft 一次猜几个
    temperature=1.0
):
    """
    投机解码
    """
    # 1. Draft 模型生成 k 个 token
    draft_tokens = []
    current = prompt
    
    for _ in range(k):
        logits = draft_model(current)
        next_token = sample(logits, temperature)
        draft_tokens.append(next_token)
        current = current + [next_token]
    
    # 2. Target 模型并行验证
    # 把 prompt + draft tokens 一起输入 target
    target_logits = target_model(prompt + draft_tokens)
    
    # 3. 逐个验证
    accepted = []
    for i, draft_tok in enumerate(draft_tokens):
        target_prob = F.softmax(target_logits[i], dim=-1)[draft_tok]
        draft_prob = F.softmax(draft_logits[i], dim=-1)[draft_tok]
        
        # 如果 target 认为 draft token 概率高,接受
        if target_prob >= draft_prob or random.random() < target_prob / draft_prob:
            accepted.append(draft_tok)
        else:
            # 拒绝,用 target 的分布采样
            new_tok = sample(target_logits[i], temperature)
            accepted.append(new_tok)
            break
    
    return accepted


# 选择 draft 模型的原则
"""
1. 参数量: target 的 1/10 ~ 1/5
   - target 70B → draft 7B 或 13B

2. 架构相似: 最好同家族
   - target LLaMA-70B → draft LLaMA-7B

3. 质量差距不能太大
   - 否则 rejection rate 太高
"""

6.2.2 Medusa / Lookahead Decoding

核心思想:并行生成多个候选序列,选最优

python
class MedusaDecoding:
    """
    Medusa: 同时预测多个位置的 token
    
    示意图:
    
    标准:  [token₁] → [token₂] → [token₃]

                          需要等 token₂
    
    Medusa:  [token₁] 
             /    \    \
            ↓     ↓     ↓
          [t₂¹] [t₂²] [t₂³]  ← 3 个候选
            |     |     |
            ↓     ↓     ↓
          [t₃¹] [t₃²] [t₃³]  ← 3 个候选
            |     |     |
            └─────┴─────┘

            选择最优路径继续
    """
    
    def __init__(self, model, num_heads=3):
        self.model = model
        self.num_heads = num_heads
        
        # 添加额外的预测头
        # 每个头预测未来第 N 个位置
        self.medusa_heads = nn.ModuleList([
            nn.Linear(model.hidden_dim, model.vocab_size)
            for _ in range(num_heads)
        ])
    
    def forward_with_medusa(self, hidden_states):
        """
        hidden_states: 当前层的隐藏状态
        返回: (主预测, medusa 预测列表)
        """
        # 主预测
        main_output = self.model.lm_head(hidden_states)
        
        # Medusa 预测
        medusa_outputs = []
        for head in self.medusa_heads:
            medusa_outputs.append(head(hidden_states))
        
        return main_output, medusa_outputs

6.2.3 Chunked Prefill

问题:长 prompt 的 prefill 阶段太慢,导致首个 token 延迟高

解决:把 prefill 分成多个 chunk

python
class ChunkedPrefill:
    """
    Chunked Prefill: 把长 prompt 分块处理
    
    优势:
    - 减少首 token 延迟
    - 让 prefill 和 decode 更好地 overlap
    - 提高吞吐
    """
    
    def __init__(self, model, chunk_size=512):
        self.model = model
        self.chunk_size = chunk_size
    
    def prefill(self, input_ids, max_new_tokens=512):
        """
        分块 prefill
        """
        if len(input_ids) <= self.chunk_size:
            # 短序列,直接处理
            return self.model.forward(input_ids)
        
        # 分块处理
        for i in range(0, len(input_ids), self.chunk_size):
            chunk = input_ids[i:i + self.chunk_size]
            
            # 处理当前 chunk
            logits = self.model.forward(chunk)
            
            # 只在第一个 chunk 后等待,完成后立即开始生成
            if i == 0:
                first_logits = logits
            
            # 后续 chunk 可以和 decode overlap
            # ...
        
        return first_logits

6.3 量化技术详解

6.3.1 对称 vs 非对称量化

对称量化:
- 零点 = 0
- 公式: x_q = round(x / scale)
- scale = max(|x|) / (2^(n-1) - 1)
- 适用: weights(分布对称或接近零)

非对称量化:
- 零点 ≠ 0
- 公式: x_q = round(x / scale) + zero_point
- 需要额外存储 zero_point
- 适用: activations(分布常偏斜)
python
def symmetric_quantize(tensor, num_bits=8):
    """
    对称量化
    """
    scale = tensor.abs().max() / (2**(num_bits - 1) - 1)
    quantized = torch.round(tensor / scale).clamp(-2**(num_bits-1), 2**(num_bits-1)-1)
    return quantized.to(torch.int8), scale


def asymmetric_quantize(tensor, num_bits=8):
    """
    非对称量化
    """
    min_val = tensor.min()
    max_val = tensor.max()
    
    scale = (max_val - min_val) / (2**num_bits - 1)
    zero_point = torch.round(-min_val / scale).clamp(0, 2**num_bits - 1)
    
    quantized = torch.round(tensor / scale + zero_point).clamp(0, 2**num_bits - 1)
    return quantized.to(torch.int8), scale, zero_point


def dequantize(quantized, scale, zero_point=None):
    """
    反量化
    """
    if zero_point is None:
        # 对称
        return quantized.float() * scale
    else:
        # 非对称
        return (quantized.float() - zero_point) * scale

6.3.2 GPTQ 量化

核心思想:OBQ (Optimal Brain Quantization),逐层贪心量化

python
class GPTQQuantizer:
    """
    GPTQ: One-Shot Quantization with Optimal Brain Surgeon
    
    核心思想:
    - 逐层处理
    - 对每层,用 Hessian 矩阵选择最不影响重建的权重量化
    """
    
    def __init__(self, model, bits=4, block_size=128):
        self.model = model
        self.bits = bits
        self.block_size = block_size
    
    def quantize_layer(self, layer):
        """
        GPTQ 量化单层
        """
        # 1. 获取层的权重
        weight = layer.weight.data.clone()
        original_shape = weight.shape
        num_groups = weight.numel() // self.block_size
        
        # 2. 准备量化误差补偿
        weight_flatten = weight.flatten()
        quant_weight = weight_flatten.clone()
        error = torch.zeros_like(weight_flatten)
        
        # 3. 逐块量化
        for i in range(num_groups):
            # 取当前块
            start = i * self.block_size
            end = start + self.block_size
            block = weight_flatten[start:end]
            
            # 计算 Hessian 近似(用于评估量化影响)
            # 简化:用单位矩阵
            hessian_inv = torch.eye(self.block_size)
            
            # 4. 贪心选择量化权重
            for j in range(self.block_size):
                # 计算每个权重的量化误差
                if weight_flatten[start + j].abs() < torch.abs(quant_weight[start:end] - block).mean():
                    # 保持原值
                    pass
                else:
                    # 量化
                    scale = block.abs().max() / (2**(self.bits - 1) - 1)
                    quant_weight[start + j] = torch.round(block[j] / scale) * scale
            
            # 5. 计算误差并累积
            error[start:end] = block - quant_weight[start:end]
            weight_flatten[start:end] = quant_weight[start:end]
        
        # 6. 返回量化后的权重
        return weight_flatten.reshape(original_shape)
    
    def quantize(self):
        """
        量化整个模型
        """
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                module.weight.data = self.quantize_layer(module)

6.3.3 AWQ 量化

核心思想:Activation-Aware Weight Quantization,保留重要权重

python
class AWQQuantizer:
    """
    AWQ: Activation-Aware Weight Quantization
    
    核心思想:
    - 不是所有权重都同等重要
    - 保留 1% 的显著权重(salient weights)为高精度
    - 只量化剩余 99%
    
    显著权重判断:看激活值大的位置(通常是对输出影响大的)
    """
    
    def __init__(self, model, bits=4, percentile=0.99):
        self.model = model
        self.bits = bits
        self.percentile = percentile
    
    def find_salient_weights(self, weight, activations):
        """
        找出显著权重
        方法:权重 × 激活值的绝对值
        """
        # importance = |W| × |X|
        importance = weight.abs() * activations.abs().mean(dim=0)
        
        # 找出 top 1% 的位置
        threshold = torch.quantile(importance.flatten(), self.percentile)
        salient_mask = importance > threshold
        
        return salient_mask
    
    def quantize_with_saliency(self, weight, salient_mask):
        """
        对非显著权重进行量化,显著权重保持高精度
        """
        # 分离显著和非显著权重
        weight_quant = weight.clone()
        weight_non_salient = weight.masked_fill(salient_mask, 0)
        
        # 量化非显著部分
        scale = weight_non_salient.abs().max() / (2**(self.bits - 1) - 1)
        weight_quant[~salient_mask] = torch.round(
            weight_non_salient[~salient_mask] / scale
        ) * scale
        
        return weight_quant

6.3.4 SmoothQuant

核心思想:把量化难度从 weights 转移到 activations

python
class SmoothQuant:
    """
    SmoothQuant: 减小量化难度从 weights 到 activations
    
    观察:
    - Weights 的 outliers → 量化误差大
    - 但 activations 通常更"正常"
    
    方法:平滑权重分布
    W_out = (W_in / s) @ diag(s)
    
    其中 s 由 activations 的 per-channel 尺度决定
    """
    
    def __init__(self, model, alpha=0.5):
        self.model = model
        self.alpha = alpha  # 平滑因子
    
    def smooth_linear_layer(self, layer, input_feat):
        """
        对 Linear 层做 SmoothQuant
        """
        # 1. 计算 per-channel 尺度
        # 输入激活的 per-channel 尺度
        channel_scales = input_feat.abs().mean(dim=0)  # (out_features,)
        
        # 2. 计算平滑因子 s
        # s = max(输入尺度) ^ alpha × max(权重尺度) ^ (1-alpha)
        weight_scales = layer.weight.abs().mean(dim=1)  # (out_features,)
        smooth_scale = (channel_scales ** self.alpha) * (weight_scales ** (1 - self.alpha))
        
        # 3. 应用平滑
        inv_scale = 1.0 / smooth_scale.clamp(min=1e-5)
        layer.weight.data = layer.weight.data * inv_scale.unsqueeze(1)
        
        return inv_scale
    
    def quantize(self, model, calibration_data):
        """
        完整的 SmoothQuant 流程
        """
        # 1. 获取激活统计
        input_feat = {}
        def hook_fn(module, input, output):
            input_feat['x'] = input[0].detach()
        
        hooks = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear) and 'lm_head' not in name:
                hooks.append(module.register_forward_hook(hook_fn))
        
        with torch.no_grad():
            for batch in calibration_data:
                model(batch)
        
        for h in hooks:
            h.remove()
        
        # 2. 应用平滑并量化
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                self.smooth_linear_layer(module, input_feat.get(name, None))
                # 然后量化...

6.3.5 FP8 量化

python
"""
FP8: 8-bit Floating Point

两种格式:
- E4M3: 1 sign + 4 exponent + 3 mantissa (精度高,范围小)
- E5M2: 1 sign + 5 exponent + 2 mantissa (范围大,精度低)

适用范围:
- E4M3: weights, activations
- E5M2: gradients (需要更大范围)
"""

# PyTorch FP8 支持 (H100+)
def fp8_quantize(tensor, format='E4M3'):
    if format == 'E4M3':
        # E4M3: 指数范围 [-6, 15]
        # 需要先把 tensor 缩放到 [-448, 448]
        scale = tensor.abs().max() / 448.0
        scaled = tensor / scale
        return torch.float8_e4m3fn(scaled), scale
    elif format == 'E5M2':
        # E5M2: 指数范围 [-15, 16]
        scale = tensor.abs().max() / 57344.0
        scaled = tensor / scale
        return torch.float8_e5m2(scaled), scale

6.4 推理框架对比

6.4.1 vLLM

python
# vLLM 核心特性

# 1. Paged Attention
# 内存碎片化 → 合并小blocks → 提高显存利用率

# 2. Continuous Batching
# 动态batch → 新请求可随时加入 → 提高吞吐

# 使用
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Llama-2-7b-hf")

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=256
)

outputs = llm.generate(["Hello, my name is", "I want to"], sampling_params)

for output in outputs:
    print(output.outputs[0].text)

vLLM 内部架构

┌─────────────────────────────────────────────────────────┐
│                      vLLM 调度器                          │
│                                                          │
│  Continuous Batching:                                    │
│  ┌─────┬─────┬─────┬─────┐                             │
│  │Req 1│Req 2│Req 3│Req 4│ ← 动态加入的新请求          │
│  └─────┴─────┴─────┴─────┘                             │
│                                                          │
│  Paged Attention:                                       │
│  KV Cache 被分成 blocks,新 token 可直接追加              │
│  ┌────┬────┬────┬────┐                                 │
│  │ B0 │ B1 │ B2 │ B3 │ ← 按需分配,不浪费             │
│  └────┴────┴────┴────┘                                 │
│                                                          │
│  比 HuggingFace 提升 10-30x 吞吐                        │
└─────────────────────────────────────────────────────────┘

6.4.2 TensorRT-LLM

python
# TensorRT-LLM 核心特性

# 1. 算子融合 (Kernel Fusion)
# 把多个小算子合并成一个大 CUDA kernel
# 减少显存访问,提高计算密度

# 2. Inflight Batching
# 请求级并行,不等整个 batch 完成

# 3. INT8/FP8 支持
# 量化优化,硬件加速

# 使用
# 需要先用 trtllm-build 构建引擎
# trtllm-build --model_dir=./llama --quantization=fp8 --output=./engine

from tensorrt_llm import LLM

llm = LLM(engine_dir="./engine")
outputs = llm.generate(["Hello, world!"])

6.4.3 SGLang

python
# SGLang: Structured Generation Language

# 核心特性:
# 1. RadixAttention: 前缀复用
# 2. 约束解码: 支持 JSON schema、正则等
# 3. 高效 beam search

from sglang import model_server, gen

# 装饰器方式
@model_server
def my_model(input_str):
    return gen("my_model", input_str, max_tokens=100)

# 批量生成
results = gen_batch("my_model", ["input1", "input2", "input3"])

# 约束解码
result = gen("my_model", "Extract: ", 
            json_schema={
                "name": str,
                "age": int
            })

6.4.4 框架对比

框架适用场景吞吐易用性量化支持
vLLM生产部署最高INT4/8/FP16/BF16
TensorRT-LLM生产部署最高INT4/8/FP8/BF16
SGLang复杂推理INT8
llama.cppCPU/边缘INT4/8
HF + bitsandbytes实验最好INT8

6.5 本章小结

┌─────────────────────────────────────────────────────────────┐
│                    推理优化技术全景                           │
├─────────────────────────────────────────────────────────────┤
│  KV Cache                                                 │
│  ├─ 显存计算  │  2×b×s×h×d×bytes                          │
│  ├─ GQA 减少  │  num_kv_heads 决定 KV 大小                 │
│  └─ 前缀复用  │  Radix Tree 管理相同前缀                   │
├─────────────────────────────────────────────────────────────┤
│  解码优化                                                  │
│  ├─ 投机解码  │  小模型猜,大模型验证,2-4x 加速            │
│  ├─ Medusa   │  多候选并行,选最优                         │
│  └─ Chunked  │  prefill 分块,减少首 token 延迟           │
├─────────────────────────────────────────────────────────────┤
│  量化技术                                                  │
│  ├─ GPTQ     │  OBQ 逐层量化,需校准数据                   │
│  ├─ AWQ      │  保护显著权重,只量化 99%                   │
│  ├─ SmoothQuant│  平滑权重分布,减小量化难度                 │
│  └─ FP8      │  H100+,E4M3/E5M2                         │
├─────────────────────────────────────────────────────────────┤
│  推理框架                                                  │
│  ├─ vLLM     │  PagedAttn + ContinuousBatching           │
│  ├─ TensorRT │  算子融合 + 引擎优化                        │
│  └─ SGLang   │  RadixAttn + 约束解码                      │
└─────────────────────────────────────────────────────────────┘

显存估算汇总

以 LLaMA-2 70B 为例:

┌─────────────────────────────────────────────────────────────┐
│                    显存占用对比                              │
├────────────────────┬──────────┬──────────┬─────────────────┤
│       精度         │  模型    │  KV Cache│  总计 (bs=1,4K) │
├────────────────────┼──────────┼──────────┼─────────────────┤
│  FP16              │  140 GB  │  112 GB  │  252 GB         │
│  INT4 (GPTQ)       │  35 GB   │  112 GB  │  147 GB         │
│  INT4 + GQA (8 kv) │  35 GB   │  28 GB   │  63 GB          │
│  INT4 + GQA + QLoRA│  35 GB   │  28 GB   │  ~40 GB (推理)  │
└────────────────────┴──────────┴──────────┴─────────────────┘

单卡 A100 (80GB): FP16 放不下 70B,但 INT4 + GQA 可以!

面试高频问题

  1. KV Cache 显存怎么算?
    → 2 × batch × seq_len × num_kv_heads × head_dim × bytes_per_param

  2. 投机解码加速原理?
    → 小模型并行猜 k 个 token,大模型一次验证,通常 2-4x 加速

  3. GPTQ vs AWQ 区别?
    → GPTQ 逐层贪心量化;AWQ 保留显著权重为高精度

  4. Paged Attention 是什么?
    → 把 KV Cache 分成小块管理,减少碎片,提高显存利用率

  5. vLLM vs TensorRT-LLM 怎么选? → vLLM 易用性好,适合快速上线;TRT-LLM 性能最优,但需要构建引擎

基于 MIT 许可发布