← 返回列表

Mean-field Chaos Diffusion Models

作者 Sungwoo Park, Dongjun Kim, Ahmed M. Alaa
年份 2024
会议/期刊 ICML 2024
评分
标签 扩散模型 平均场理论 高基数数据
摘要 平均场混沌扩散模型:将高基数数据视为 N 粒子交互系统 + Wasserstein 变分方程 + 混沌熵细分策略,突破维度灾难,支持可变基数生成

核心思想

标准扩散模型处理固定维度的数据(如 256×256 图像)。但许多数据的维度是可变的

  • 3D 点云:几百到几万个点
  • 分子结构:原子数量不定
  • 集合数据:元素个数可变

这类高基数(high-cardinality) 数据的维度极高且可变 → 标准扩散模型面临维度灾难

MF-CDM平均场理论重新构想扩散模型:

  1. 将数据视为 $N$ 个交互粒子的系统
  2. 利用混沌传播性质——当 $N \to \infty$ 时,粒子变得近似独立
  3. 只需学习有限 $M \ll N$ 个粒子的行为即可近似无穷维系统

背景知识

什么是高基数数据

数据类型 基数 $N$ 特点
图像 固定(如 $256^2$) 维度固定,结构规则
3D 点云 可变(100-10000) 维度可变,无序
分子图 可变(10-100 原子) 维度可变,置换不变
集合数据 可变 维度可变,无序

维度灾难

标准扩散模型的分数估计误差随维度 $d$ 增长。当 $N$ 个粒子各有 $d$ 维 → 总维度 $Nd$ → 当 $N$ 很大时,分数匹配变得极其困难。

什么是平均场理论

物理学中处理大量粒子交互系统的方法。核心思想:

当粒子数 $N \to \infty$ 时,每个粒子受到的影响可以用所有粒子的平均效应(平均场)来近似。

什么是混沌传播

Kac 混沌(Definition 2.2):如果 $N$ 个粒子是可交换的,那么当 $N \to \infty$ 时,任意 $M$ 个粒子的联合分布趋近于独立同分布

\[\nu_t^{M,N} \xrightarrow{\text{弱收敛}} \mu_t^{\otimes M} \quad \text{当 } N \to \infty\]

这意味着:只需要学习单个粒子的边缘分布 $\mu_t$,就能重建整个系统。

方法详解

1. 平均场 SDE

前向动力学(加噪):

\[d\mathbf{X}_u^{i,N} = f_s(\mathbf{X}_u^{i,N}) du + \sigma_u d\mathbf{B}_u^{i,N}\]

反向动力学(去噪):

\[d\mathbf{X}_t^{i,N} = \left[f_t(\mathbf{X}_t^{i,N}) - \sigma_t^2 \nabla \log \zeta_t(\mathbf{X}_t^{i,N})\right] dt + \sigma_t d\mathbf{B}_t^{i,N}\]

$N$ 个粒子共享对称性(可交换性),每个粒子受其他粒子通过平均场的影响。

2. 混沌熵

标准化的 KL 散度:

\[\mathcal{H}(\nu_t^N | \zeta_0^{\otimes N}) = \frac{1}{N} \int \log \frac{\varrho_t^N}{\zeta_0^{\otimes N}} \varrho_t^N d\mathbf{x}^N\]

除以 $N$ → 使目标对基数鲁棒(不会随 $N$ 增长)。

3. Wasserstein 变分方程(Theorem 3.1)

在 Wasserstein 空间(概率测度空间)中用 Itô-Wentzell-Lions 公式推导:

\[\mathcal{H}_t^N(\nu_t^N) \lesssim \mathcal{H}_s^N(\nu_s^N) + C_0 \int_s^t \mathcal{O}\left(\mathbb{E}\|\nabla_{\mathcal{P}_2} \mathcal{H}_r^N\|_e^2\right) dr + C_1 \int_s^t \mathcal{O}\left(\mathbb{E}\|\nabla_x \nabla_{\mathcal{P}_2} \mathcal{H}_r^N\|_F^2\right) dr\]

常数:$C_0 \lesssim \mathcal{O}(\sqrt{d} + \mathcal{M}^2)$,$C_1 \lesssim \mathcal{O}(T)$

关键:用 Wasserstein 梯度信息(而非时间导数)来约束熵演化。

4. 平均场分数匹配(Corollary 3.2)

