← 返回列表

Deconstructing Denoising Diffusion Models for Self-Supervised Learning

作者 Xinlei Chen, Zhuang Liu, Saining Xie, Kaiming He
年份 2024
会议/期刊 arXiv 2024
评分
标签 自监督学习 扩散模型 表征学习
摘要 系统解构扩散模型中哪些组件对表征学习真正重要:逐步去除现代组件,最终发现低维潜在空间 + 去噪目标是核心,得到接近 MAE 性能的极简 l-DAE

核心思想

扩散模型(DDM)不仅能生成图像,还能学习视觉表征。但 DDM 有很多复杂组件(VQGAN、噪声调度、扩散过程等),哪些对表征学习真正重要

本文通过”解构“方法——逐步去除现代组件——发现:

  1. 只有少数组件对表征学习关键
  2. 最终模型 l-DAE(latent Denoising Autoencoder)极其简单:PCA 投影 + 加噪声 + ViT 去噪
  3. l-DAE 接近 MAE 的性能,揭示扩散模型学表征的本质机制

背景知识

自监督学习的两大范式

范式 代表方法 核心思想
对比学习 MoCo v3 让相似图像的表征接近,不相似的远离
掩码预测 MAE 遮住图像的一部分,预测被遮住的内容
去噪 DDM / l-DAE 给图像加噪声,预测原始图像

扩散模型做表征学习

扩散模型的正向过程:

\[\mathbf{z}_t = \gamma_t \mathbf{z}_0 + \sigma_t \boldsymbol{\varepsilon}, \quad \boldsymbol{\varepsilon} \sim \mathcal{N}(0, \mathbf{I})\]

其中 $\gamma_t^2 + \sigma_t^2 = 1$。

训练目标(预测噪声):

\[\mathcal{L} = \|\boldsymbol{\varepsilon} - \text{net}(\mathbf{z}_t)\|^2\]

编码器(网络前半部分)学到的特征可以用于下游任务(如分类)。

起点:DiT-Large

组件 配置
编码器 DiT-L 前 12 层(ViT-½L)
解码器 DiT-L 后 12 层
分词器 VQGAN(有感知损失 + 对抗损失 + KL 正则)
噪声调度 EDM 调度(偏向高噪声)
条件 类别标签
基线性能 线性探测准确率 57.5%,FID 11.6

方法详解:三阶段解构

Stage 1:重新定向 DDM 用于自监督学习

1.1 去除类别条件

配置 线性探测准确率
有类别标签 57.5%
无类别标签 62.5%

反直觉:去掉标签反而更好——因为有标签时,网络可以”作弊”(直接从标签推断内容),不需要从图像学习语义。

1.2 简化 VQGAN 分词器

VQGAN 原始损失有 4 项:

\[\mathcal{L}_\text{VQGAN} = \underbrace{\|x - g(f(x))\|^2}_{\text{重建}} + \underbrace{\text{KL}}_{\text{正则}} + \underbrace{\text{Perceptual}}_{\text{感知(用 VGG)}} + \underbrace{\text{GAN}}_{\text{对抗}}\]

逐步去除:

操作 准确率
完整 VQGAN 62.5%
去掉感知损失 58.4%(下降!)
再去掉对抗损失 59.0%(恢复)

发现:感知损失和对抗损失可以一起去掉。

1.3 替换噪声调度

原始 EDM 调度在高噪声区域花费太多步骤。替换为线性调度:$\gamma_t^2$ 从 1 线性衰减到 0。

\[59.0\% \to 63.4\%\]

关键洞察:表征质量与生成质量解耦——更好的噪声调度不一定生成更好的图像,但学到更好的特征。

Stage 2:解构分词器

测试 4 种越来越简单的分词器:

2.1 卷积 VAE(基线)

深度卷积编码器/解码器 + KL 正则。

2.2 Patch-wise VAE

对每个 16×16 patch 做线性编码/解码:

\[\mathcal{L} = \|x - UV^T x\|^2 + \text{KL}[Vx | \mathcal{N}]\]

其中 $U, V \in \mathbb{R}^{d \times D}$($d$ 是潜在维度,$D = 768$ 是 patch 维度)。

2.3 Patch-wise AE

去掉 KL 正则:

\[\mathcal{L} = \|x - UV^T x\|^2\]

2.4 Patch-wise PCA

用 PCA 做无参数分词:

\[\mathcal{L} = \|x - V^T V x\|^2, \quad VV^T = \mathbf{I}\]

不需要梯度训练——直接对 patch 做 PCA 就够了。

