On-Policy Distillation:让学生自己犯错、自己改正
1. 传统蒸馏的问题:一直在模仿别人的"范文"
在大模型落地过程中,蒸馏是把大模型能力迁移到小模型的常用手段。传统的 logits 蒸馏思路很直接:让教师模型生成一批数据,学生模型学习这些数据的 token 分布。这在学术上叫 Off-policy Distillation——因为训练数据来自教师策略 \(\pi_T\),而学生真正推理时面对的是自己的策略 \(\pi_S\)。
这个范式在短文本分类、简单问答上效果不错,但一旦进入长链推理(数学证明、代码生成、多轮工具调用),问题就来了:自回归误差积累。
想象一个数学推理场景。教师在第 3 步写的是"令 \(x=2\)",学生训练时反复看到这个正确示范。但真实推理时,学生可能在第 3 步写成"设 \(x=2\)"——这本身未必错,却改变了后续所有 token 的条件分布。接下来第 4 步,学生面对的前缀变成了自己写的"设 \(x=2\)",而训练时它从未见过教师在这个前缀下会怎么做。一步偏差之后,学生逐渐滑向教师数据分布从未覆盖过的区域,开始胡编乱造。
这就是模仿学习里经典的 Train-Test Mismatch:学生一直在"看别人写的范文"上训练,考试时却必须"自己写作文"。教师的数据不可能覆盖学生所有可能犯错的状态,而 LLM 的自回归特性会让这种分布偏移指数级放大。
2. OPD 的做法:从"抄范文"到"改作业"
On-Policy Distillation(OPD)解决上述问题的思路很直接:既然推理时是学生自己生成,那训练时也让学生自己生成;既然错误会一步步积累,那就在学生每一步的轨迹上给出密集纠正信号。
它的训练循环遵循三步:
- Rollout:给定输入 \(x\),学生模型
\(\pi_\theta\) 自回归生成完整回复 \(y\);
- Evaluation:教师模型 \(\pi_T\)
不重新生成,而是直接对学生生成的 \(y\)
做前向传播,逐 token 输出 logits;
- Alignment:用学生分布与教师分布之间的差异作为损失,更新学生模型。
和传统蒸馏的差异在于:教师不再是"范文生成者",而是"作业批改者"。教师不需要写出一篇新文章,只需要对学生写好的文章逐字打分。这意味着教师可以是更大的模型、黑盒 API,甚至是带有特权信息(如 gold reasoning trace、检索文档)的同一个模型。
3. 算法实现:Reverse KL 与 On-Policy 数据流
3.1 数据流和损失函数
OPD 的标准实现如下:
for batch in dataloader:
# Step 1: 学生当前策略生成
y = student.generate(x, temperature=0.7) # y ~ π_θ
# Step 2: 教师对学生轨迹打分(前向传播,无需反向传播)
teacher_logits = teacher(context=c, input=x, output=y)
student_logits = student(input=x, output=y)
# Step 3: 计算 Reverse KL 并更新
loss = reverse_kl(student_logits, teacher_logits)
student.update(loss)
这里有两个必须同时满足的条件:
- On-Policy:序列 \(y\)
必须来自学生当前 checkpoint
的采样分布,而非教师历史数据或静态语料;
- Token-level Dense Signal:教师对 \(y\) 的每一个位置 \(t\) 都输出完整词表分布,损失函数在词表级别精确计算,而非只在采样到的 token 上算交叉熵。
3.2 为什么用 Reverse KL?
OPD 通常最小化 Reverse KL Divergence。要理解这个选择,我们需要先回到 KL 散度的本质。
4. KL 散度:信息论视角下的"编码浪费"
KL 散度衡量的是:当你用分布 \(Q\)(学生)去近似真实分布 \(P\)(教师)时,平均会多浪费多少比特的信息量。
它的定义是:
\[ D_{KL}(P \| Q) = \mathbb{E}_{x \sim P} \left[ \log \frac{P(x)}{Q(x)} \right] \]
注意一个关键性质:KL 散度不对称。\(D_{KL}(P\|Q)\) 和 \(D_{KL}(Q\|P)\) 是两个完全不同的目标,对应完全不同的优化行为。
4.1 期望底数:站在谁的角度看问题
两种 KL 的差异在于期望是在哪个分布下求的:
| 类型 | 写法 | 期望底数 | 含义 |
|---|---|---|---|
| Forward KL | \(D_{KL}(P \| Q)\) | \(x \sim P\)(教师) | 站在教师的角度,看学生的编码多烂 |
| Reverse KL | \(D_{KL}(Q \| P)\) | \(x \sim Q\)(学生) | 站在学生的角度,看教师的编码多烂 |
这个"期望底数"不是抽象概念,它直接决定了蒙特卡洛估计时你该用谁的样本做平均。
- 如果你要无偏估计 \(D_{KL}(P\|Q)\),你必须有从 \(P\) 采样的数据;
- 如果你要无偏估计 \(D_{KL}(Q\|P)\),你必须有从 \(Q\) 采样的数据。
而 OPD 的数据流天然产生的是 \(y \sim Q\)(学生 rollouts)。因此,Reverse KL 可以直接对学生生成的序列求平均,得到无偏估计;Forward KL 若强行用学生数据估计,则需要引入重要性采样权重 \(P(y)/Q(y)\),在 LLM 的长序列空间上方差极大,工程上几乎不可用。
4.2 Mode-covering vs Mode-seeking:看个实际例子
假设在某个推理步骤,教师模型(大模型)对下一个词的概率分布如下:
| token | 教师 \(P\) | 语义 |
|---|---|---|
| "分析" | 0.40 | 展开逻辑分析 |
| "计算" | 0.35 | 进入数值运算 |
| "推导" | 0.15 | 形式化推导 |
| "求解" | 0.10 | 直接给答案 |
教师是"多面手",四种表达都合理。但学生模型(小模型或 LoRA)容量有限,在这个位置只能强表达一个偏好(单峰)。假设学生初始分布为:
| token | 学生 \(Q\) |
|---|---|
| "分析" | 0.85 |
| "计算" | 0.08 |
| "推导" | 0.04 |
| "求解" | 0.03 |
Forward KL:摊平覆盖(Mode-covering)
Forward KL 在教师分布上求期望。如果学生不给某个教师支持的词分配足够概率,惩罚会急剧上升:
\[ D_{KL}(P \| Q) = 0.40 \log\frac{0.40}{Q(\text{"分析"})} + 0.35 \log\frac{0.35}{Q(\text{"计算"})} + 0.15 \log\frac{0.15}{Q(\text{"推导"})} + 0.10 \log\frac{0.10}{Q(\text{"求解"})} \]
如果学生坚持单峰押注"分析"(\(Q(\text{"分析"})=0.85\),其余接近
0):
- 第一项:\(0.40 \times \log(0.40/0.85)
\approx 0.40 \times (-0.76) =
-0.30\)(负值,但注意整体被其他项主导)
- 第二项:\(0.35 \times \log(0.35/0.08)
\approx 0.35 \times 1.48 = 0.52\)
- 第三项:\(0.15 \times \log(0.15/0.04)
\approx 0.15 \times 1.32 = 0.20\)
- 第四项:\(0.10 \times \log(0.10/0.03)
\approx 0.10 \times 1.20 = 0.12\)
总损失约为 \(0.54\),且梯度会迫使学生把概率摊向教师的所有模式。最终学生可能变成:
| token | 学生 \(Q\)(Forward KL 后) |
|---|---|
| "分析" | 0.42 |
| "计算" | 0.38 |
| "推导" | 0.12 |
| "求解" | 0.08 |
问题出现了:学生原来在"分析"上有 0.85 的置信度,被稀释到了 0.42。推理时,它在"分析"和"计算"之间随机跳跃,前后步骤的术语不一致,导致逻辑连贯性崩坏。这就是 Mode-covering——学生被迫覆盖教师的所有表达模式,结果每个模式都学不精。
Reverse KL:锁定修正(Mode-seeking)
Reverse KL 在学生分布上求期望:
\[ D_{KL}(Q \| P) = 0.85 \log\frac{0.85}{0.40} + 0.08 \log\frac{0.08}{0.35} + 0.04 \log\frac{0.04}{0.15} + 0.03 \log\frac{0.03}{0.10} \]
逐项看:
- 第一项("分析"):\(0.85 \times \log(2.125)
\approx 0.85 \times 0.75 = 0.64\)
- 第二项("计算"):\(0.08 \times \log(0.23)
\approx 0.08 \times (-1.47) = -0.12\)
- 第三、四项:权重极小,贡献接近 0
可以看到,虽然单个项可能是负的,但 KL 散度的总和必然非负(这是数学性质)。关键在于:"推导"和"求解"这两个词因为学生概率极低,几乎不参与损失计算。
Reverse KL 的梯度会把学生推向什么方向?它只会在学生实际高概率的区域("分析")上,将 \(Q\) 向 \(P\) 靠拢(从 0.85 往 0.40 微调),同时不会强迫学生给"计算""推导"分配显著概率。
最终学生可能变成:
| token | 学生 \(Q\)(Reverse KL 后) |
|---|---|
| "分析" | 0.72 |
| "计算" | 0.15 |
| "推导" | 0.08 |
| "求解" | 0.05 |
结果:学生保持单峰特性,稳定地使用"分析"这一表达,内部逻辑一致。它放弃了模仿教师的所有措辞风格,只确保自己走的每一步在教师看来"合理即可"。这就是 Mode-seeking——学生寻找并锁定一个自己擅长、教师也认可的模式。
4.3 Reverse KL 的"硬边界"反而是好事
Reverse KL 有一个严格的数学约束:如果学生给某个词分配了非零概率,而教师给零概率,损失会爆炸到无穷大:
\[ Q(x) > 0, \ P(x) = 0 \ \Rightarrow \ Q(x) \log\frac{Q(x)}{0} \to +\infty \]
这意味着 \(\text{supp}(Q) \subseteq \text{supp}(P)\)——学生的支持集必须是教师支持集的子集。学生"不敢"探索教师认为绝对不可能的区域。
在 OPD
场景中,这反而是安全机制:它防止学生自由发挥到低质量或语法崩坏的区域,确保蒸馏始终发生在教师认可的语言空间内。(工程上通常会对教师概率做
epsilon 平滑,如 clamp(min=1e-6),来避免数值溢出)
5. 工程实现的几个细节
5.1 采样频率:要不要每步都重新生成?
理论上,一旦学生参数更新(\(\theta_t \to \theta_{t+1}\)),之前的 rollouts 就"过期"了。完全严格的 OPD 每步都重新采样:
for step in range(total_steps):
y = student.generate(x) # 来自 π_θ_t
loss = reverse_kl(student(x,y), teacher(c,x,y))
student.update(loss) # 变成 π_θ_{t+1}
# 下一步必须重新 generate
但 LLM 的生成代价远高于训练(通常 10-100 倍)。工程上普遍采用 PPO 风格的缓冲池策略:每隔 \(N\) 步用当前 checkpoint 采样一批数据,重复训练 \(K\) 个 epoch(通常 \(K \leq 4\))。只要 policy lag 不大,效果与严格 on-policy 几乎无差异。
5.2 和现有框架的兼容性
OPD 对基础设施很友好:
- 教师无需反向传播:只需前向传播输出 logits,可以是更大的模型或黑盒
API;
- 与 LoRA 天然适配:Qwen3 技术报告显示,SFT 后 LoRA 与全量微调差距约
13%,经过 OPD 后差距缩小到 6%,且 OPD 阶段可用更小的 batch size(因为
token-level 信号密度高,梯度噪声低);
- 接入 RL 框架很简单:如果你已有 VeRL / verl 等 RL 流水线,只需将
reference model 的 KL penalty 替换为 teacher model 的 log-prob
评估,即完成 OPD 改造。
5.3 温度和采样策略
- 学生生成:通常使用 temperature > 0(如
0.7-1.0)采样,覆盖自身分布中的多种错误模式;
- 教师评估:可用 temperature=0 获取确定性反馈,降低方差。
6. 什么时候用 OPD?
OPD 最适合的场景:
- 长链推理任务(数学、代码):早期错误会导致后期严重偏离,OPD
能纠正中间步骤;
- 持续学习/对齐恢复:当 SFT 会遗忘 post-training 行为时,OPD
可以恢复已学习的特性,而不会像 SFT 那样在自身样本上训练导致退化;
- 跨架构蒸馏:如从 Dense 模型蒸馏到 MoE,OPD 比静态数据更鲁棒。
局限:
- 需要在线生成学生轨迹,训练吞吐低于静态 SFT(但比 RL 好很多,因为不需要
Value Model 或大量 rollout);
-
若学生与教师能力差距过大(教师的高概率区域完全不在学生的支撑集内),需配合
warm-up SFT 缩小初始分布差异。
7. 小结
On-Policy Distillation 的核心思路可以概括为一句话:
学生生成自己的轨迹 → 教师对这条轨迹的每个 token 打分 → 用 Reverse KL 将学生分布拉向教师分布,但只纠正学生实际会犯的错误。
它与传统蒸馏的差异不在于"有没有教师信号",而在于信号以什么密度、在什么分布上传递:
| 方法 | 数据来源 | 反馈密度 | 核心行为 |
|---|---|---|---|
| Off-policy Distillation | 教师生成 | Token-level Dense | 模仿教师范文,但分布不匹配 |
| RL (PPO/GRPO) | 学生当前策略 | Episode-level Sparse | 探索自身分布,但信用分配困难 |
| OPD | 学生当前策略 | Token-level Dense | 在自己分布上被密集纠正 |
从信息论视角看,RL 每个 episode 只传输 \(O(1)\) bits 信号,而 OPD 传输 \(O(N)\) bits(\(N\) 为序列长度)。这种效率优势使得 OPD 在复现同等策略时,所需梯度步数通常比 RL 少一个数量级。
如果你已有 RL 训练框架,实现 OPD 的工程改动很小;如果你正在用 LoRA 做领域适配,在 SFT 后追加一轮 OPD 往往是提升推理稳定性的最高性价比手段。
【推荐文章】
- Agent:
Karpathy所说的LLM
Wiki
Harness
Engineer
O-MEM
字节的M3-Agent
DeepResearch的报告生成方法
从RAG到DeepSearch
阿里通义Lab:
WebWalker,WebDancer和WebSailor
Agent评测数据集
Agent完全手册(零):三大模块,三个理念
agent调研(1)--MetaGPT,OpenManus和OWL
Devin和Anthropic的Agent开发经验
- MoE:
DeepSeek-V3细节探索
MoE模型的前世今生
DeepSeek-V2和MLA
昆仑万维-SkyworkMoE
成本10w刀的JetMoE
MoE的top-p
routing
对MoE模型的一些观察
从dense到MoE -- sparse
upcycling
MoE路由--expert choice
routing
- 端侧模型:
苹果智能系统模型--AFM
MiniCPM
适合移动设备的语言模型--MobileLLM
phi系列模型
Gemma2
苹果的OpenELM
bilibili的index-1.9B
- 预训练:
预训练经验
Qwen3实测&技术报告
代码大模型(一)--业界现状
代码大模型(二)--OpenCoder
LLM高效预训练(一)
LLM高效预训练(二)
Llama3.1--预训练要点一览
Qwen2技术报告
Yi技术报告-划重点看细节
InternLM系列模型
GLM4报告的一些技术点
从Yuan2.0到Yuan2.0-M32
从loss视角理解大模型涌现能力
- 数据:
训练数据合成(一)
训练数据合成(二)
训练数据合成(三)
LLM预训练数据策略(一)
预训练数据处理--长度分解
- 长上下文:
Qwen2.5-1M技术解析
LLM长上下文的问题
解锁大模型长上下文能力
大模型推理窗口-从有限到无限大
prompt压缩(一)
prompt压缩(二)
reasoning压缩(一)
- 推理加速:
大模型推理加速-投机解码
大模型推理加速-MEDUSA
- 对齐:
VeRA,LoRA-XS和TinyLoRA
腾讯的Training-Free
GRPO
深度求索DeepSeek-R1详解
基模型Cognitive
Behaviors对RL的影响
Llama3.1--post-training要点一览
模型平均 -- model
soup
大模型偏好对齐-DPO
大模型偏好对齐-ODPO
大模型偏好对齐-simPO
大模型偏好对齐-IPO
- Transformer:
Attention
Residuals
理解Attention:从起源到MHA,MQA和GQA
LLM的重复生成和ICL
transformer中normalization的二三事
从代码实现看normalization-到底做了什么
稀疏注意力计算:sliding
window attention
理解LLM位置编码:RoPE
RoPE的远距离衰减
LLM水印
- 训练框架
Muon优化器
LLM训练框架:从优化器和精度讲到ZeRO
LLM训练各种并行策略
- 项目应用:
一个模型支持智能助手系统
关于The Bitter
Lesson
- CV:
CV入门--关于Vision
Transformer
CV入门--无监督学习
- 多模态:
多模态入门(一)--CLIP
多模态入门(二)--Flamingo,LLaVA系列和BLIP系列
多模态入门(三)--MiniGPT4,DeepSeekVL,InternVL系列和QwenVL系列
多模态入门(四)--CogVLM,VILA,MM1,MM1.5和Pixtral-12B
多模态入门(五)--InternVL系列
小米的移动UI多模态模型--MobileVLM
DeepSeek-VL2的细节
- 论文阅读:
最近阅读--关于数据合成、agent、reasoning和多任务
最近阅读2-关于自适应深度思考、context
engineering和模型训练
- 大模型算法题:
(1)、 (2)、 (3)、 (4)、 (5)、 (6)、 (7)、 (8)、 (9)