On-Policy Distillation:让学生自己犯错、自己改正

1. 传统蒸馏的问题:一直在模仿别人的"范文"

在大模型落地过程中,蒸馏是把大模型能力迁移到小模型的常用手段。传统的 logits 蒸馏思路很直接:让教师模型生成一批数据,学生模型学习这些数据的 token 分布。这在学术上叫 Off-policy Distillation——因为训练数据来自教师策略 \(\pi_T\),而学生真正推理时面对的是自己的策略 \(\pi_S\)

这个范式在短文本分类、简单问答上效果不错,但一旦进入长链推理(数学证明、代码生成、多轮工具调用),问题就来了:自回归误差积累。

想象一个数学推理场景。教师在第 3 步写的是"令 \(x=2\)",学生训练时反复看到这个正确示范。但真实推理时,学生可能在第 3 步写成"设 \(x=2\)"——这本身未必错,却改变了后续所有 token 的条件分布。接下来第 4 步,学生面对的前缀变成了自己写的"设 \(x=2\)",而训练时它从未见过教师在这个前缀下会怎么做。一步偏差之后,学生逐渐滑向教师数据分布从未覆盖过的区域,开始胡编乱造。

这就是模仿学习里经典的 Train-Test Mismatch:学生一直在"看别人写的范文"上训练,考试时却必须"自己写作文"。教师的数据不可能覆盖学生所有可能犯错的状态,而 LLM 的自回归特性会让这种分布偏移指数级放大。

2. OPD 的做法:从"抄范文"到"改作业"

On-Policy Distillation(OPD)解决上述问题的思路很直接:既然推理时是学生自己生成,那训练时也让学生自己生成;既然错误会一步步积累,那就在学生每一步的轨迹上给出密集纠正信号。

它的训练循环遵循三步:

  1. Rollout:给定输入 \(x\),学生模型 \(\pi_\theta\) 自回归生成完整回复 \(y\)
  2. Evaluation:教师模型 \(\pi_T\) 不重新生成,而是直接对学生生成的 \(y\) 做前向传播,逐 token 输出 logits;
  3. Alignment:用学生分布与教师分布之间的差异作为损失,更新学生模型。

和传统蒸馏的差异在于:教师不再是"范文生成者",而是"作业批改者"。教师不需要写出一篇新文章,只需要对学生写好的文章逐字打分。这意味着教师可以是更大的模型、黑盒 API,甚至是带有特权信息(如 gold reasoning trace、检索文档)的同一个模型。

3. 算法实现:Reverse KL 与 On-Policy 数据流

3.1 数据流和损失函数

OPD 的标准实现如下:

for batch in dataloader:
    # Step 1: 学生当前策略生成
    y = student.generate(x, temperature=0.7)  # y ~ π_θ
    
    # Step 2: 教师对学生轨迹打分(前向传播,无需反向传播)
    teacher_logits = teacher(context=c, input=x, output=y)
    student_logits = student(input=x, output=y)
    
    # Step 3: 计算 Reverse KL 并更新
    loss = reverse_kl(student_logits, teacher_logits)
    student.update(loss)

这里有两个必须同时满足的条件:

  • On-Policy:序列 \(y\) 必须来自学生当前 checkpoint 的采样分布,而非教师历史数据或静态语料;
  • Token-level Dense Signal:教师对 \(y\) 的每一个位置 \(t\) 都输出完整词表分布,损失函数在词表级别精确计算,而非只在采样到的 token 上算交叉熵。

3.2 为什么用 Reverse KL?

OPD 通常最小化 Reverse KL Divergence。要理解这个选择,我们需要先回到 KL 散度的本质。

4. KL 散度:信息论视角下的"编码浪费"

KL 散度衡量的是:当你用分布 \(Q\)(学生)去近似真实分布 \(P\)(教师)时,平均会多浪费多少比特的信息量

它的定义是:

\[ D_{KL}(P \| Q) = \mathbb{E}_{x \sim P} \left[ \log \frac{P(x)}{Q(x)} \right] \]

注意一个关键性质:KL 散度不对称\(D_{KL}(P\|Q)\)\(D_{KL}(Q\|P)\) 是两个完全不同的目标,对应完全不同的优化行为。

4.1 期望底数:站在谁的角度看问题

两种 KL 的差异在于期望是在哪个分布下求的:

