Lightning OPD: Efficient Post-Training for Large Reasoning Models with Offline On-Policy Distillation

前言

在大型推理模型的后训练中,On-Policy Distillation (OPD) 已被证明是一种高效的训练范式,它通过密集的 per-token advantage 信号让学生模型匹配 teacher 的 token 级分布。然而,标准 OPD 需要在整个训练过程中维持一个实时的 Teacher 推理服务器,带来了巨大的基础设施开销。

Lightning OPD 的核心发现是:“Teacher Consistency”——即 SFT 阶段和 OPD 阶段必须使用同一个 Teacher 模型——是 OPD 成功的必要条件。违反这一条件会引入不可约减的梯度偏差,导致无论训练多久都无法收敛到最优解。基于这一洞察,Lightning OPD 在 SFT rollouts 上一次性预计算 Teacher 的 log-probabilities,完全消除了训练期间对实时 Teacher 的依赖。在 Qwen3-8B-Base 上,仅需 30 GPU 小时就在 AIME 2024 上达到 69.9%,相比标准 OPD 实现了 4.0× 加速。

背景

标准 OPD 的训练流程是:先对 base model 做 SFT,然后用一个更强的 Teacher 模型对 Student 的 on-policy rollout 进行 token 级蒸馏。问题在于,这个 Teacher 必须在每一步梯度更新时对 Student 当前采样的 rollout 进行推理,意味着需要一个持续运行的多 GPU Teacher 推理服务,计算开销巨大,学术研究者难以复现。

一个自然的想法是:能否把 Teacher 的 log-probability 预先计算好存起来,训练时直接查表?但实际尝试发现这种 naive 的离线方法无法稳定地匹配标准 OPD 的性能。

作者深入调查了失败原因,发现问题并不主要来自离线近似本身,而是一个被之前所有 OPD 工作忽视的条件:Teacher Consistency。具体来说:

  1. 现有管线普遍违反 Teacher Consistency:SFT 阶段用一个 Teacher(如 QwQ-32B)生成训练轨迹,OPD 阶段用另一个 Teacher(如 Qwen3-32B)提供参考分布
  2. 这种不一致引入不可约减的梯度偏差:导致无论在线还是离线 OPD 都收敛到次优不动点
  3. 离线设置下问题更严重:固定的 rollout 分布还会反映错误 Teacher 的偏好

框架

预备知识

给定 Teacher 模型 πT\pi_T 和可训练的 Student 模型 πθ\pi_\theta,对于 prompt qq,response x=(a1,,aT)x = (a_1, \ldots, a_T) 以自回归方式生成:

πθ(xq)=t=1Tπθ(atst)\pi_\theta(x | q) = \prod_{t=1}^{T} \pi_\theta(a_t | s_t)

其中 st=(q,a1,,at1)s_t = (q, a_1, \ldots, a_{t-1})。per-token OPD advantage 定义为:

At(θ)=logπT(atst)logπθ(atst)A_t(\theta) = \log \pi_T(a_t | s_t) - \log \pi_\theta(a_t | s_t)

当 Teacher 比 Student 更有信心时 advantage 为正,反之为负。标准 OPD 的目标函数为:

Jon(θ)=Eqp,xπθ[t=1TAt(θ)]J_{\text{on}}(\theta) = \mathbb{E}_{q \sim p, x \sim \pi_\theta} \left[ \sum_{t=1}^{T} A_t(\theta) \right]

Lightning OPD 两阶段流程

Stage 1: SFT

给定 base model πbase\pi_{\text{base}}、Teacher πT\pi_T 和 prompt 数据集 QSFT\mathcal{Q}_{\text{SFT}},先用 Teacher 生成轨迹构成 SFT 数据集:

DSFT={(qi,xi)xiπT(qi),qiQSFT}\mathcal{D}_{\text{SFT}} = \left\{ (q_i, x_i) \mid x_i \sim \pi_T(\cdot | q_i), q_i \in \mathcal{Q}_{\text{SFT}} \right\}

然后通过最大似然估计得到 reference policy πref\pi_{\text{ref}}关键约束:SFT 数据必须由 OPD 阶段同一个 Teacher 生成。

Stage 2: 离线 OPD

分为预处理和训练两个阶段。预处理阶段,从 πref\pi_{\text{ref}} 采样 rollout,一次性查询 Teacher 得到 per-token log-probabilities,构成离线数据集:

DOPD={(qj,xj,{logπT(atjstj)}t=1Tj)xjπref(qj),qjQOPD}\mathcal{D}_{\text{OPD}} = \left\{ \left( q_j, x_j, \left\{ \log \pi_T(a_t^j | s_t^j) \right\}_{t=1}^{T_j} \right) \mid x_j \sim \pi_{\text{ref}}(\cdot | q_j), q_j \in \mathcal{Q}_{\text{OPD}} \right\}

训练阶段,Student 从 πref\pi_{\text{ref}} 初始化,在 DOPD\mathcal{D}_{\text{OPD}} 上训练,无需 Teacher 服务器。每步的 advantage 计算中,Teacher 项直接从数据集读取,Student 项在线计算。Lightning OPD 的目标函数为:

Joff(θ)=Eqp,xπref[t=1TAt(θ)]J_{\text{off}}(\theta) = \mathbb{E}_{q \sim p, x \sim \pi_{\text{ref}}} \left[ \sum_{t=1}^{T} A_t(\theta) \right]

与标准 OPD 的唯一区别在于 rollout 分布:标准 OPD 从当前 πθ\pi_\theta 采样,Lightning OPD 固定使用 πref\pi_{\text{ref}}

理论分析

