← 返回列表

From Next-Token to Next-Block: A Principled Adaptation Path for Diffusion LLMs

作者 Yuchuan Tian, Yuchen Liang, Shuo Zhang, Yingte Shu, Guangwen Yang, Wei He, Sibo Fang, Tianyu Guo, Kai Han, Chao Xu, Hanting Chen, Xinghao Chen, Yunhe Wang
年份 2025
会议/期刊 arXiv 2025
评分
标签 扩散语言模型 模型适配 块扩散
摘要 AR→扩散 LLM 的原则性适配路径:上下文因果注意力 + 辅助 AR 损失 + 渐进式块增长课程,NBDiff-7B 在 7B 级 DLM 中 SOTA(宏观平均 79.9%),仅需 ~700B 额外 token

核心思想

如何把已有的强大 AR 模型(如 Qwen)改造成扩散语言模型(DLM)?

现有方法要么从头训练(浪费 AR 预训练投入),要么简单替换注意力掩码(效果差)。本文发现:块扩散(Block Diffusion)是 AR→DLM 适配的天然终点——AR 就是块大小为 1 的块扩散。

三个关键设计:

  1. 上下文因果注意力:已提交的上下文保持因果,只有当前块双向 → 平滑过渡
  2. 辅助 AR 损失:$\lambda = 0.5$ 的 next-token 预测损失 → 保持 AR 能力
  3. 渐进式块增长:块大小从 1 逐步翻倍到 32 → 避免突变

结果:NBDiff-7B 在 7B 级 DLM 中 SOTA,宏观平均 79.9%,仅需 ~700B 额外 token。

背景知识

AR → DLM 适配的困难

挑战 说明
注意力模式冲突 AR 是严格因果,DLM 需要双向 → 直接切换会崩溃
训练不稳定 全序列扩散的掩码空间组合爆炸 → 损失振荡
能力遗忘 适配过程可能丧失 AR 预训练学到的推理能力

为什么块扩散优于全序列扩散

维度 全序列扩散 块扩散
训练稳定性 掩码空间组合爆炸 → 振荡 块内局部 → 稳定
推理效率 全序列多步去噪 KV-cache 逐块复用
与 AR 的关系 完全不同的范式 块大小=1 就是 AR

方法详解

1. 上下文因果注意力

区分已提交上下文和当前生成块的注意力机制:

\[M_{\text{all}} = \begin{bmatrix} M_{\text{BD}} & M_{\text{OBC}} \\ 0 & M_{\text{CC}} \end{bmatrix}\]
  • $M_{\text{BD}}$:块对角自注意力(当前块内双向)
  • $M_{\text{OBC}}$:去噪条件(当前块看到之前所有已提交块)
  • $M_{\text{CC}}$:因果自注意力(已提交上下文保持 AR 因果性)

关键对比

注意力类型 平均性能
Block-Causal(上下文也双向) 31.4%
Context-Causal(上下文因果) 48.6%

差距 17.2pp → 保持上下文因果性至关重要

2. 辅助 AR 损失

\[\mathcal{L}_{\text{total}}(\theta) = \mathcal{L}_{\text{MDM}}(\theta) + \lambda \mathcal{L}_{\text{AR}}(\theta)\]
  • $\mathcal{L}_{\text{MDM}}$:掩码扩散损失(块内去噪)
  • $\mathcal{L}_{\text{AR}}$:next-token 预测损失(应用于干净上下文分支)
  • $\lambda = 0.5$

无需额外前向传播——AR 损失直接在同一次前向中计算。

3. 渐进式块增长课程

\[b(s) = \min\{b_{\max}, b_0 \cdot r^{\lfloor(s - s_0)/\Delta\rfloor}\}\]
  • $b_0 = 1$(从 AR 开始)
  • $r = 2$(每次翻倍)
  • 目标 $b_{\max} = 32$
  • 间隔 $\Delta$ 固定

路径:$1 \to 2 \to 4 \to 8 \to 16 \to 32$

4. NBDiff-7B 训练管线

阶段 迭代数 序列长度 token 量
Stage 1:适配 84,000 8K ~700B
Stage 2:长上下文 23,800 32K ~100B
Stage 3:指令微调 17,000 10B

实验结果

NBDiff-7B-Instruct vs 同级 DLM

基准 LLaDA-MoE Dream-v0 SDAR NBDiff-7B
MMLU 67.2 67.0 78.6 82.9
MMLU-Pro 44.6 43.3 56.9 71.9
GSM8K 82.4 81.0 91.3 91.0
MATH 58.7 39.2 78.6 84.0
MBPP 70.0 58.8 72.0 87.6
HumanEval 61.6 55.5 78.7 89.0
宏观平均 61.1 58.2 74.0 79.9

NBDiff-7B 在 6/6 个基准上取得最佳或接近最佳,宏观平均 79.9% 大幅领先。

NBDiff-7B-Base vs 基线

基准 LLaDA-8B Dream-v0-Base NBDiff-7B-Base
MMLU 65.9 69.5 70.1
CMMLU 69.9 60.9 77.3
BBH 49.8 57.9 77.3
MATH 27.3 39.6 46.0
宏观平均 52.0 60.1 65.3

适配方法对比

模型 方法 GSM8K MATH HumanEval MBPP 平均
Qwen3-4B 注意力退火 76.6 28.7 25.6 59.4 47.6
Qwen3-4B 本文 79.8 32.3 27.4 61.6 50.3
Qwen3-8B 注意力退火 80.7 25.8 39.0 49.8 48.8
Qwen3-8B 本文 82.3 34.7 31.7 64.4 53.3

本文方法一致优于退火式适配 2-5pp

组件消融(openPangu-7B)

配置 平均 MATH 提升
基线 48.95
+ AR 损失 52.97 +7.16
+ AR 损失 + 渐进增长 54.94 +7.92

数学推理受益最大(+7.92pp)。

个人思考

  1. “AR 是块大小为 1 的块扩散” 这个观察打通了 AR 和 DLM 的连接 → 适配不再是范式转换,而是同一谱上的平滑移动
  2. 上下文因果 vs 块因果 17.2pp 的差距说明:AR 模型学到的因果推理模式是宝贵资产,适配过程应该保留而非破坏
  3. 700B token 的适配成本相比从头训练(通常 >10T token)节省了一个数量级 → 这使得将任意强 AR 模型转为 DLM 变得实际可行。
  4. HumanEval 89.0% 是所有 7B 级 DLM 中的最高分 → 代码生成从块扩散的双向注意力中受益最大(代码的局部结构需要前后一致)。
  5. 渐进式块增长看似简单但不可或缺:直接跳到大块会导致不稳定,逐步翻倍让模型有时间适应每个新的并行度级别
← 返回列表