Mean-field Chaos Diffusion Models
核心思想
标准扩散模型处理固定维度的数据(如 256×256 图像)。但许多数据的维度是可变的:
- 3D 点云:几百到几万个点
- 分子结构:原子数量不定
- 集合数据:元素个数可变
这类高基数(high-cardinality) 数据的维度极高且可变 → 标准扩散模型面临维度灾难。
MF-CDM 用平均场理论重新构想扩散模型:
- 将数据视为 $N$ 个交互粒子的系统
- 利用混沌传播性质——当 $N \to \infty$ 时,粒子变得近似独立
- 只需学习有限 $M \ll N$ 个粒子的行为即可近似无穷维系统
背景知识
什么是高基数数据
| 数据类型 | 基数 $N$ | 特点 |
|---|---|---|
| 图像 | 固定(如 $256^2$) | 维度固定,结构规则 |
| 3D 点云 | 可变(100-10000) | 维度可变,无序 |
| 分子图 | 可变(10-100 原子) | 维度可变,置换不变 |
| 集合数据 | 可变 | 维度可变,无序 |
维度灾难
标准扩散模型的分数估计误差随维度 $d$ 增长。当 $N$ 个粒子各有 $d$ 维 → 总维度 $Nd$ → 当 $N$ 很大时,分数匹配变得极其困难。
什么是平均场理论
物理学中处理大量粒子交互系统的方法。核心思想:
当粒子数 $N \to \infty$ 时,每个粒子受到的影响可以用所有粒子的平均效应(平均场)来近似。
什么是混沌传播
Kac 混沌(Definition 2.2):如果 $N$ 个粒子是可交换的,那么当 $N \to \infty$ 时,任意 $M$ 个粒子的联合分布趋近于独立同分布:
\[\nu_t^{M,N} \xrightarrow{\text{弱收敛}} \mu_t^{\otimes M} \quad \text{当 } N \to \infty\]这意味着:只需要学习单个粒子的边缘分布 $\mu_t$,就能重建整个系统。
方法详解
1. 平均场 SDE
前向动力学(加噪):
\[d\mathbf{X}_u^{i,N} = f_s(\mathbf{X}_u^{i,N}) du + \sigma_u d\mathbf{B}_u^{i,N}\]反向动力学(去噪):
\[d\mathbf{X}_t^{i,N} = \left[f_t(\mathbf{X}_t^{i,N}) - \sigma_t^2 \nabla \log \zeta_t(\mathbf{X}_t^{i,N})\right] dt + \sigma_t d\mathbf{B}_t^{i,N}\]$N$ 个粒子共享对称性(可交换性),每个粒子受其他粒子通过平均场的影响。
2. 混沌熵
标准化的 KL 散度:
\[\mathcal{H}(\nu_t^N | \zeta_0^{\otimes N}) = \frac{1}{N} \int \log \frac{\varrho_t^N}{\zeta_0^{\otimes N}} \varrho_t^N d\mathbf{x}^N\]除以 $N$ → 使目标对基数鲁棒(不会随 $N$ 增长)。
3. Wasserstein 变分方程(Theorem 3.1)
在 Wasserstein 空间(概率测度空间)中用 Itô-Wentzell-Lions 公式推导:
\[\mathcal{H}_t^N(\nu_t^N) \lesssim \mathcal{H}_s^N(\nu_s^N) + C_0 \int_s^t \mathcal{O}\left(\mathbb{E}\|\nabla_{\mathcal{P}_2} \mathcal{H}_r^N\|_e^2\right) dr + C_1 \int_s^t \mathcal{O}\left(\mathbb{E}\|\nabla_x \nabla_{\mathcal{P}_2} \mathcal{H}_r^N\|_F^2\right) dr\]常数:$C_0 \lesssim \mathcal{O}(\sqrt{d} + \mathcal{M}^2)$,$C_1 \lesssim \mathcal{O}(T)$
关键:用 Wasserstein 梯度信息(而非时间导数)来约束熵演化。
4. 平均场分数匹配(Corollary 3.2)
上界化简为 Sobolev 范数:
\[\mathcal{H}_t^N(\nu_t^N) \lesssim \frac{\mathcal{M}}{\sqrt{Nd}} \int_0^T \|\mathcal{G}_t\|_V^2 dt\]其中 $\mathcal{G}t = \nabla \log \varrho_t^N - \nabla \log \zeta{t-t}^{\otimes N}$。
系数 $1/\sqrt{Nd}$ 使框架对大基数鲁棒——不同于标准 SGM 的 $1/\sqrt{d}$。
分数匹配目标:
\[\mathcal{J}^{MF}_N(\theta, \nu_{(0,t)}^N) = \mathbb{E}_{t \sim p(t)} \|\mathbf{s}_\theta(t, \mathbf{X}_t^N, \nu_t^N) - \nabla \log \zeta_{t-t}^{\otimes N}(\mathbf{X}_t^N)\|_V^2\]5. 分数网络架构
\[\mathbf{s}_\theta(t, \mathbf{x}^N, \nu_t^N) = A_\theta(t, \mathbf{x}^N) + B_\theta[\nu_t^N](\mathbf{x}^N)\]- $A_\theta$:独立项(每个粒子自身的分数)
- $B_\theta$:交互项(通过平均场的粒子间交互)
交互项用欧几里得球 $\mathbb{B}_R$ 上的截断卷积实现 → 捕获局部交互同时保持置换不变性。
6. 混沌熵细分(Proposition 4.1)
将问题分割为 $K$ 个子问题,基数逐步增长 $N_{k+1} = \mathfrak{b} N_k$:
\[\mathcal{H}_t^\infty(\mu_t) \lesssim \lim_{K \to \infty} \sum_{k=0}^{K} \left[\sigma^{-2}_\zeta(T) E(N_{k+1}) + \frac{\mathcal{M}}{\sqrt{d}} \frac{1}{(\mathfrak{b}\sqrt{N_{k+1}})^k} \mathcal{J}^{MF}(N_k, \theta, \nu_{(t_k, t_{k+1})}^{N_k})\right]\]分而治之:先用少量粒子学习粗略结构,再用更多粒子精细化。
7. 粒子分支函数 $\Psi_\theta$
在分段边界将 $N_k$ 个粒子扩展为 $\mathfrak{b} N_k$ 个:
\[(\mathbf{Id}^{\otimes(\mathfrak{b}-1)} \otimes \Psi_\theta)_\# \nu_{t_k}^{N_k} \to \hat{\nu}_{t_k}^{\otimes \mathfrak{b} N_k}\]8. 集中不等式(Theorem 4.2)
\[\mathbb{P}\left[\mathcal{H}(\nu_t^{M,N} | \mu_t^{\otimes M}) \ge \varepsilon\right] \lesssim \mathcal{O}(\varepsilon^{-\varepsilon-d}) \cdot \mathcal{O}\left(\exp[-M\mathfrak{f}(\kappa)\varepsilon^2 - M\mathfrak{f}(\kappa)\mathfrak{h}(R)]\right)\]证明 $M \ll N$ 个粒子足以近似平均场 → 不需要学习全部 $N$ 个粒子。
实验结果
合成数据集
验证了方法对基数变化的鲁棒性和分数估计质量。
3D 点云生成
将每个 3D 点视为一个”粒子”,利用粒子的可交换性进行置换不变学习。展示了在非结构化高基数数据上的有效性。
个人思考
- “粒子 = 数据元素”的类比是核心洞察:将点云中的点、分子中的原子视为物理学中的交互粒子 → 自然引入平均场理论的强大工具。
- $1/\sqrt{Nd}$ 的缩放系数是突破维度灾难的关键:标准 SGM 的误差随维度增长,而 MF-CDM 通过粒子间的统计独立性(混沌传播)抵消了 $N$ 的增长。
- 细分策略的分而治之思想优雅:先用少量粒子捕获全局结构 → 再逐步增加粒子精细化 → 类似多分辨率方法。
- 局部卷积交互是实用的近似:全局交互(所有粒子对所有粒子)计算量 $O(N^2)$,局部交互在欧几里得球内的截断使计算可行。
- 理论贡献远超实验:大量深刻的定理和推论,但实验仅在合成数据和点云上 → 未来在分子生成等更广泛领域的验证值得期待。