类型 写法 期望底数 含义
Forward KL \(D_{KL}(P \| Q)\) \(x \sim P\)(教师) 站在教师的角度,看学生的编码多烂
Reverse KL \(D_{KL}(Q \| P)\) \(x \sim Q\)(学生) 站在学生的角度,看教师的编码多烂

这个"期望底数"不是抽象概念,它直接决定了蒙特卡洛估计时你该用谁的样本做平均。

  • 如果你要无偏估计 \(D_{KL}(P\|Q)\),你必须有从 \(P\) 采样的数据;
  • 如果你要无偏估计 \(D_{KL}(Q\|P)\),你必须有从 \(Q\) 采样的数据。

而 OPD 的数据流天然产生的是 \(y \sim Q\)(学生 rollouts)。因此,Reverse KL 可以直接对学生生成的序列求平均,得到无偏估计;Forward KL 若强行用学生数据估计,则需要引入重要性采样权重 \(P(y)/Q(y)\),在 LLM 的长序列空间上方差极大,工程上几乎不可用。

4.2 Mode-covering vs Mode-seeking:看个实际例子

假设在某个推理步骤,教师模型(大模型)对下一个词的概率分布如下:

token 教师 \(P\) 语义
"分析" 0.40 展开逻辑分析
"计算" 0.35 进入数值运算
"推导" 0.15 形式化推导
"求解" 0.10 直接给答案

教师是"多面手",四种表达都合理。但学生模型(小模型或 LoRA)容量有限,在这个位置只能强表达一个偏好(单峰)。假设学生初始分布为:

token 学生 \(Q\)
"分析" 0.85
"计算" 0.08
"推导" 0.04
"求解" 0.03

Forward KL:摊平覆盖(Mode-covering)

Forward KL 在教师分布上求期望。如果学生不给某个教师支持的词分配足够概率,惩罚会急剧上升:

\[ D_{KL}(P \| Q) = 0.40 \log\frac{0.40}{Q(\text{"分析"})} + 0.35 \log\frac{0.35}{Q(\text{"计算"})} + 0.15 \log\frac{0.15}{Q(\text{"推导"})} + 0.10 \log\frac{0.10}{Q(\text{"求解"})} \]

如果学生坚持单峰押注"分析"(\(Q(\text{"分析"})=0.85\),其余接近 0):
- 第一项:\(0.40 \times \log(0.40/0.85) \approx 0.40 \times (-0.76) = -0.30\)(负值,但注意整体被其他项主导)
- 第二项:\(0.35 \times \log(0.35/0.08) \approx 0.35 \times 1.48 = 0.52\)
- 第三项:\(0.15 \times \log(0.15/0.04) \approx 0.15 \times 1.32 = 0.20\)
- 第四项:\(0.10 \times \log(0.10/0.03) \approx 0.10 \times 1.20 = 0.12\)

总损失约为 \(0.54\),且梯度会迫使学生把概率摊向教师的所有模式。最终学生可能变成:

token 学生 \(Q\)(Forward KL 后)
"分析" 0.42
"计算" 0.38
"推导" 0.12
"求解" 0.08

问题出现了:学生原来在"分析"上有 0.85 的置信度,被稀释到了 0.42。推理时,它在"分析"和"计算"之间随机跳跃,前后步骤的术语不一致,导致逻辑连贯性崩坏。这就是 Mode-covering——学生被迫覆盖教师的所有表达模式,结果每个模式都学不精。

Reverse KL:锁定修正(Mode-seeking)

Reverse KL 在学生分布上求期望:

\[ D_{KL}(Q \| P) = 0.85 \log\frac{0.85}{0.40} + 0.08 \log\frac{0.08}{0.35} + 0.04 \log\frac{0.04}{0.15} + 0.03 \log\frac{0.03}{0.10} \]

逐项看:
- 第一项("分析"):\(0.85 \times \log(2.125) \approx 0.85 \times 0.75 = 0.64\)
- 第二项("计算"):\(0.08 \times \log(0.23) \approx 0.08 \times (-1.47) = -0.12\)
- 第三、四项:权重极小,贡献接近 0

可以看到,虽然单个项可能是负的,但 KL 散度的总和必然非负(这是数学性质)。关键在于:"推导"和"求解"这两个词因为学生概率极低,几乎不参与损失计算。

Reverse KL 的梯度会把学生推向什么方向?它只会在学生实际高概率的区域("分析")上,将 \(Q\)\(P\) 靠拢(从 0.85 往 0.40 微调),同时不会强迫学生给"计算""推导"分配显著概率。