上界化简为 Sobolev 范数:

\[\mathcal{H}_t^N(\nu_t^N) \lesssim \frac{\mathcal{M}}{\sqrt{Nd}} \int_0^T \|\mathcal{G}_t\|_V^2 dt\]

其中 $\mathcal{G}t = \nabla \log \varrho_t^N - \nabla \log \zeta{t-t}^{\otimes N}$。

系数 $1/\sqrt{Nd}$ 使框架对大基数鲁棒——不同于标准 SGM 的 $1/\sqrt{d}$。

分数匹配目标

\[\mathcal{J}^{MF}_N(\theta, \nu_{(0,t)}^N) = \mathbb{E}_{t \sim p(t)} \|\mathbf{s}_\theta(t, \mathbf{X}_t^N, \nu_t^N) - \nabla \log \zeta_{t-t}^{\otimes N}(\mathbf{X}_t^N)\|_V^2\]

5. 分数网络架构

\[\mathbf{s}_\theta(t, \mathbf{x}^N, \nu_t^N) = A_\theta(t, \mathbf{x}^N) + B_\theta[\nu_t^N](\mathbf{x}^N)\]
  • $A_\theta$:独立项(每个粒子自身的分数)
  • $B_\theta$:交互项(通过平均场的粒子间交互)

交互项用欧几里得球 $\mathbb{B}_R$ 上的截断卷积实现 → 捕获局部交互同时保持置换不变性。

6. 混沌熵细分(Proposition 4.1)

将问题分割为 $K$ 个子问题,基数逐步增长 $N_{k+1} = \mathfrak{b} N_k$:

\[\mathcal{H}_t^\infty(\mu_t) \lesssim \lim_{K \to \infty} \sum_{k=0}^{K} \left[\sigma^{-2}_\zeta(T) E(N_{k+1}) + \frac{\mathcal{M}}{\sqrt{d}} \frac{1}{(\mathfrak{b}\sqrt{N_{k+1}})^k} \mathcal{J}^{MF}(N_k, \theta, \nu_{(t_k, t_{k+1})}^{N_k})\right]\]

分而治之:先用少量粒子学习粗略结构,再用更多粒子精细化。

7. 粒子分支函数 $\Psi_\theta$

在分段边界将 $N_k$ 个粒子扩展为 $\mathfrak{b} N_k$ 个:

\[(\mathbf{Id}^{\otimes(\mathfrak{b}-1)} \otimes \Psi_\theta)_\# \nu_{t_k}^{N_k} \to \hat{\nu}_{t_k}^{\otimes \mathfrak{b} N_k}\]

8. 集中不等式(Theorem 4.2)

\[\mathbb{P}\left[\mathcal{H}(\nu_t^{M,N} | \mu_t^{\otimes M}) \ge \varepsilon\right] \lesssim \mathcal{O}(\varepsilon^{-\varepsilon-d}) \cdot \mathcal{O}\left(\exp[-M\mathfrak{f}(\kappa)\varepsilon^2 - M\mathfrak{f}(\kappa)\mathfrak{h}(R)]\right)\]

证明 $M \ll N$ 个粒子足以近似平均场 → 不需要学习全部 $N$ 个粒子

实验结果

合成数据集

验证了方法对基数变化的鲁棒性和分数估计质量。

3D 点云生成

将每个 3D 点视为一个”粒子”,利用粒子的可交换性进行置换不变学习。展示了在非结构化高基数数据上的有效性。

个人思考

  1. “粒子 = 数据元素”的类比是核心洞察:将点云中的点、分子中的原子视为物理学中的交互粒子 → 自然引入平均场理论的强大工具。
  2. $1/\sqrt{Nd}$ 的缩放系数是突破维度灾难的关键:标准 SGM 的误差随维度增长,而 MF-CDM 通过粒子间的统计独立性(混沌传播)抵消了 $N$ 的增长。
  3. 细分策略的分而治之思想优雅:先用少量粒子捕获全局结构 → 再逐步增加粒子精细化 → 类似多分辨率方法。
  4. 局部卷积交互是实用的近似:全局交互(所有粒子对所有粒子)计算量 $O(N^2)$,局部交互在欧几里得球内的截断使计算可行。
  5. 理论贡献远超实验:大量深刻的定理和推论,但实验仅在合成数据和点云上 → 未来在分子生成等更广泛领域的验证值得期待。
← 返回列表