Improving Diffusion Language Model Decoding through Joint Search in Generation Order and Token Space
核心思想
掩码扩散模型(MDM) 如 LLaDA 生成文本时,不是从左到右逐个 token 生成,而是从全掩码序列逐步”揭示” token。这引入了一个自回归模型没有的自由度:生成顺序——先揭示哪些位置的 token?
OTS(Order-Token Search) 发现:
- 保守策略(低置信度重掩码)→ 单样本准确但探索不足
- 随机策略 → 探索充分但选择不精准
- OTS 同时搜索顺序和 token → 兼顾探索与利用
结果:在 GSM8K、MATH500、Countdown、HumanEval 上提升 3.1-7.9%。
背景知识
自回归 vs 扩散语言模型
| 特性 | 自回归(GPT 等) | 扩散(LLaDA 等) | |
|---|---|---|---|
| 生成方式 | 从左到右逐个 token | 并行去噪,灵活顺序 | |
| 生成顺序 | 固定(左→右) | 可变(任意顺序) | |
| 搜索方法 | beam search 成熟 | 缺乏有效搜索 | |
| 似然计算 | $p(x_i | x_{<i})$ 直接可得 | 联合似然难计算 |
探索-利用权衡
MDM 的重掩码策略决定了每一步揭示哪些位置:
| 策略 | 做法 | 优势 | 劣势 |
|---|---|---|---|
| 低置信度重掩码 | 保留最确定的 token | pass@1 高(~79%) | pass@256 低(~0.4),探索窄 |
| 随机重掩码 | 随机选择保留 | pass@256 高(~0.8),探索广 | pass@1 低,选择无方向 |
| OTS | 结构化搜索 | 兼顾 pass@1 和 pass@k | 计算成本 2-3× |
为什么自回归搜索方法不适用于 MDM
自回归模型中,似然可以分解为条件概率的乘积:
\[\log p(x) = \sum_i \log p(x_i | x_{<i})\]但 MDM 的去噪是双向的——每个 token 的预测依赖所有已揭示的 token,不仅是”之前”的。因此自回归式的似然分解在 MDM 上失效。
方法详解
1. 将解码视为轨迹搜索
MDM 从全掩码序列出发,经过 $S$ 步去噪到达完整序列。每条”轨迹”对应一种特定的生成顺序 + token 选择组合。OTS 在这个轨迹空间中做结构化搜索。
2. 增量似然估计器
核心公式(Equation 2):
\[s(x_t; x_s) = \mathbb{E}_{x_0 \sim p_\theta(x_0 | x_t)} \log p(x_0 | b(x_s, x_t, x_0))\]- $x_s, x_t$:两个时间步的部分掩码序列($s > t$,$x_t$ 比 $x_s$ 揭示了更多 token)
- 仅评分新揭示的 token 块($x_s$ 和 $x_t$ 之间新揭示的位置)
- 以模型对完整序列的预测 $x_0$ 为条件
关键设计:增量评分与 MDM 的训练目标对齐(预测完整序列),而非强行适配自回归式评分。
3. 搜索过程
维护 $K$ 条 beam 候选:
- 扩展:每条 beam 生成 $K$ 条独立的去噪轨迹 → 共 $K^2$ 候选
- 评分:用增量似然估计器评分每个候选
- 剪枝:保留 Top-$K$ 候选
- 重复:每 $N$ 步在块边界做一次搜索
块级扩散降低复杂度:从 $O(S \cdot K^2 \cdot L)$ 降到 $O(S \cdot K \cdot L + B \cdot K^2 \cdot L)$,其中 $B$ 是块数。
4. 算法伪代码
输入:模型 p_θ,beam 大小 K,块大小 N
初始化 K 条全掩码序列
for 每个时间步 t = T → 0:
对每条 beam 做标准去噪(加 Gumbel 噪声增加随机性)
if t 是块边界:
扩展:每条 beam × K 条轨迹
评分:增量似然估计
剪枝:保留 Top-K
返回最优序列
实验结果
主要性能(LLaDA-8B-Instruct)
| 任务 | Low-conf+MV | OTS | 提升 |
|---|---|---|---|
| GSM8K | 70.7% | 70.1% | -0.6% |
| MATH500 | 29.7% | 32.8% | +3.1% |
| Countdown | 20.7% | 28.4% | +7.7% |
| HumanEval | 20.9% | 27.4% | +6.5% |
| 平均 | 37.1% | 39.7% | +2.6% |
LLaDA-1.5 结果
| 任务 | 最佳基线 | OTS | 提升 |
|---|---|---|---|
| MATH500 | 31.3% | 33.8% | +2.5% |
| Countdown | 22.5% | 28.0% | +5.5% |
| HumanEval | 22.0% | 27.6% | +5.6% |
| 平均 | 37.2% | 40.4% | +3.2% |
与自回归式搜索对比(序列长度 256)
| 搜索方法 | GSM8K | MATH500 | Countdown | HumanEval |
|---|---|---|---|---|
| Order Search(AR 式) | 79.2% | 35.8% | 15.2% | 32.9% |
| Token Search(AR 式) | — | — | 0% | — |
| OTS | 79.8% | 36.0% | 26.2% | 34.2% |
Token Search 在 Countdown 上完全崩溃(0%)→ 证明自回归似然在 MDM 上根本不适用。
似然估计器消融
| 评分方式 | 平均准确率 |
|---|---|
| OTS(增量评分) | 39.7% |
| OTS All Blocks(评分完整前缀) | 32.6% |
| OTS Future Blocks(无预测上下文) | 33.9% |
增量评分显著优于替代方案。
计算效率
| 方法 | Countdown 准确率 | 相对时间 |
|---|---|---|
| OTS (beam=6) | 29.3% | ~2-3× 单次推理 |
| AR+MV (5 samples) | 19.9% | ~5× 单次推理 |
| Random+MV (5 samples) | 18.4% | ~5× 单次推理 |
OTS 比多数投票更快(2-3× vs 5×)且更准。
负面结果:生成顺序无最优解
对 256 个样本做大规模相关性分析:生成顺序与”从左到右”的相似度和解的正确性之间几乎没有相关性(相关系数 ~10⁻³)→ MDM 发现了多条等效的解题路径。
个人思考
- 生成顺序的灵活性是 MDM 的独特优势:自回归模型只能从左到右,但 MDM 可以”先写关键步骤、再填充细节”——OTS 正是在利用这种灵活性。
- 增量似然估计器是技术核心:直接套用自回归评分会在 Countdown 上崩溃到 0% → 说明 MDM 需要量身定制的评分机制。
- “无最优生成顺序” 是有趣的发现:MDM 不需要模仿人类的从左到右思考方式——它找到了人类不会采用的解题路径,但同样有效。
- 与后训练方法互补:OTS 是纯推理时方法(不改权重),而 diffu-GRPO 是训练方法 → 两者可以叠加使用。
- Countdown 上 +7.7% 但 GSM8K 上 -0.6%:说明搜索在需要全局一致性的规划任务上收益最大,在已经表现良好的任务上边际收益递减。