最终学生可能变成:

token 学生 \(Q\)(Reverse KL 后)
"分析" 0.72
"计算" 0.15
"推导" 0.08
"求解" 0.05

结果:学生保持单峰特性,稳定地使用"分析"这一表达,内部逻辑一致。它放弃了模仿教师的所有措辞风格,只确保自己走的每一步在教师看来"合理即可"。这就是 Mode-seeking——学生寻找并锁定一个自己擅长、教师也认可的模式。

4.3 Reverse KL 的"硬边界"反而是好事

Reverse KL 有一个严格的数学约束:如果学生给某个词分配了非零概率,而教师给零概率,损失会爆炸到无穷大:

\[ Q(x) > 0, \ P(x) = 0 \ \Rightarrow \ Q(x) \log\frac{Q(x)}{0} \to +\infty \]

这意味着 \(\text{supp}(Q) \subseteq \text{supp}(P)\)——学生的支持集必须是教师支持集的子集。学生"不敢"探索教师认为绝对不可能的区域。

在 OPD 场景中,这反而是安全机制:它防止学生自由发挥到低质量或语法崩坏的区域,确保蒸馏始终发生在教师认可的语言空间内。(工程上通常会对教师概率做 epsilon 平滑,如 clamp(min=1e-6),来避免数值溢出)

5. 工程实现的几个细节

5.1 采样频率:要不要每步都重新生成?

理论上,一旦学生参数更新(\(\theta_t \to \theta_{t+1}\)),之前的 rollouts 就"过期"了。完全严格的 OPD 每步都重新采样:

for step in range(total_steps):
    y = student.generate(x)        # 来自 π_θ_t
    loss = reverse_kl(student(x,y), teacher(c,x,y))
    student.update(loss)           # 变成 π_θ_{t+1}
    # 下一步必须重新 generate

但 LLM 的生成代价远高于训练(通常 10-100 倍)。工程上普遍采用 PPO 风格的缓冲池策略:每隔 \(N\) 步用当前 checkpoint 采样一批数据,重复训练 \(K\) 个 epoch(通常 \(K \leq 4\))。只要 policy lag 不大,效果与严格 on-policy 几乎无差异。

5.2 和现有框架的兼容性

OPD 对基础设施很友好:
- 教师无需反向传播:只需前向传播输出 logits,可以是更大的模型或黑盒 API;
- 与 LoRA 天然适配:Qwen3 技术报告显示,SFT 后 LoRA 与全量微调差距约 13%,经过 OPD 后差距缩小到 6%,且 OPD 阶段可用更小的 batch size(因为 token-level 信号密度高,梯度噪声低);
- 接入 RL 框架很简单:如果你已有 VeRL / verl 等 RL 流水线,只需将 reference model 的 KL penalty 替换为 teacher model 的 log-prob 评估,即完成 OPD 改造。

5.3 温度和采样策略

  • 学生生成:通常使用 temperature > 0(如 0.7-1.0)采样,覆盖自身分布中的多种错误模式;
  • 教师评估:可用 temperature=0 获取确定性反馈,降低方差。

6. 什么时候用 OPD?

OPD 最适合的场景:
- 长链推理任务(数学、代码):早期错误会导致后期严重偏离,OPD 能纠正中间步骤;
- 持续学习/对齐恢复:当 SFT 会遗忘 post-training 行为时,OPD 可以恢复已学习的特性,而不会像 SFT 那样在自身样本上训练导致退化;
- 跨架构蒸馏:如从 Dense 模型蒸馏到 MoE,OPD 比静态数据更鲁棒。

局限:
- 需要在线生成学生轨迹,训练吞吐低于静态 SFT(但比 RL 好很多,因为不需要 Value Model 或大量 rollout);
- 若学生与教师能力差距过大(教师的高概率区域完全不在学生的支撑集内),需配合 warm-up SFT 缩小初始分布差异。

7. 小结

On-Policy Distillation 的核心思路可以概括为一句话:

学生生成自己的轨迹 → 教师对这条轨迹的每个 token 打分 → 用 Reverse KL 将学生分布拉向教师分布,但只纠正学生实际会犯的错误。

它与传统蒸馏的差异不在于"有没有教师信号",而在于信号以什么密度、在什么分布上传递:

方法 数据来源 反馈密度 核心行为
Off-policy Distillation 教师生成 Token-level Dense 模仿教师范文,但分布不匹配
RL (PPO/GRPO) 学生当前策略 Episode-level Sparse 探索自身分布,但信用分配困难
OPD 学生当前策略 Token-level Dense 在自己分布上被密集纠正

从信息论视角看,RL 每个 episode 只传输 \(O(1)\) bits 信号,而 OPD 传输 \(O(N)\) bits(\(N\) 为序列长度)。这种效率优势使得 OPD 在复现同等策略时,所需梯度步数通常比 RL 少一个数量级。

如果你已有 RL 训练框架,实现 OPD 的工程改动很小;如果你正在用 LoRA 做领域适配,在 SFT 后追加一轮 OPD 往往是提升推理稳定性的最高性价比手段。


【推荐文章】
- Agent:
Karpathy所说的LLM Wiki
Harness Engineer
O-MEM
字节的M3-Agent
DeepResearch的报告生成方法
从RAG到DeepSearch
阿里通义Lab: WebWalker,WebDancer和WebSailor
Agent评测数据集
Agent完全手册(零):三大模块,三个理念
agent调研(1)--MetaGPT,OpenManus和OWL
Devin和Anthropic的Agent开发经验
- MoE:
DeepSeek-V3细节探索
MoE模型的前世今生
DeepSeek-V2和MLA
昆仑万维-SkyworkMoE
成本10w刀的JetMoE
MoE的top-p routing
对MoE模型的一些观察
从dense到MoE -- sparse upcycling
MoE路由--expert choice routing
- 端侧模型:
苹果智能系统模型--AFM
MiniCPM
适合移动设备的语言模型--MobileLLM
phi系列模型
Gemma2
苹果的OpenELM
bilibili的index-1.9B
- 预训练:
预训练经验
Qwen3实测&技术报告
代码大模型(一)--业界现状
代码大模型(二)--OpenCoder
LLM高效预训练(一)
LLM高效预训练(二)
Llama3.1--预训练要点一览
Qwen2技术报告
Yi技术报告-划重点看细节
InternLM系列模型
GLM4报告的一些技术点
从Yuan2.0到Yuan2.0-M32
从loss视角理解大模型涌现能力
- 数据:
训练数据合成(一)
训练数据合成(二)
训练数据合成(三)
LLM预训练数据策略(一)
预训练数据处理--长度分解
- 长上下文:
Qwen2.5-1M技术解析
LLM长上下文的问题
解锁大模型长上下文能力
大模型推理窗口-从有限到无限大
prompt压缩(一)
prompt压缩(二)
reasoning压缩(一)
- 推理加速:
大模型推理加速-投机解码
大模型推理加速-MEDUSA
- 对齐:
VeRA,LoRA-XS和TinyLoRA
腾讯的Training-Free GRPO
深度求索DeepSeek-R1详解
基模型Cognitive Behaviors对RL的影响
Llama3.1--post-training要点一览
模型平均 -- model soup
大模型偏好对齐-DPO
大模型偏好对齐-ODPO
大模型偏好对齐-simPO
大模型偏好对齐-IPO
- Transformer:
Attention Residuals
理解Attention:从起源到MHA,MQA和GQA
LLM的重复生成和ICL
transformer中normalization的二三事
从代码实现看normalization-到底做了什么
稀疏注意力计算:sliding window attention
理解LLM位置编码:RoPE
RoPE的远距离衰减
LLM水印
- 训练框架
Muon优化器
LLM训练框架:从优化器和精度讲到ZeRO
LLM训练各种并行策略
- 项目应用:
一个模型支持智能助手系统
关于The Bitter Lesson
- CV:
CV入门--关于Vision Transformer
CV入门--无监督学习
- 多模态:
多模态入门(一)--CLIP
多模态入门(二)--Flamingo,LLaVA系列和BLIP系列
多模态入门(三)--MiniGPT4,DeepSeekVL,InternVL系列和QwenVL系列
多模态入门(四)--CogVLM,VILA,MM1,MM1.5和Pixtral-12B
多模态入门(五)--InternVL系列
小米的移动UI多模态模型--MobileVLM
DeepSeek-VL2的细节
- 论文阅读:
最近阅读--关于数据合成、agent、reasoning和多任务
最近阅读2-关于自适应深度思考、context engineering和模型训练
- 大模型算法题:
(1)(2)(3)(4)(5)(6)(7)(8)(9)