推理系统与量化部署
本章目标:掌握生产环境的吞吐优化与模型压缩
6.1 KV Cache 工程
6.1.1 KV Cache 显存计算
python
def calculate_kv_cache_memory(
num_layers: int,
num_kv_heads: int,
head_dim: int,
batch_size: int,
max_seq_len: int,
bytes_per_param: int = 2 # BF16 = 2 bytes
) -> str:
"""
计算 KV Cache 显存占用
每个 token 需要存储:
- K: num_kv_heads × head_dim
- V: num_kv_heads × head_dim
- 共 2 × num_kv_heads × head_dim 个参数
总显存 = 2 × batch × seq × num_kv_heads × head_dim × bytes
"""
# 单 token 的 KV Cache 大小(bytes)
per_token = 2 * num_kv_heads * head_dim * bytes_per_param
# 总显存
total_bytes = per_token * batch_size * max_seq_len
# 转换为人类可读格式
if total_bytes > 1e9:
return f"{total_bytes / 1e9:.2f} GB"
elif total_bytes > 1e6:
return f"{total_bytes / 1e6:.2f} MB"
else:
return f"{total_bytes / 1e3:.2f} KB"
# LLaMA-2 7B 示例
print(calculate_kv_cache_memory(
num_layers=32,
num_kv_heads=32, # MHA
head_dim=128,
batch_size=1,
max_seq_len=4096
))
# 输出: 128.00 MB
# LLaMA-2 34B (GQA, num_kv_heads=8)
print(calculate_kv_cache_memory(
num_layers=48,
num_kv_heads=8,
head_dim=128,
batch_size=1,
max_seq_len=4096
))
# 输出: 48.00 MB (减少了 4 倍!)6.1.2 KV Cache 管理
python
class KVCache:
"""
简单的 KV Cache 实现
"""
def __init__(self, max_batch_size, max_seq_len, num_kv_heads, head_dim):
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# 预分配 GPU 显存
self.k_cache = torch.zeros(
max_batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.bfloat16, device='cuda'
)
self.v_cache = torch.zeros(
max_batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.bfloat16, device='cuda'
)
# 当前序列长度
self.seq_lengths = [0] * max_batch_size
def update(self, batch_idx, seq_len, k, v):
"""
更新指定 batch 位置的 KV
"""
self.k_cache[batch_idx, seq_len] = k
self.v_cache[batch_idx, seq_len] = v
self.seq_lengths[batch_idx] = seq_len + 1
def get(self, batch_idx):
"""
获取指定 batch 的完整 KV
"""
seq_len = self.seq_lengths[batch_idx]
return (
self.k_cache[batch_idx, :seq_len],
self.v_cache[batch_idx, :seq_len]
)6.1.3 Prefix Caching
问题:多个请求可能有相同的前缀(如 system prompt),重复计算浪费
python
class RadixTree:
"""
前缀复用:使用 Radix Tree 管理 KV Cache
相同前缀的请求可以复用已计算的 KV
"""
def __init__(self):
self.root = {'children': {}, 'kv_cache': None, 'ref_count': 0}
def insert(self, tokens, kv_cache):
"""
插入新序列
"""
node = self.root
node['ref_count'] += 1
for token in tokens:
if token not in node['children']:
node['children'][token] = {
'children': {},
'kv_cache': None,
'ref_count': 0
}
node = node['children'][token]
node['ref_count'] += 1
node['kv_cache'] = kv_cache
def lookup(self, tokens):
"""
查找可复用的前缀
返回: (最长匹配长度, kv_cache)
"""
node = self.root
match_len = 0
for i, token in enumerate(tokens):
if token not in node['children']:
break
node = node['children'][token]
if node['kv_cache'] is not None:
match_len = i + 1
return match_len, node['kv_cache']
def evict_if_needed(self):
"""
LRU 驱逐低引用计数的节点
"""
# 简化实现:驱逐 ref_count == 0 的节点
def _cleanup(node):
to_delete = []
for token, child in node['children'].items():
if child['ref_count'] == 0 and child['kv_cache'] is not None:
to_delete.append(token)
else:
_cleanup(child)
for token in to_delete:
del node['children'][token]
_cleanup(self.root)6.2 解码优化策略
6.2.1 Speculative Decoding
核心思想:用小模型"猜"多个 token,大模型验证
传统 Decoding:
Step 1: 大模型生成 1 个 token
Step 2: 大模型生成 1 个 token
Step 3: ...
Speculative Decoding:
Step 1: 小模型连续生成 k=4 个 token: [a, b, c, d]
Step 2: 大模型并行验证这 4 个 token
如果都正确 → 一次输出 4 个
如果第 3 个错 → 只输出前 2 个,然后继续
速度提升: 通常 2-4xpython
def speculative_decoding(
draft_model, # 小模型(8B)
target_model, # 大模型(70B)
prompt,
k=4, # draft 一次猜几个
temperature=1.0
):
"""
投机解码
"""
# 1. Draft 模型生成 k 个 token
draft_tokens = []
current = prompt
for _ in range(k):
logits = draft_model(current)
next_token = sample(logits, temperature)
draft_tokens.append(next_token)
current = current + [next_token]
# 2. Target 模型并行验证
# 把 prompt + draft tokens 一起输入 target
target_logits = target_model(prompt + draft_tokens)
# 3. 逐个验证
accepted = []
for i, draft_tok in enumerate(draft_tokens):
target_prob = F.softmax(target_logits[i], dim=-1)[draft_tok]
draft_prob = F.softmax(draft_logits[i], dim=-1)[draft_tok]
# 如果 target 认为 draft token 概率高,接受
if target_prob >= draft_prob or random.random() < target_prob / draft_prob:
accepted.append(draft_tok)
else:
# 拒绝,用 target 的分布采样
new_tok = sample(target_logits[i], temperature)
accepted.append(new_tok)
break
return accepted
# 选择 draft 模型的原则
"""
1. 参数量: target 的 1/10 ~ 1/5
- target 70B → draft 7B 或 13B
2. 架构相似: 最好同家族
- target LLaMA-70B → draft LLaMA-7B
3. 质量差距不能太大
- 否则 rejection rate 太高
"""6.2.2 Medusa / Lookahead Decoding
核心思想:并行生成多个候选序列,选最优
python
class MedusaDecoding:
"""
Medusa: 同时预测多个位置的 token
示意图:
标准: [token₁] → [token₂] → [token₃]
↑
需要等 token₂
Medusa: [token₁]
/ \ \
↓ ↓ ↓
[t₂¹] [t₂²] [t₂³] ← 3 个候选
| | |
↓ ↓ ↓
[t₃¹] [t₃²] [t₃³] ← 3 个候选
| | |
└─────┴─────┘
↓
选择最优路径继续
"""
def __init__(self, model, num_heads=3):
self.model = model
self.num_heads = num_heads
# 添加额外的预测头
# 每个头预测未来第 N 个位置
self.medusa_heads = nn.ModuleList([
nn.Linear(model.hidden_dim, model.vocab_size)
for _ in range(num_heads)
])
def forward_with_medusa(self, hidden_states):
"""
hidden_states: 当前层的隐藏状态
返回: (主预测, medusa 预测列表)
"""
# 主预测
main_output = self.model.lm_head(hidden_states)
# Medusa 预测
medusa_outputs = []
for head in self.medusa_heads:
medusa_outputs.append(head(hidden_states))
return main_output, medusa_outputs6.2.3 Chunked Prefill
问题:长 prompt 的 prefill 阶段太慢,导致首个 token 延迟高
解决:把 prefill 分成多个 chunk
python
class ChunkedPrefill:
"""
Chunked Prefill: 把长 prompt 分块处理
优势:
- 减少首 token 延迟
- 让 prefill 和 decode 更好地 overlap
- 提高吞吐
"""
def __init__(self, model, chunk_size=512):
self.model = model
self.chunk_size = chunk_size
def prefill(self, input_ids, max_new_tokens=512):
"""
分块 prefill
"""
if len(input_ids) <= self.chunk_size:
# 短序列,直接处理
return self.model.forward(input_ids)
# 分块处理
for i in range(0, len(input_ids), self.chunk_size):
chunk = input_ids[i:i + self.chunk_size]
# 处理当前 chunk
logits = self.model.forward(chunk)
# 只在第一个 chunk 后等待,完成后立即开始生成
if i == 0:
first_logits = logits
# 后续 chunk 可以和 decode overlap
# ...
return first_logits6.3 量化技术详解
6.3.1 对称 vs 非对称量化
对称量化:
- 零点 = 0
- 公式: x_q = round(x / scale)
- scale = max(|x|) / (2^(n-1) - 1)
- 适用: weights(分布对称或接近零)
非对称量化:
- 零点 ≠ 0
- 公式: x_q = round(x / scale) + zero_point
- 需要额外存储 zero_point
- 适用: activations(分布常偏斜)python
def symmetric_quantize(tensor, num_bits=8):
"""
对称量化
"""
scale = tensor.abs().max() / (2**(num_bits - 1) - 1)
quantized = torch.round(tensor / scale).clamp(-2**(num_bits-1), 2**(num_bits-1)-1)
return quantized.to(torch.int8), scale
def asymmetric_quantize(tensor, num_bits=8):
"""
非对称量化
"""
min_val = tensor.min()
max_val = tensor.max()
scale = (max_val - min_val) / (2**num_bits - 1)
zero_point = torch.round(-min_val / scale).clamp(0, 2**num_bits - 1)
quantized = torch.round(tensor / scale + zero_point).clamp(0, 2**num_bits - 1)
return quantized.to(torch.int8), scale, zero_point
def dequantize(quantized, scale, zero_point=None):
"""
反量化
"""
if zero_point is None:
# 对称
return quantized.float() * scale
else:
# 非对称
return (quantized.float() - zero_point) * scale6.3.2 GPTQ 量化
核心思想:OBQ (Optimal Brain Quantization),逐层贪心量化
python
class GPTQQuantizer:
"""
GPTQ: One-Shot Quantization with Optimal Brain Surgeon
核心思想:
- 逐层处理
- 对每层,用 Hessian 矩阵选择最不影响重建的权重量化
"""
def __init__(self, model, bits=4, block_size=128):
self.model = model
self.bits = bits
self.block_size = block_size
def quantize_layer(self, layer):
"""
GPTQ 量化单层
"""
# 1. 获取层的权重
weight = layer.weight.data.clone()
original_shape = weight.shape
num_groups = weight.numel() // self.block_size
# 2. 准备量化误差补偿
weight_flatten = weight.flatten()
quant_weight = weight_flatten.clone()
error = torch.zeros_like(weight_flatten)
# 3. 逐块量化
for i in range(num_groups):
# 取当前块
start = i * self.block_size
end = start + self.block_size
block = weight_flatten[start:end]
# 计算 Hessian 近似(用于评估量化影响)
# 简化:用单位矩阵
hessian_inv = torch.eye(self.block_size)
# 4. 贪心选择量化权重
for j in range(self.block_size):
# 计算每个权重的量化误差
if weight_flatten[start + j].abs() < torch.abs(quant_weight[start:end] - block).mean():
# 保持原值
pass
else:
# 量化
scale = block.abs().max() / (2**(self.bits - 1) - 1)
quant_weight[start + j] = torch.round(block[j] / scale) * scale
# 5. 计算误差并累积
error[start:end] = block - quant_weight[start:end]
weight_flatten[start:end] = quant_weight[start:end]
# 6. 返回量化后的权重
return weight_flatten.reshape(original_shape)
def quantize(self):
"""
量化整个模型
"""
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
module.weight.data = self.quantize_layer(module)6.3.3 AWQ 量化
核心思想:Activation-Aware Weight Quantization,保留重要权重
python
class AWQQuantizer:
"""
AWQ: Activation-Aware Weight Quantization
核心思想:
- 不是所有权重都同等重要
- 保留 1% 的显著权重(salient weights)为高精度
- 只量化剩余 99%
显著权重判断:看激活值大的位置(通常是对输出影响大的)
"""
def __init__(self, model, bits=4, percentile=0.99):
self.model = model
self.bits = bits
self.percentile = percentile
def find_salient_weights(self, weight, activations):
"""
找出显著权重
方法:权重 × 激活值的绝对值
"""
# importance = |W| × |X|
importance = weight.abs() * activations.abs().mean(dim=0)
# 找出 top 1% 的位置
threshold = torch.quantile(importance.flatten(), self.percentile)
salient_mask = importance > threshold
return salient_mask
def quantize_with_saliency(self, weight, salient_mask):
"""
对非显著权重进行量化,显著权重保持高精度
"""
# 分离显著和非显著权重
weight_quant = weight.clone()
weight_non_salient = weight.masked_fill(salient_mask, 0)
# 量化非显著部分
scale = weight_non_salient.abs().max() / (2**(self.bits - 1) - 1)
weight_quant[~salient_mask] = torch.round(
weight_non_salient[~salient_mask] / scale
) * scale
return weight_quant6.3.4 SmoothQuant
核心思想:把量化难度从 weights 转移到 activations
python
class SmoothQuant:
"""
SmoothQuant: 减小量化难度从 weights 到 activations
观察:
- Weights 的 outliers → 量化误差大
- 但 activations 通常更"正常"
方法:平滑权重分布
W_out = (W_in / s) @ diag(s)
其中 s 由 activations 的 per-channel 尺度决定
"""
def __init__(self, model, alpha=0.5):
self.model = model
self.alpha = alpha # 平滑因子
def smooth_linear_layer(self, layer, input_feat):
"""
对 Linear 层做 SmoothQuant
"""
# 1. 计算 per-channel 尺度
# 输入激活的 per-channel 尺度
channel_scales = input_feat.abs().mean(dim=0) # (out_features,)
# 2. 计算平滑因子 s
# s = max(输入尺度) ^ alpha × max(权重尺度) ^ (1-alpha)
weight_scales = layer.weight.abs().mean(dim=1) # (out_features,)
smooth_scale = (channel_scales ** self.alpha) * (weight_scales ** (1 - self.alpha))
# 3. 应用平滑
inv_scale = 1.0 / smooth_scale.clamp(min=1e-5)
layer.weight.data = layer.weight.data * inv_scale.unsqueeze(1)
return inv_scale
def quantize(self, model, calibration_data):
"""
完整的 SmoothQuant 流程
"""
# 1. 获取激活统计
input_feat = {}
def hook_fn(module, input, output):
input_feat['x'] = input[0].detach()
hooks = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and 'lm_head' not in name:
hooks.append(module.register_forward_hook(hook_fn))
with torch.no_grad():
for batch in calibration_data:
model(batch)
for h in hooks:
h.remove()
# 2. 应用平滑并量化
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
self.smooth_linear_layer(module, input_feat.get(name, None))
# 然后量化...6.3.5 FP8 量化
python
"""
FP8: 8-bit Floating Point
两种格式:
- E4M3: 1 sign + 4 exponent + 3 mantissa (精度高,范围小)
- E5M2: 1 sign + 5 exponent + 2 mantissa (范围大,精度低)
适用范围:
- E4M3: weights, activations
- E5M2: gradients (需要更大范围)
"""
# PyTorch FP8 支持 (H100+)
def fp8_quantize(tensor, format='E4M3'):
if format == 'E4M3':
# E4M3: 指数范围 [-6, 15]
# 需要先把 tensor 缩放到 [-448, 448]
scale = tensor.abs().max() / 448.0
scaled = tensor / scale
return torch.float8_e4m3fn(scaled), scale
elif format == 'E5M2':
# E5M2: 指数范围 [-15, 16]
scale = tensor.abs().max() / 57344.0
scaled = tensor / scale
return torch.float8_e5m2(scaled), scale6.4 推理框架对比
6.4.1 vLLM
python
# vLLM 核心特性
# 1. Paged Attention
# 内存碎片化 → 合并小blocks → 提高显存利用率
# 2. Continuous Batching
# 动态batch → 新请求可随时加入 → 提高吞吐
# 使用
from vllm import LLM, SamplingParams
llm = LLM(model="meta-llama/Llama-2-7b-hf")
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=256
)
outputs = llm.generate(["Hello, my name is", "I want to"], sampling_params)
for output in outputs:
print(output.outputs[0].text)vLLM 内部架构:
┌─────────────────────────────────────────────────────────┐
│ vLLM 调度器 │
│ │
│ Continuous Batching: │
│ ┌─────┬─────┬─────┬─────┐ │
│ │Req 1│Req 2│Req 3│Req 4│ ← 动态加入的新请求 │
│ └─────┴─────┴─────┴─────┘ │
│ │
│ Paged Attention: │
│ KV Cache 被分成 blocks,新 token 可直接追加 │
│ ┌────┬────┬────┬────┐ │
│ │ B0 │ B1 │ B2 │ B3 │ ← 按需分配,不浪费 │
│ └────┴────┴────┴────┘ │
│ │
│ 比 HuggingFace 提升 10-30x 吞吐 │
└─────────────────────────────────────────────────────────┘6.4.2 TensorRT-LLM
python
# TensorRT-LLM 核心特性
# 1. 算子融合 (Kernel Fusion)
# 把多个小算子合并成一个大 CUDA kernel
# 减少显存访问,提高计算密度
# 2. Inflight Batching
# 请求级并行,不等整个 batch 完成
# 3. INT8/FP8 支持
# 量化优化,硬件加速
# 使用
# 需要先用 trtllm-build 构建引擎
# trtllm-build --model_dir=./llama --quantization=fp8 --output=./engine
from tensorrt_llm import LLM
llm = LLM(engine_dir="./engine")
outputs = llm.generate(["Hello, world!"])6.4.3 SGLang
python
# SGLang: Structured Generation Language
# 核心特性:
# 1. RadixAttention: 前缀复用
# 2. 约束解码: 支持 JSON schema、正则等
# 3. 高效 beam search
from sglang import model_server, gen
# 装饰器方式
@model_server
def my_model(input_str):
return gen("my_model", input_str, max_tokens=100)
# 批量生成
results = gen_batch("my_model", ["input1", "input2", "input3"])
# 约束解码
result = gen("my_model", "Extract: ",
json_schema={
"name": str,
"age": int
})6.4.4 框架对比
| 框架 | 适用场景 | 吞吐 | 易用性 | 量化支持 |
|---|---|---|---|---|
| vLLM | 生产部署 | 最高 | 好 | INT4/8/FP16/BF16 |
| TensorRT-LLM | 生产部署 | 最高 | 中 | INT4/8/FP8/BF16 |
| SGLang | 复杂推理 | 高 | 好 | INT8 |
| llama.cpp | CPU/边缘 | 中 | 好 | INT4/8 |
| HF + bitsandbytes | 实验 | 低 | 最好 | INT8 |
6.5 本章小结
┌─────────────────────────────────────────────────────────────┐
│ 推理优化技术全景 │
├─────────────────────────────────────────────────────────────┤
│ KV Cache │
│ ├─ 显存计算 │ 2×b×s×h×d×bytes │
│ ├─ GQA 减少 │ num_kv_heads 决定 KV 大小 │
│ └─ 前缀复用 │ Radix Tree 管理相同前缀 │
├─────────────────────────────────────────────────────────────┤
│ 解码优化 │
│ ├─ 投机解码 │ 小模型猜,大模型验证,2-4x 加速 │
│ ├─ Medusa │ 多候选并行,选最优 │
│ └─ Chunked │ prefill 分块,减少首 token 延迟 │
├─────────────────────────────────────────────────────────────┤
│ 量化技术 │
│ ├─ GPTQ │ OBQ 逐层量化,需校准数据 │
│ ├─ AWQ │ 保护显著权重,只量化 99% │
│ ├─ SmoothQuant│ 平滑权重分布,减小量化难度 │
│ └─ FP8 │ H100+,E4M3/E5M2 │
├─────────────────────────────────────────────────────────────┤
│ 推理框架 │
│ ├─ vLLM │ PagedAttn + ContinuousBatching │
│ ├─ TensorRT │ 算子融合 + 引擎优化 │
│ └─ SGLang │ RadixAttn + 约束解码 │
└─────────────────────────────────────────────────────────────┘显存估算汇总
以 LLaMA-2 70B 为例:
┌─────────────────────────────────────────────────────────────┐
│ 显存占用对比 │
├────────────────────┬──────────┬──────────┬─────────────────┤
│ 精度 │ 模型 │ KV Cache│ 总计 (bs=1,4K) │
├────────────────────┼──────────┼──────────┼─────────────────┤
│ FP16 │ 140 GB │ 112 GB │ 252 GB │
│ INT4 (GPTQ) │ 35 GB │ 112 GB │ 147 GB │
│ INT4 + GQA (8 kv) │ 35 GB │ 28 GB │ 63 GB │
│ INT4 + GQA + QLoRA│ 35 GB │ 28 GB │ ~40 GB (推理) │
└────────────────────┴──────────┴──────────┴─────────────────┘
单卡 A100 (80GB): FP16 放不下 70B,但 INT4 + GQA 可以!面试高频问题
KV Cache 显存怎么算?
→ 2 × batch × seq_len × num_kv_heads × head_dim × bytes_per_param投机解码加速原理?
→ 小模型并行猜 k 个 token,大模型一次验证,通常 2-4x 加速GPTQ vs AWQ 区别?
→ GPTQ 逐层贪心量化;AWQ 保留显著权重为高精度Paged Attention 是什么?
→ 把 KV Cache 分成小块管理,减少碎片,提高显存利用率vLLM vs TensorRT-LLM 怎么选? → vLLM 易用性好,适合快速上线;TRT-LLM 性能最优,但需要构建引擎