From Next-Token to Next-Block: A Principled Adaptation Path for Diffusion LLMs
核心思想
如何把已有的强大 AR 模型(如 Qwen)改造成扩散语言模型(DLM)?
现有方法要么从头训练(浪费 AR 预训练投入),要么简单替换注意力掩码(效果差)。本文发现:块扩散(Block Diffusion)是 AR→DLM 适配的天然终点——AR 就是块大小为 1 的块扩散。
三个关键设计:
- 上下文因果注意力:已提交的上下文保持因果,只有当前块双向 → 平滑过渡
- 辅助 AR 损失:$\lambda = 0.5$ 的 next-token 预测损失 → 保持 AR 能力
- 渐进式块增长:块大小从 1 逐步翻倍到 32 → 避免突变
结果:NBDiff-7B 在 7B 级 DLM 中 SOTA,宏观平均 79.9%,仅需 ~700B 额外 token。
背景知识
AR → DLM 适配的困难
| 挑战 | 说明 |
|---|---|
| 注意力模式冲突 | AR 是严格因果,DLM 需要双向 → 直接切换会崩溃 |
| 训练不稳定 | 全序列扩散的掩码空间组合爆炸 → 损失振荡 |
| 能力遗忘 | 适配过程可能丧失 AR 预训练学到的推理能力 |
为什么块扩散优于全序列扩散
| 维度 | 全序列扩散 | 块扩散 |
|---|---|---|
| 训练稳定性 | 掩码空间组合爆炸 → 振荡 | 块内局部 → 稳定 |
| 推理效率 | 全序列多步去噪 | KV-cache 逐块复用 |
| 与 AR 的关系 | 完全不同的范式 | 块大小=1 就是 AR |
方法详解
1. 上下文因果注意力
区分已提交上下文和当前生成块的注意力机制:
\[M_{\text{all}} = \begin{bmatrix} M_{\text{BD}} & M_{\text{OBC}} \\ 0 & M_{\text{CC}} \end{bmatrix}\]- $M_{\text{BD}}$:块对角自注意力(当前块内双向)
- $M_{\text{OBC}}$:去噪条件(当前块看到之前所有已提交块)
- $M_{\text{CC}}$:因果自注意力(已提交上下文保持 AR 因果性)
关键对比:
| 注意力类型 | 平均性能 |
|---|---|
| Block-Causal(上下文也双向) | 31.4% |
| Context-Causal(上下文因果) | 48.6% |
差距 17.2pp → 保持上下文因果性至关重要。
2. 辅助 AR 损失
\[\mathcal{L}_{\text{total}}(\theta) = \mathcal{L}_{\text{MDM}}(\theta) + \lambda \mathcal{L}_{\text{AR}}(\theta)\]- $\mathcal{L}_{\text{MDM}}$:掩码扩散损失(块内去噪)
- $\mathcal{L}_{\text{AR}}$:next-token 预测损失(应用于干净上下文分支)
- $\lambda = 0.5$
无需额外前向传播——AR 损失直接在同一次前向中计算。
3. 渐进式块增长课程
\[b(s) = \min\{b_{\max}, b_0 \cdot r^{\lfloor(s - s_0)/\Delta\rfloor}\}\]- $b_0 = 1$(从 AR 开始)
- $r = 2$(每次翻倍)
- 目标 $b_{\max} = 32$
- 间隔 $\Delta$ 固定
路径:$1 \to 2 \to 4 \to 8 \to 16 \to 32$
4. NBDiff-7B 训练管线
| 阶段 | 迭代数 | 序列长度 | token 量 |
|---|---|---|---|
| Stage 1:适配 | 84,000 | 8K | ~700B |
| Stage 2:长上下文 | 23,800 | 32K | ~100B |
| Stage 3:指令微调 | 17,000 | — | 10B |
实验结果
NBDiff-7B-Instruct vs 同级 DLM
| 基准 | LLaDA-MoE | Dream-v0 | SDAR | NBDiff-7B |
|---|---|---|---|---|
| MMLU | 67.2 | 67.0 | 78.6 | 82.9 |
| MMLU-Pro | 44.6 | 43.3 | 56.9 | 71.9 |
| GSM8K | 82.4 | 81.0 | 91.3 | 91.0 |
| MATH | 58.7 | 39.2 | 78.6 | 84.0 |
| MBPP | 70.0 | 58.8 | 72.0 | 87.6 |
| HumanEval | 61.6 | 55.5 | 78.7 | 89.0 |
| 宏观平均 | 61.1 | 58.2 | 74.0 | 79.9 |
NBDiff-7B 在 6/6 个基准上取得最佳或接近最佳,宏观平均 79.9% 大幅领先。
NBDiff-7B-Base vs 基线
| 基准 | LLaDA-8B | Dream-v0-Base | NBDiff-7B-Base |
|---|---|---|---|
| MMLU | 65.9 | 69.5 | 70.1 |
| CMMLU | 69.9 | 60.9 | 77.3 |
| BBH | 49.8 | 57.9 | 77.3 |
| MATH | 27.3 | 39.6 | 46.0 |
| 宏观平均 | 52.0 | 60.1 | 65.3 |
适配方法对比
| 模型 | 方法 | GSM8K | MATH | HumanEval | MBPP | 平均 |
|---|---|---|---|---|---|---|
| Qwen3-4B | 注意力退火 | 76.6 | 28.7 | 25.6 | 59.4 | 47.6 |
| Qwen3-4B | 本文 | 79.8 | 32.3 | 27.4 | 61.6 | 50.3 |
| Qwen3-8B | 注意力退火 | 80.7 | 25.8 | 39.0 | 49.8 | 48.8 |
| Qwen3-8B | 本文 | 82.3 | 34.7 | 31.7 | 64.4 | 53.3 |
本文方法一致优于退火式适配 2-5pp。
组件消融(openPangu-7B)
| 配置 | 平均 | MATH 提升 |
|---|---|---|
| 基线 | 48.95 | — |
| + AR 损失 | 52.97 | +7.16 |
| + AR 损失 + 渐进增长 | 54.94 | +7.92 |
数学推理受益最大(+7.92pp)。
个人思考
- “AR 是块大小为 1 的块扩散” 这个观察打通了 AR 和 DLM 的连接 → 适配不再是范式转换,而是同一谱上的平滑移动。
- 上下文因果 vs 块因果 17.2pp 的差距说明:AR 模型学到的因果推理模式是宝贵资产,适配过程应该保留而非破坏。
- 700B token 的适配成本相比从头训练(通常 >10T token)节省了一个数量级 → 这使得将任意强 AR 模型转为 DLM 变得实际可行。
- HumanEval 89.0% 是所有 7B 级 DLM 中的最高分 → 代码生成从块扩散的双向注意力中受益最大(代码的局部结构需要前后一致)。
- 渐进式块增长看似简单但不可或缺:直接跳到大块会导致不稳定,逐步翻倍让模型有时间适应每个新的并行度级别。