← 返回列表

Improving Diffusion Language Model Decoding through Joint Search in Generation Order and Token Space

作者 Yangyi Shen, Tianjian Feng, Jiaqi Han, Wen Wang, Tianlang Chen, Chunhua Shen, Jure Leskovec, Stefano Ermon
年份 2025
会议/期刊 arXiv 2025
评分
标签 扩散语言模型 解码策略 测试时搜索
摘要 扩散语言模型的联合搜索解码(OTS):同时搜索生成顺序和 token 选择 + 增量似然估计器 + 块级扩散降低复杂度,GSM8K/MATH/HumanEval 提升 3-8%

核心思想

掩码扩散模型(MDM) 如 LLaDA 生成文本时,不是从左到右逐个 token 生成,而是从全掩码序列逐步”揭示” token。这引入了一个自回归模型没有的自由度:生成顺序——先揭示哪些位置的 token?

OTS(Order-Token Search) 发现:

  1. 保守策略(低置信度重掩码)→ 单样本准确但探索不足
  2. 随机策略 → 探索充分但选择不精准
  3. OTS 同时搜索顺序和 token → 兼顾探索与利用

结果:在 GSM8K、MATH500、Countdown、HumanEval 上提升 3.1-7.9%

背景知识

自回归 vs 扩散语言模型

特性 自回归(GPT 等) 扩散(LLaDA 等)  
生成方式 从左到右逐个 token 并行去噪,灵活顺序  
生成顺序 固定(左→右) 可变(任意顺序)  
搜索方法 beam search 成熟 缺乏有效搜索  
似然计算 $p(x_i x_{<i})$ 直接可得 联合似然难计算

探索-利用权衡

MDM 的重掩码策略决定了每一步揭示哪些位置:

策略 做法 优势 劣势
低置信度重掩码 保留最确定的 token pass@1 高(~79%) pass@256 低(~0.4),探索窄
随机重掩码 随机选择保留 pass@256 高(~0.8),探索广 pass@1 低,选择无方向
OTS 结构化搜索 兼顾 pass@1 和 pass@k 计算成本 2-3×

为什么自回归搜索方法不适用于 MDM

自回归模型中,似然可以分解为条件概率的乘积:

\[\log p(x) = \sum_i \log p(x_i | x_{<i})\]

但 MDM 的去噪是双向的——每个 token 的预测依赖所有已揭示的 token,不仅是”之前”的。因此自回归式的似然分解在 MDM 上失效

方法详解

1. 将解码视为轨迹搜索

MDM 从全掩码序列出发,经过 $S$ 步去噪到达完整序列。每条”轨迹”对应一种特定的生成顺序 + token 选择组合。OTS 在这个轨迹空间中做结构化搜索。

2. 增量似然估计器

核心公式(Equation 2)

\[s(x_t; x_s) = \mathbb{E}_{x_0 \sim p_\theta(x_0 | x_t)} \log p(x_0 | b(x_s, x_t, x_0))\]
  • $x_s, x_t$:两个时间步的部分掩码序列($s > t$,$x_t$ 比 $x_s$ 揭示了更多 token)
  • 仅评分新揭示的 token 块($x_s$ 和 $x_t$ 之间新揭示的位置)
  • 以模型对完整序列的预测 $x_0$ 为条件

关键设计:增量评分与 MDM 的训练目标对齐(预测完整序列),而非强行适配自回归式评分。

3. 搜索过程

维护 $K$ 条 beam 候选:

  1. 扩展:每条 beam 生成 $K$ 条独立的去噪轨迹 → 共 $K^2$ 候选
  2. 评分:用增量似然估计器评分每个候选
  3. 剪枝:保留 Top-$K$ 候选
  4. 重复:每 $N$ 步在块边界做一次搜索

块级扩散降低复杂度:从 $O(S \cdot K^2 \cdot L)$ 降到 $O(S \cdot K \cdot L + B \cdot K^2 \cdot L)$,其中 $B$ 是块数。

4. 算法伪代码

输入:模型 p_θ,beam 大小 K,块大小 N
初始化 K 条全掩码序列
for 每个时间步 t = T → 0:
    对每条 beam 做标准去噪(加 Gumbel 噪声增加随机性)
    if t 是块边界:
        扩展:每条 beam × K 条轨迹
        评分:增量似然估计
        剪枝:保留 Top-K
返回最优序列

实验结果

主要性能(LLaDA-8B-Instruct)

任务 Low-conf+MV OTS 提升
GSM8K 70.7% 70.1% -0.6%
MATH500 29.7% 32.8% +3.1%
Countdown 20.7% 28.4% +7.7%
HumanEval 20.9% 27.4% +6.5%
平均 37.1% 39.7% +2.6%

LLaDA-1.5 结果

任务 最佳基线 OTS 提升
MATH500 31.3% 33.8% +2.5%
Countdown 22.5% 28.0% +5.5%
HumanEval 22.0% 27.6% +5.6%
平均 37.2% 40.4% +3.2%

与自回归式搜索对比(序列长度 256)

搜索方法 GSM8K MATH500 Countdown HumanEval
Order Search(AR 式) 79.2% 35.8% 15.2% 32.9%
Token Search(AR 式) 0%
OTS 79.8% 36.0% 26.2% 34.2%

Token Search 在 Countdown 上完全崩溃(0%)→ 证明自回归似然在 MDM 上根本不适用。

似然估计器消融

评分方式 平均准确率
OTS(增量评分) 39.7%
OTS All Blocks(评分完整前缀) 32.6%
OTS Future Blocks(无预测上下文) 33.9%

增量评分显著优于替代方案。

计算效率

方法 Countdown 准确率 相对时间
OTS (beam=6) 29.3% ~2-3× 单次推理
AR+MV (5 samples) 19.9% ~5× 单次推理
Random+MV (5 samples) 18.4% ~5× 单次推理

OTS 比多数投票更快(2-3× vs 5×)且更准

负面结果:生成顺序无最优解

对 256 个样本做大规模相关性分析:生成顺序与”从左到右”的相似度和解的正确性之间几乎没有相关性(相关系数 ~10⁻³)→ MDM 发现了多条等效的解题路径。

个人思考

  1. 生成顺序的灵活性是 MDM 的独特优势:自回归模型只能从左到右,但 MDM 可以”先写关键步骤、再填充细节”——OTS 正是在利用这种灵活性。
  2. 增量似然估计器是技术核心:直接套用自回归评分会在 Countdown 上崩溃到 0% → 说明 MDM 需要量身定制的评分机制。
  3. “无最优生成顺序” 是有趣的发现:MDM 不需要模仿人类的从左到右思考方式——它找到了人类不会采用的解题路径,但同样有效。
  4. 与后训练方法互补:OTS 是纯推理时方法(不改权重),而 diffu-GRPO 是训练方法 → 两者可以叠加使用。
  5. Countdown 上 +7.7% 但 GSM8K 上 -0.6%:说明搜索在需要全局一致性的规划任务上收益最大,在已经表现良好的任务上边际收益递减。
← 返回列表