Skip to content

长上下文专项技术

本章目标:突破 128K+ 上下文的技术栈(横跨训练与推理)


5.1 长文本微调技术

5.1.1 LongLoRA: Shift-Short Attention

核心思想:用近似 attention 降低微调时的计算量

标准 Self-Attention: O(n²) 显存和计算
S²-Attn: 近似 O(n) 显存,适合微调

核心:将 token 分组,在组内做 attention,减少通信量
python
class ShiftedShortAttention(nn.Module):
    """
    S²-Attn: 把序列分成若干组,组内做 attention
    减少计算量的同时保持长距离依赖
    
    示意图:
    
    标准 Attention:
    ┌─────────────────────────┐
    │  q₀ q₁ q₂ q₃ q₄ q₅ q₆ │  ← 所有 q 关注所有 k,v
    │    ↖  ↑  ↑  ↑  ↑  ↗   │
    └─────────────────────────┘
    
    S²-Attn:
    ┌───┬───┬───┐
    │ q₀│ q₁│ q₂│  ← 组内 attention
    │ k₀│ k₁│ k₂│
    ├───┼───┼───┤
    │ q₃│ q₄│ q₅│  ← 相邻组有 overlap
    │ k₃│ k₄│ k₅│
    └───┴───┴───┘
    """
    
    def __init__(self, num_heads, head_dim, shift_ratio=0.25):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.shift_ratio = shift_ratio
    
    def forward(self, q, k, v, group_size):
        """
        q, k, v: (batch, seq_len, num_heads, head_dim)
        group_size: 注意力窗口大小
        """
        B, T, H, D = q.shape
        
        # 1. 把序列分成组
        num_groups = T // group_size
        
        # 2. 对 Q,K,V 做 shift(在特征维度上偏移)
        shift = int(group_size * self.shift_ratio)
        
        # Pad 后 shift
        q_shifted = F.pad(q, (0, 0, 0, 0, shift, shift))
        k_shifted = F.pad(k, (0, 0, 0, 0, shift, shift))
        v_shifted = F.pad(v, (0, 0, 0, 0, shift, shift))
        
        # 3. Reshape 成组
        q_groups = q_shifted.view(B, num_groups, group_size + 2*shift, H, D)
        k_groups = k_shifted.view(B, num_groups, group_size + 2*shift, H, D)
        v_groups = v_shifted.view(B, num_groups, group_size + 2*shift, H, D)
        
        # 4. 组内做 attention
        # 取中间的有效部分
        q_valid = q_groups[:, :, shift:shift+group_size]
        
        # 5. 计算 attention(简化版,实际实现更复杂)
        attn_weights = torch.matmul(q_valid, k_groups.transpose(-2, -1)) / (D ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        output = torch.matmul(attn_weights, v_groups)
        
        return output.view(B, T, H, D)

5.1.2 LISA: Layerwise Importance Sampling

核心思想:不是所有层都同等重要,优先微调重要的层

python
class LISALoRA(nn.Module):
    """
    LISA: Layerwise Importance Sampling for LoRA
    
    核心思想:
    - 深层的 attention 对齐更重要
    - 浅层的 FFN 对知识保留更重要
    - 按重要性分配 LoRA rank
    """
    
    def __init__(self, model, base_rank=8, importance_scores=None):
        super().__init__()
        self.model = model
        self.importance_scores = importance_scores or self._compute_importance()
        
        # 根据重要性分配 rank
        for i, layer in enumerate(model.transformer.layers):
            importance = self.importance_scores[i]
            rank = max(2, int(base_rank * importance))
            
            # 深层多加 LoRA
            if i > len(model.transformer.layers) // 2:
                rank = int(rank * 1.5)
            
            self._apply_lora_to_layer(layer, rank)
    
    def _compute_importance(self):
        """
        计算每层的重要性分数
        方法:梯度范数、激活值方差、attention 熵等
        """
        # 简化的重要性计算
        num_layers = len(self.model.transformer.layers)
        scores = []
        for i in range(num_layers):
            # 越深层越重要
            scores.append((i + 1) / num_layers)
        return scores

5.1.3 长文本 SFT 数据构造

python
def construct_long_context_data(dataset, max_len=32768):
    """
    长上下文 SFT 数据构造策略
    """
    long_data = []
    
    for doc in dataset:
        # 方法1: 直接使用长文档
        if len(doc) > max_len // 4:
            # 截取中间部分(避免总是从开头开始)
            start = len(doc) // 4
            chunk = doc[start:start + max_len]
            long_data.append(chunk)
        
        # 方法2: 拼接多个短文档
        else:
            combined = []
            current_len = 0
            while current_len < max_len and len(combined) < len(dataset):
                next_doc = dataset[(len(combined) % len(dataset))]
                if current_len + len(next_doc) < max_len:
                    combined.append(next_doc)
                    current_len += len(next_doc)
            
            long_data.append('\n\n'.join(combined))
        
        # 方法3: 重复上下文模式
        # QA 对,Q 在开头,A 在后面
        # 强制模型在长上下文中找到答案
    
    return long_data


# 数据配比建议
"""
YaRN 论文建议的长文本数据配比:
- 原始短文本数据: 70%
- 长文本数据 (8K+): 20%
- 合成长文本数据: 10%
"""

5.2 推理阶段压缩:H2O / SnapKV / StreamingLLM

5.2.1 Heavy Hitter Oracle (H2O)

核心思想:不是所有 KV 都重要,识别并保留最重要的

问题:长上下文 KV Cache 太大
解决:驱逐不重要的 KV,保留重要的

H2O 假设:某些 token(如 "the", "is")贡献小,某些(如 "not", "but")贡献大
python
class H2OKVCache:
    """
    H2O: Heavy Hitter Oracle - 重要性驱动的 KV 驱逐
    """
    
    def __init__(self, max_cache_size):
        self.max_cache_size = max_cache_size
        self.cache = {}  # token_id -> importance_score
        self.kv_cache = {}
    
    def compute_importance(self, query, key, value, token_id):
        """
        计算当前 token 的重要性
        公式:importance = |Q·K| / sqrt(d)
        即:当前 query 和这个 key 的注意力分数
        """
        # Q @ K^T / sqrt(d)
        score = torch.matmul(query, key.transpose(-2, -1)).abs().mean()
        return score.item()
    
    def update(self, token_id, key, value, query_for_next):
        """
        更新 KV Cache,驱逐低重要性 token
        """
        # 计算当前 token 的重要性
        importance = self.compute_importance(
            query_for_next, key, value, token_id
        )
        
        # 如果 cache 满了,驱逐最不重要的
        if len(self.kv_cache) >= self.max_cache_size:
            # 驱逐重要性最低的
            min_importance_token = min(self.cache.items(), key=lambda x: x[1])
            self.evict(min_importance_token[0])
        
        # 存入新的
        self.cache[token_id] = importance
        self.kv_cache[token_id] = (key, value)
    
    def evict(self, token_id):
        """驱逐低重要性 token"""
        self.cache.pop(token_id, None)
        self.kv_cache.pop(token_id, None)

5.2.2 SnapKV

核心思想:先用观测窗口决定保留哪些 KV,再推理

python
class SnapKV:
    """
    SnapKV: 通过观察窗口预判重要性
    
    两阶段:
    1. Prefill 阶段:用小窗口观察哪些 KV 重要
    2. Decode 阶段:只保留重要的 KV
    """
    
    def __init__(self, observe_window=64, max_cache=4096):
        self.observe_window = observe_window
        self.max_cache = max_cache
        self.important_positions = []
    
    def prefill_and_observe(self, hidden_states):
        """
        第一阶段:用小窗口观察,统计每个位置的累计 attention
        """
        # 只用前 observe_window 个 token 做 attention
        prefix = hidden_states[:self.observe_window]
        
        # 统计每个位置被注意到的次数
        attention_counts = torch.zeros(hidden_states.shape[0])
        
        for i in range(self.observe_window, hidden_states.shape[0]):
            # 简化的 attention 计算
            query = hidden_states[i]
            # 计算与 prefix 的 attention
            attn = F.softmax((query @ prefix.transpose(-2, -1)) / (hidden_states.shape[-1] ** 0.5))
            # 累加到每个 prefix 位置
            attention_counts[:self.observe_window] += attn[0]
        
        # 找出最重要的位置
        self.important_positions = torch.topk(
            attention_counts, 
            min(self.max_cache, self.observe_window)
        ).indices.tolist()
    
    def decode_with_cache(self, new_hidden):
        """
        第二阶段:只解码重要位置
        """
        # 只对重要位置做计算
        important_kv = self.kv_cache[:, self.important_positions]
        
        # 计算 attention
        attn = torch.matmul(new_hidden, important_kv.transpose(-2, -1))
        # ...

5.2.3 StreamingLLM: Attention Sink

现象:LLM 对某些"锚点"token(如句首的 [CLS])有异常高的 attention

python
class StreamingLLM:
    """
    StreamingLLM: 利用 Attention Sink 现象
    
    观察:
    - LLM 对第一个 token 有异常高的 attention(sink)
    - 最近的几个 token 也有较高 attention(recent)
    - 中间的 token attention 很低
    
    所以:
    - 只缓存 [sink tokens] + [recent tokens]
    - 中间的直接丢弃
    """
    
    def __init__(self, sink_tokens=4, recent_tokens=64):
        self.sink_tokens = sink_tokens
        self.recent_tokens = recent_tokens
        
        # 固定的前缀(sink)
        self.sink_cache = []
        # 最近的 tokens
        self.recent_cache = []
    
    def update(self, token_id, kv):
        """
        更新 rolling cache
        """
        # Sink tokens 固定不变
        if len(self.sink_cache) < self.sink_tokens:
            self.sink_cache.append((token_id, kv))
        
        # Recent tokens 用 buffer
        self.recent_cache.append((token_id, kv))
        if len(self.recent_cache) > self.recent_tokens:
            self.recent_cache.pop(0)
    
    def get_cached_kv(self):
        """
        获取拼接后的缓存
        """
        all_kv = []
        for token_id, kv in self.sink_cache:
            all_kv.append(kv)
        for token_id, kv in self.recent_cache:
            all_kv.append(kv)
        
        return torch.cat(all_kv, dim=2)  # 沿 seq 维度拼接

可视化

Attention Pattern:

        Sink         Recent
          ↓            ↓
Token: [S][S][S][S] ... ... ... [R][R][R][R]
         ↑______________________________↑
              被忽略的中间部分


为什么有效?
- Sink: 提供"起点"信息,帮助理解上下文框架
- Recent: 最近的 tokens 包含即时信息
- 中间部分:信息已被"蒸馏"到 recent 里

5.3 短训长推:SelfExtend / LM-Infinite

5.3.1 SelfExtend

核心思想:不需要重新训练,只需修改 attention mask

python
class SelfExtendAttention:
    """
    SelfExtend: Grouped Attention,不改变模型权重
    
    核心思想:
    - 近距离 tokens: 标准 local attention
    - 远距离 tokens: 分组 group attention
    """
    
    def __init__(self, local_window=256, group_size=16):
        self.local_window = local_window
        self.group_size = group_size
    
    def get_attention_mask(self, seq_len):
        """
        生成 attention mask
        """
        mask = torch.zeros(seq_len, seq_len)
        
        for i in range(seq_len):
            for j in range(seq_len):
                distance = i - j
                
                if distance <= 0:
                    # 未来 token,mask 掉
                    mask[i, j] = float('-inf')
                elif distance <= self.local_window:
                    # 局部窗口,标准 attention
                    mask[i, j] = 0
                else:
                    # 远距离,分组
                    group_id = distance // self.group_size
                    mask[i, j] = 0  # 允许关注同组的
        
        return mask
    
    def forward(self, q, k, v, seq_len):
        """
        带 mask 的 attention
        """
        # 计算 attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
        
        # 应用 mask
        mask = self.get_attention_mask(seq_len).to(q.device)
        scores = scores + mask
        
        # softmax
        attn = F.softmax(scores, dim=-1)
        return torch.matmul(attn, v)

5.3.2 LM-Infinite

核心思想:对距离超过 Λ 的 token 应用 uniform attention

python
class LMInfiniteAttention:
    """
    LM-Infinite: ∞-transformer 的简化实现
    
    核心思想:
    - Λ 之内的 token: 正常 local attention
    - Λ 之外的 token: 假设均匀分布,用 O(1) 表示
    """
    
    def __init__(self, lambda_param=4096, local_window=2048):
        self.lambda_param = lambda_param  # 有效距离阈值
        self.local_window = local_window
    
    def forward(self, q, k, v, position_ids):
        """
        position_ids: 每个 token 的位置
        """
        B, T, H, D = q.shape
        
        # 1. 计算每个 token 的有效距离
        distances = position_ids.unsqueeze(-1) - position_ids.unsqueeze(-2)
        
        # 2. 分离 local 和 global
        local_mask = distances.abs() <= self.local_window
        global_mask = (distances > 0) & (distances > self.local_window) & (distances <= self.lambda_param)
        out_of_range_mask = distances > self.lambda_param
        
        # 3. Local attention
        local_scores = torch.matmul(q, k.transpose(-2, -1))
        local_scores = local_scores.masked_fill(~local_mask, float('-inf'))
        
        # 4. Global attention (uniform)
        # 对超过 lambda 的 token,用平均 attention
        global_k = k.masked_fill(out_of_range_mask.unsqueeze(-2), 0)
        global_k_sum = global_k.sum(dim=2, keepdim=True)
        global_count = (~out_of_range_mask).sum(dim=2, keepdim=True).clamp(min=1)
        global_k_avg = global_k_sum / global_count
        
        global_scores = torch.matmul(q, global_k_avg.transpose(-2, -1))
        
        # 5. 合并
        # ... (省略细节)
        
        return attn @ v

5.4 上下文检索增强

5.4.1 RAG vs 长上下文的边界

RAG 适用场景:
- 知识库检索(事实性知识)
- 需要精确检索的场景
- 知识会更新的场景

长上下文适用场景:
- 需要整体理解长文档
- 复杂推理(需要在全文中跳转)
- 代码库理解
- 多文档摘要

5.4.2 上下文压缩

python
class ContextCompressor:
    """
    用小型模型压缩长上下文
    保留关键信息,减少 token 数量
    """
    
    def __init__(self, compressor_model):
        self.compressor = compressor_model
    
    def compress(self, context, max_tokens=1024):
        """
        压缩长上下文到固定长度
        """
        prompt = f"""压缩以下文本,保留关键信息:
        
{context}

压缩后的摘要(不超过 {max_tokens} tokens):"""
        
        compressed = self.compressor.generate(prompt, max_tokens=max_tokens)
        return compressed
    
    def hierarchical_compress(self, chunks, level=2):
        """
        层级压缩:
        L1: 每个 chunk → 摘要
        L2: L1 摘要 → 全局摘要
        """
        # Level 1: chunk → summary
        level1_summaries = [self.compress(chunk) for chunk in chunks]
        
        if level >= 2:
            # Level 2: summaries → global summary
            global_summary = self.compress('\n'.join(level1_summaries))
            return global_summary, level1_summaries
        else:
            return level1_summaries

5.4.3 选择性上下文

python
class SelectiveContext:
    """
    选择性保留上下文:只保留对当前查询重要的部分
    
    步骤:
    1. 给每个 context chunk 打"重要性分数"
    2. 只保留分数超过阈值的前 N 个 chunk
    """
    
    def __init__(self, importance_model):
        self.model = importance_model
    
    def score_chunks(self, query, context_chunks):
        """
        给每个 chunk 打分
        """
        scores = []
        for chunk in context_chunks:
            # 简单方法:算 query 和 chunk 的重叠度
            query_tokens = set(query.lower().split())
            chunk_tokens = set(chunk.lower().split())
            overlap = len(query_tokens & chunk_tokens)
            scores.append(overlap)
        
        return scores
    
    def select(self, query, context_chunks, top_k=5):
        """
        选择最重要的 chunks
        """
        scores = self.score_chunks(query, context_chunks)
        
        # 取 top-k
        top_indices = sorted(range(len(scores)), 
                           key=lambda i: scores[i], 
                           reverse=True)[:top_k]
        
        return [context_chunks[i] for i in sorted(top_indices)]

5.5 本章小结

┌─────────────────────────────────────────────────────────────┐
│                 长上下文技术全景图                             │
├─────────────────────────────────────────────────────────────┤
│  微调阶段                                                   │
│  ├─ LongLoRA    │  S²-Attn 近似,组内 attention,省计算    │
│  ├─ LISA        │  按层重要性分配 LoRA rank                │
│  └─ 长文本数据   │  截取/拼接/合成 多种构造方式              │
├─────────────────────────────────────────────────────────────┤
│  推理阶段                                                   │
│  ├─ H2O         │  驱逐低重要性 KV,保留 Heavy Hitters      │
│  ├─ SnapKV      │  观测窗口预判,只保留关键 KV              │
│  ├─ StreamingLLM│  利用 Attention Sink + Recent Buffer     │
│  └─ SelfExtend  │  Grouped Attention,无需重训练            │
├─────────────────────────────────────────────────────────────┤
│  短训长推                                                   │
│  ├─ LM-Infinite │  Λ之外 uniform attention                  │
│  └─ 位置插值    │  Position Interpolation                  │
└─────────────────────────────────────────────────────────────┘

Attention Sink 可视化

典型 LLM 的 Attention Pattern:

位置   0    1    2    3    4    ...   100   101   102
      ┌────┬────┬────┬────┬────┬        ┌────┬────┬────┐
Token │ CLS│ the│ cat│ sat│ on│  ...   │ the│ mat│ .  │
      ├────┼────┼────┼────┼────┤        ├────┼────┼────┤
 Attn │████████│▒▒▒▒│▒▒▒▒│▒▒▒▒│        │▒▒▒▒│████│▒▒▒▒│
      │ Sink  │    │    │    │        │    │Recent│    │
      └────┴────┴────┴────┴────┘        └────┴────┴────┘
      
██ = 高 attention (Sink + Recent)
▒▒ = 低 attention (可丢弃)

面试高频问题

  1. Attention Sink 是什么?为什么有用? → LLM 对句首 token 有异常高 attention,作为"锚点"帮助理解;中间 token 可以丢弃

  2. H2O 和 SnapKV 的区别? → H2O 是动态驱逐;SnapKV 是预判重要位置

  3. LongLoRA 为什么能省计算? → S²-Attn 把长序列分成组,组内做 attention,减少 O(n²) 到 O(n)

  4. RAG 和长上下文的取舍? → 知识检索用 RAG,复杂推理用长上下文

基于 MIT 许可发布