关键发现:潜在维度分析

维度 $d$ 卷积 VAE Patch VAE Patch AE Patch PCA
8 54.5% 58.3% 59.9% 56.0%
16 63.4% 64.9% 64.7% 63.4%
32 62.8% 64.8% 64.6% 65.1%
64 57.0% 56.8% 59.9% 60.0%

三大发现

  1. 四种分词器趋势一致——架构差异不大
  2. 最优维度很低(16-32),远低于 patch 维度 768
  3. PCA 与学习的分词器性能相当——不需要训练!

Stage 3:走向经典去噪自编码器

从 Patch PCA 出发(65.1%),逐步简化到经典 DAE:

3.1 预测干净数据而非噪声

经典 DAE 预测原始数据:

\[\mathcal{L} = \lambda_t \|\mathbf{z}_0 - \text{net}(\mathbf{z}_t)\|^2\]

其中 $\lambda_t = \gamma_t^2$ 强调接近干净的样本。

\(65.1\% \to 62.4\%\)(暂时下降,但理论上更纯粹)

3.2 去除输入缩放($\gamma_t \equiv 1$)

令 $\gamma_t = 1$,噪声 $\sigma_t$ 从 0 线性增到 $\sqrt{2}$:

\[\mathbf{z}_t = \mathbf{z}_0 + \sigma_t \boldsymbol{\varepsilon}\]

损失权重:$\lambda_t = 1 / (1 + \sigma_t^2)$。

\[62.4\% \to 63.6\%\]

3.3 投影回图像空间

流程变为:图像 → PCA 投影到潜在空间 → 加噪 → 逆 PCA 投影回图像 → ViT 去噪。

\[63.6\% \to 63.9\%\]

3.4 直接预测原始图像

网络直接预测原始干净图像。损失考虑噪声维度和 PCA 重建误差维度的不同权重:

\[\mathcal{L} = \lambda_t \sum_{i=1}^{D} w_i r_i^2\]

其中:

  • $w_i = 1$($i \leq d$):噪声损坏的维度
  • $w_i = 0.1$($d < i \leq D$):PCA 丢弃的维度
\[63.9\% \to 64.5\%\]

3.5 单一噪声水平实验

固定 $\sigma = \sqrt{1/3}$:

\(64.5\% \to 61.5\%\)(下降 3%)

结论:多噪声水平起数据增强作用,有益但非必需。

最终 l-DAE 配置

1. 对图像 patch 做 PCA → 投影到 d=32 维潜在空间
2. 在潜在空间加高斯噪声
3. 通过逆 PCA 投影回图像空间
4. 训练 ViT 预测原始干净图像
5. 使用多水平噪声调度

实验结果

模型缩放

模型 参数量 线性探测准确率
ViT-B 86M 66.6%
ViT-½L 65.0%
ViT-L 304M 75.0%

从 ViT-B 到 ViT-L:+10.6%

训练时长

Epochs 准确率
400 65.0%
800 67.5%
1600 69.6%

与 MAE/MoCo v3 对比

线性探测

方法 ViT-B ViT-L
MoCo v3 76.7% 77.6%
MAE 68.0% 75.8%
l-DAE 66.6% 75.0%

微调

方法 ViT-B ViT-L
MoCo v3 83.2% 84.1%
MAE 83.6% 85.9%
l-DAE 83.7% 84.7%

微调时 l-DAE 与 MAE 几乎持平

数据增强的影响

增强 准确率
仅中心裁剪 64.5%
随机缩放裁剪 65.0%

l-DAE 对数据增强几乎不依赖(仅 +0.5%),而对比学习高度依赖增强。

个人思考

  1. “解构”方法论极具启发性:不是构建更复杂的系统,而是拆解现有系统找出真正重要的部分——这是理解深度学习的正确方式。
  2. 低维潜在空间是关键(16-32 维):不是原始像素(太高维),也不是 VAE 的 4 维(太低维),而是保留主要成分但去除噪声的中间维度。
  3. PCA = 学习的分词器:如此简单的无参数方法就能匹配复杂的 VAE——说明扩散模型真正需要的只是”降维到合适的维度”。
  4. 生成 ≠ 识别的又一证据:FID 更好的配置反而线性探测更差——优化生成质量和表征质量是不同的目标。
  5. 去噪是根本机制:不管是 DDM 的多步扩散还是经典 DAE 的单步去噪,学到的表征质量相似——去噪本身(而非扩散调度)才是自监督学习的核心。
← 返回列表