长上下文专项技术
本章目标:突破 128K+ 上下文的技术栈(横跨训练与推理)
5.1 长文本微调技术
5.1.1 LongLoRA: Shift-Short Attention
核心思想:用近似 attention 降低微调时的计算量
标准 Self-Attention: O(n²) 显存和计算
S²-Attn: 近似 O(n) 显存,适合微调
核心:将 token 分组,在组内做 attention,减少通信量python
class ShiftedShortAttention(nn.Module):
"""
S²-Attn: 把序列分成若干组,组内做 attention
减少计算量的同时保持长距离依赖
示意图:
标准 Attention:
┌─────────────────────────┐
│ q₀ q₁ q₂ q₃ q₄ q₅ q₆ │ ← 所有 q 关注所有 k,v
│ ↖ ↑ ↑ ↑ ↑ ↗ │
└─────────────────────────┘
S²-Attn:
┌───┬───┬───┐
│ q₀│ q₁│ q₂│ ← 组内 attention
│ k₀│ k₁│ k₂│
├───┼───┼───┤
│ q₃│ q₄│ q₅│ ← 相邻组有 overlap
│ k₃│ k₄│ k₅│
└───┴───┴───┘
"""
def __init__(self, num_heads, head_dim, shift_ratio=0.25):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.shift_ratio = shift_ratio
def forward(self, q, k, v, group_size):
"""
q, k, v: (batch, seq_len, num_heads, head_dim)
group_size: 注意力窗口大小
"""
B, T, H, D = q.shape
# 1. 把序列分成组
num_groups = T // group_size
# 2. 对 Q,K,V 做 shift(在特征维度上偏移)
shift = int(group_size * self.shift_ratio)
# Pad 后 shift
q_shifted = F.pad(q, (0, 0, 0, 0, shift, shift))
k_shifted = F.pad(k, (0, 0, 0, 0, shift, shift))
v_shifted = F.pad(v, (0, 0, 0, 0, shift, shift))
# 3. Reshape 成组
q_groups = q_shifted.view(B, num_groups, group_size + 2*shift, H, D)
k_groups = k_shifted.view(B, num_groups, group_size + 2*shift, H, D)
v_groups = v_shifted.view(B, num_groups, group_size + 2*shift, H, D)
# 4. 组内做 attention
# 取中间的有效部分
q_valid = q_groups[:, :, shift:shift+group_size]
# 5. 计算 attention(简化版,实际实现更复杂)
attn_weights = torch.matmul(q_valid, k_groups.transpose(-2, -1)) / (D ** 0.5)
attn_weights = F.softmax(attn_weights, dim=-1)
output = torch.matmul(attn_weights, v_groups)
return output.view(B, T, H, D)5.1.2 LISA: Layerwise Importance Sampling
核心思想:不是所有层都同等重要,优先微调重要的层
python
class LISALoRA(nn.Module):
"""
LISA: Layerwise Importance Sampling for LoRA
核心思想:
- 深层的 attention 对齐更重要
- 浅层的 FFN 对知识保留更重要
- 按重要性分配 LoRA rank
"""
def __init__(self, model, base_rank=8, importance_scores=None):
super().__init__()
self.model = model
self.importance_scores = importance_scores or self._compute_importance()
# 根据重要性分配 rank
for i, layer in enumerate(model.transformer.layers):
importance = self.importance_scores[i]
rank = max(2, int(base_rank * importance))
# 深层多加 LoRA
if i > len(model.transformer.layers) // 2:
rank = int(rank * 1.5)
self._apply_lora_to_layer(layer, rank)
def _compute_importance(self):
"""
计算每层的重要性分数
方法:梯度范数、激活值方差、attention 熵等
"""
# 简化的重要性计算
num_layers = len(self.model.transformer.layers)
scores = []
for i in range(num_layers):
# 越深层越重要
scores.append((i + 1) / num_layers)
return scores5.1.3 长文本 SFT 数据构造
python
def construct_long_context_data(dataset, max_len=32768):
"""
长上下文 SFT 数据构造策略
"""
long_data = []
for doc in dataset:
# 方法1: 直接使用长文档
if len(doc) > max_len // 4:
# 截取中间部分(避免总是从开头开始)
start = len(doc) // 4
chunk = doc[start:start + max_len]
long_data.append(chunk)
# 方法2: 拼接多个短文档
else:
combined = []
current_len = 0
while current_len < max_len and len(combined) < len(dataset):
next_doc = dataset[(len(combined) % len(dataset))]
if current_len + len(next_doc) < max_len:
combined.append(next_doc)
current_len += len(next_doc)
long_data.append('\n\n'.join(combined))
# 方法3: 重复上下文模式
# QA 对,Q 在开头,A 在后面
# 强制模型在长上下文中找到答案
return long_data
# 数据配比建议
"""
YaRN 论文建议的长文本数据配比:
- 原始短文本数据: 70%
- 长文本数据 (8K+): 20%
- 合成长文本数据: 10%
"""5.2 推理阶段压缩:H2O / SnapKV / StreamingLLM
5.2.1 Heavy Hitter Oracle (H2O)
核心思想:不是所有 KV 都重要,识别并保留最重要的
问题:长上下文 KV Cache 太大
解决:驱逐不重要的 KV,保留重要的
H2O 假设:某些 token(如 "the", "is")贡献小,某些(如 "not", "but")贡献大python
class H2OKVCache:
"""
H2O: Heavy Hitter Oracle - 重要性驱动的 KV 驱逐
"""
def __init__(self, max_cache_size):
self.max_cache_size = max_cache_size
self.cache = {} # token_id -> importance_score
self.kv_cache = {}
def compute_importance(self, query, key, value, token_id):
"""
计算当前 token 的重要性
公式:importance = |Q·K| / sqrt(d)
即:当前 query 和这个 key 的注意力分数
"""
# Q @ K^T / sqrt(d)
score = torch.matmul(query, key.transpose(-2, -1)).abs().mean()
return score.item()
def update(self, token_id, key, value, query_for_next):
"""
更新 KV Cache,驱逐低重要性 token
"""
# 计算当前 token 的重要性
importance = self.compute_importance(
query_for_next, key, value, token_id
)
# 如果 cache 满了,驱逐最不重要的
if len(self.kv_cache) >= self.max_cache_size:
# 驱逐重要性最低的
min_importance_token = min(self.cache.items(), key=lambda x: x[1])
self.evict(min_importance_token[0])
# 存入新的
self.cache[token_id] = importance
self.kv_cache[token_id] = (key, value)
def evict(self, token_id):
"""驱逐低重要性 token"""
self.cache.pop(token_id, None)
self.kv_cache.pop(token_id, None)5.2.2 SnapKV
核心思想:先用观测窗口决定保留哪些 KV,再推理
python
class SnapKV:
"""
SnapKV: 通过观察窗口预判重要性
两阶段:
1. Prefill 阶段:用小窗口观察哪些 KV 重要
2. Decode 阶段:只保留重要的 KV
"""
def __init__(self, observe_window=64, max_cache=4096):
self.observe_window = observe_window
self.max_cache = max_cache
self.important_positions = []
def prefill_and_observe(self, hidden_states):
"""
第一阶段:用小窗口观察,统计每个位置的累计 attention
"""
# 只用前 observe_window 个 token 做 attention
prefix = hidden_states[:self.observe_window]
# 统计每个位置被注意到的次数
attention_counts = torch.zeros(hidden_states.shape[0])
for i in range(self.observe_window, hidden_states.shape[0]):
# 简化的 attention 计算
query = hidden_states[i]
# 计算与 prefix 的 attention
attn = F.softmax((query @ prefix.transpose(-2, -1)) / (hidden_states.shape[-1] ** 0.5))
# 累加到每个 prefix 位置
attention_counts[:self.observe_window] += attn[0]
# 找出最重要的位置
self.important_positions = torch.topk(
attention_counts,
min(self.max_cache, self.observe_window)
).indices.tolist()
def decode_with_cache(self, new_hidden):
"""
第二阶段:只解码重要位置
"""
# 只对重要位置做计算
important_kv = self.kv_cache[:, self.important_positions]
# 计算 attention
attn = torch.matmul(new_hidden, important_kv.transpose(-2, -1))
# ...5.2.3 StreamingLLM: Attention Sink
现象:LLM 对某些"锚点"token(如句首的 [CLS])有异常高的 attention
python
class StreamingLLM:
"""
StreamingLLM: 利用 Attention Sink 现象
观察:
- LLM 对第一个 token 有异常高的 attention(sink)
- 最近的几个 token 也有较高 attention(recent)
- 中间的 token attention 很低
所以:
- 只缓存 [sink tokens] + [recent tokens]
- 中间的直接丢弃
"""
def __init__(self, sink_tokens=4, recent_tokens=64):
self.sink_tokens = sink_tokens
self.recent_tokens = recent_tokens
# 固定的前缀(sink)
self.sink_cache = []
# 最近的 tokens
self.recent_cache = []
def update(self, token_id, kv):
"""
更新 rolling cache
"""
# Sink tokens 固定不变
if len(self.sink_cache) < self.sink_tokens:
self.sink_cache.append((token_id, kv))
# Recent tokens 用 buffer
self.recent_cache.append((token_id, kv))
if len(self.recent_cache) > self.recent_tokens:
self.recent_cache.pop(0)
def get_cached_kv(self):
"""
获取拼接后的缓存
"""
all_kv = []
for token_id, kv in self.sink_cache:
all_kv.append(kv)
for token_id, kv in self.recent_cache:
all_kv.append(kv)
return torch.cat(all_kv, dim=2) # 沿 seq 维度拼接可视化:
Attention Pattern:
Sink Recent
↓ ↓
Token: [S][S][S][S] ... ... ... [R][R][R][R]
↑______________________________↑
被忽略的中间部分
为什么有效?
- Sink: 提供"起点"信息,帮助理解上下文框架
- Recent: 最近的 tokens 包含即时信息
- 中间部分:信息已被"蒸馏"到 recent 里5.3 短训长推:SelfExtend / LM-Infinite
5.3.1 SelfExtend
核心思想:不需要重新训练,只需修改 attention mask
python
class SelfExtendAttention:
"""
SelfExtend: Grouped Attention,不改变模型权重
核心思想:
- 近距离 tokens: 标准 local attention
- 远距离 tokens: 分组 group attention
"""
def __init__(self, local_window=256, group_size=16):
self.local_window = local_window
self.group_size = group_size
def get_attention_mask(self, seq_len):
"""
生成 attention mask
"""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
for j in range(seq_len):
distance = i - j
if distance <= 0:
# 未来 token,mask 掉
mask[i, j] = float('-inf')
elif distance <= self.local_window:
# 局部窗口,标准 attention
mask[i, j] = 0
else:
# 远距离,分组
group_id = distance // self.group_size
mask[i, j] = 0 # 允许关注同组的
return mask
def forward(self, q, k, v, seq_len):
"""
带 mask 的 attention
"""
# 计算 attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
# 应用 mask
mask = self.get_attention_mask(seq_len).to(q.device)
scores = scores + mask
# softmax
attn = F.softmax(scores, dim=-1)
return torch.matmul(attn, v)5.3.2 LM-Infinite
核心思想:对距离超过 Λ 的 token 应用 uniform attention
python
class LMInfiniteAttention:
"""
LM-Infinite: ∞-transformer 的简化实现
核心思想:
- Λ 之内的 token: 正常 local attention
- Λ 之外的 token: 假设均匀分布,用 O(1) 表示
"""
def __init__(self, lambda_param=4096, local_window=2048):
self.lambda_param = lambda_param # 有效距离阈值
self.local_window = local_window
def forward(self, q, k, v, position_ids):
"""
position_ids: 每个 token 的位置
"""
B, T, H, D = q.shape
# 1. 计算每个 token 的有效距离
distances = position_ids.unsqueeze(-1) - position_ids.unsqueeze(-2)
# 2. 分离 local 和 global
local_mask = distances.abs() <= self.local_window
global_mask = (distances > 0) & (distances > self.local_window) & (distances <= self.lambda_param)
out_of_range_mask = distances > self.lambda_param
# 3. Local attention
local_scores = torch.matmul(q, k.transpose(-2, -1))
local_scores = local_scores.masked_fill(~local_mask, float('-inf'))
# 4. Global attention (uniform)
# 对超过 lambda 的 token,用平均 attention
global_k = k.masked_fill(out_of_range_mask.unsqueeze(-2), 0)
global_k_sum = global_k.sum(dim=2, keepdim=True)
global_count = (~out_of_range_mask).sum(dim=2, keepdim=True).clamp(min=1)
global_k_avg = global_k_sum / global_count
global_scores = torch.matmul(q, global_k_avg.transpose(-2, -1))
# 5. 合并
# ... (省略细节)
return attn @ v5.4 上下文检索增强
5.4.1 RAG vs 长上下文的边界
RAG 适用场景:
- 知识库检索(事实性知识)
- 需要精确检索的场景
- 知识会更新的场景
长上下文适用场景:
- 需要整体理解长文档
- 复杂推理(需要在全文中跳转)
- 代码库理解
- 多文档摘要5.4.2 上下文压缩
python
class ContextCompressor:
"""
用小型模型压缩长上下文
保留关键信息,减少 token 数量
"""
def __init__(self, compressor_model):
self.compressor = compressor_model
def compress(self, context, max_tokens=1024):
"""
压缩长上下文到固定长度
"""
prompt = f"""压缩以下文本,保留关键信息:
{context}
压缩后的摘要(不超过 {max_tokens} tokens):"""
compressed = self.compressor.generate(prompt, max_tokens=max_tokens)
return compressed
def hierarchical_compress(self, chunks, level=2):
"""
层级压缩:
L1: 每个 chunk → 摘要
L2: L1 摘要 → 全局摘要
"""
# Level 1: chunk → summary
level1_summaries = [self.compress(chunk) for chunk in chunks]
if level >= 2:
# Level 2: summaries → global summary
global_summary = self.compress('\n'.join(level1_summaries))
return global_summary, level1_summaries
else:
return level1_summaries5.4.3 选择性上下文
python
class SelectiveContext:
"""
选择性保留上下文:只保留对当前查询重要的部分
步骤:
1. 给每个 context chunk 打"重要性分数"
2. 只保留分数超过阈值的前 N 个 chunk
"""
def __init__(self, importance_model):
self.model = importance_model
def score_chunks(self, query, context_chunks):
"""
给每个 chunk 打分
"""
scores = []
for chunk in context_chunks:
# 简单方法:算 query 和 chunk 的重叠度
query_tokens = set(query.lower().split())
chunk_tokens = set(chunk.lower().split())
overlap = len(query_tokens & chunk_tokens)
scores.append(overlap)
return scores
def select(self, query, context_chunks, top_k=5):
"""
选择最重要的 chunks
"""
scores = self.score_chunks(query, context_chunks)
# 取 top-k
top_indices = sorted(range(len(scores)),
key=lambda i: scores[i],
reverse=True)[:top_k]
return [context_chunks[i] for i in sorted(top_indices)]5.5 本章小结
┌─────────────────────────────────────────────────────────────┐
│ 长上下文技术全景图 │
├─────────────────────────────────────────────────────────────┤
│ 微调阶段 │
│ ├─ LongLoRA │ S²-Attn 近似,组内 attention,省计算 │
│ ├─ LISA │ 按层重要性分配 LoRA rank │
│ └─ 长文本数据 │ 截取/拼接/合成 多种构造方式 │
├─────────────────────────────────────────────────────────────┤
│ 推理阶段 │
│ ├─ H2O │ 驱逐低重要性 KV,保留 Heavy Hitters │
│ ├─ SnapKV │ 观测窗口预判,只保留关键 KV │
│ ├─ StreamingLLM│ 利用 Attention Sink + Recent Buffer │
│ └─ SelfExtend │ Grouped Attention,无需重训练 │
├─────────────────────────────────────────────────────────────┤
│ 短训长推 │
│ ├─ LM-Infinite │ Λ之外 uniform attention │
│ └─ 位置插值 │ Position Interpolation │
└─────────────────────────────────────────────────────────────┘Attention Sink 可视化
典型 LLM 的 Attention Pattern:
位置 0 1 2 3 4 ... 100 101 102
┌────┬────┬────┬────┬────┬ ┌────┬────┬────┐
Token │ CLS│ the│ cat│ sat│ on│ ... │ the│ mat│ . │
├────┼────┼────┼────┼────┤ ├────┼────┼────┤
Attn │████████│▒▒▒▒│▒▒▒▒│▒▒▒▒│ │▒▒▒▒│████│▒▒▒▒│
│ Sink │ │ │ │ │ │Recent│ │
└────┴────┴────┴────┴────┘ └────┴────┴────┘
██ = 高 attention (Sink + Recent)
▒▒ = 低 attention (可丢弃)面试高频问题
Attention Sink 是什么?为什么有用? → LLM 对句首 token 有异常高 attention,作为"锚点"帮助理解;中间 token 可以丢弃
H2O 和 SnapKV 的区别? → H2O 是动态驱逐;SnapKV 是预判重要位置
LongLoRA 为什么能省计算? → S²-Attn 把长序列分成组,组内做 attention,减少 O(n²) 到 O(n)
RAG 和长上下文的取舍? → 知识检索用 RAG,复杂推理用长上下文