作者给出了严谨的理论保证,基于三个标准假设:

  1. Assumption 3.1:Advantage 的二阶矩有界,Exπref[(tAt(θ))2]σA2\mathbb{E}_{x \sim \pi_{\text{ref}}} [(\sum_t A_t(\theta))^2] \leq \sigma_A^2
  2. Assumption 3.2:支持覆盖,supp(πθ)supp(πref)\text{supp}(\pi_\theta) \subseteq \text{supp}(\pi_{\text{ref}})
  3. Assumption 3.3:Score function 有界,logπθ(atst)2G\|\nabla \log \pi_\theta(a_t|s_t)\|_2 \leq G

梯度差距有界(Theorem 3.5)

Jon(θ)Joff(θ)2GσAχ2(πθπref)\|\nabla J_{\text{on}}(\theta) - \nabla J_{\text{off}}(\theta)\|_2 \leq G \sigma_A \sqrt{\chi^2(\pi_\theta \| \pi_{\text{ref}})}

在初始化时 πθ=πref\pi_\theta = \pi_{\text{ref}}χ2=0\chi^2 = 0,两者梯度完全一致。训练过程中差距以 O(ηk)O(\eta^k) 增长。

共享最优解(Theorem 3.7):在 Teacher Consistency 下,JonJ_{\text{on}}JoffJ_{\text{off}} 共享相同的稳定点集,每个稳定点都是 KL(πθπT)\text{KL}(\pi_\theta \| \pi_T) 的局部最小值。

隐式正则化(Theorem 3.9):离线目标引入了一个协方差项,当策略偏离 πref\pi_{\text{ref}} 时自动产生恢复力,起到隐式正则化效果,防止策略漂移。

Teacher 不一致的危害(Theorem 3.11 & 3.13):当 SFT Teacher 和 OPD Teacher 不同时(πTSFTπTOPD\pi_T^{\text{SFT}} \neq \pi_T^{\text{OPD}}),引入不可约减偏差 GσΔG\sigma_\Delta,无论标准 OPD 还是离线 OPD 都会收敛到偏移的不动点。

实验

主要结果

在 Qwen3-4B 和 Qwen3-8B 两个规模上,对比 SFT baseline、标准 OPD 和 Lightning OPD:

Student 模型 方法 AIME 2024 (%) LiveCodeBench v6 (%) GPU Hours 加速比
Qwen3-4B-Base SFT 56.7 31.5 72 -
Qwen3-4B-Base OPD 65.4 39.3 20 -
Qwen3-4B-Base Lightning OPD 68.1 40.3 - 3.6×
Qwen3-8B-Base SFT 63.7 36.8 120 -
Qwen3-8B-Base OPD 68.5 41.2 30 -
Qwen3-8B-Base Lightning OPD 69.9 43.9 - 4.0×

从表中可以总结:

  1. Lightning OPD 在所有 benchmark 和模型规模上都持平或超过标准 OPD
  2. 8B 规模上,Lightning OPD 以仅 30 GPU 小时达到 69.9% AIME 2024,比标准 OPD 节省 75% 计算量
  3. 在 LiveCodeBench v6 代码生成任务上,Lightning OPD 的优势同样明显(43.9% vs 41.2%)

训练效率对比

模型规模 标准 OPD GPU Hours Lightning OPD GPU Hours 加速比 Teacher 服务器需求
4B (Teacher: 8B) 72 20 3.6× 不需要
8B (Teacher: 32B) 120 30 4.0× 不需要

关键发现:

  1. 消除实时 Teacher 推理服务是效率提升的主要来源——标准 OPD 需要并行运行 Teacher 服务器,而 Lightning OPD 只需一次预处理
  2. 随着模型规模增大(Teacher 从 8B 到 32B),加速比也相应提升,说明 Lightning OPD 在大规模设置下优势更明显
  3. 30 GPU 小时的训练成本大幅降低了学术研究进入 LLM 后训练领域的门槛

超参数配置

SFT 阶段配置:

超参数 4B Scale 8B Scale
Training steps 3000 3000
Global batch size 256 128
Max sequence length 16384 16384
Learning rate 8×10⁻⁵ 8×10⁻⁵
LR schedule cosine cosine
Warmup ratio 0.1 0.1
Packing
DeepSpeed stage ZeRO-0 ZeRO-1

OPD 阶段配置:

超参数 4B Scale 8B Scale
Training steps 150 150
Global batch size 256 256
Max response length 4096 4096
Learning rate 2×10⁻⁶ 2×10⁻⁶
LR schedule constant constant
Weight decay 0.1 0.1
Rollout temperature 0.8 0.8
Rollout top-p 1.0 1.0
Tensor parallel size 2 4

值得注意:

  1. OPD 阶段仅需 150 步 即可收敛,远少于 SFT 的 3000 步
  2. 代码生成任务的 OPD 从数学训练的 checkpoint 初始化(而非直接从 SFT),这比直接从 SFT 初始化效果更好,印证了数学推理训练对代码训练的正向迁移
  3. OPD 使用 constant learning rate 和很小的学习率(2×10⁻⁶),表明这一阶段主要做精细的策略对齐

总结

Lightning OPD 的核心贡献不仅是工程上的加速,更在于理论上揭示了 Teacher Consistency 是 OPD 管线的必要条件——SFT 阶段和 OPD 阶段必须使用同一个 Teacher。这一发现对整个 LLM 后训练领域都有指导意义,因为此前的工作(如 Thinking Machines Lab)在实践中经常忽略这一条件。理论分析证明了在 Teacher Consistency 下,离线 OPD 与标准 OPD 共享最优解,且离线目标的梯度差距有界、自带隐式正则化效果。实验上,30 GPU 小时达到 AIME 2024 69.9% 的结果,使得高质量的推理模型后训练不再是大厂专属的资源密集型任务。