Deconstructing Denoising Diffusion Models for Self-Supervised Learning
核心思想
扩散模型(DDM)不仅能生成图像,还能学习视觉表征。但 DDM 有很多复杂组件(VQGAN、噪声调度、扩散过程等),哪些对表征学习真正重要?
本文通过”解构“方法——逐步去除现代组件——发现:
- 只有少数组件对表征学习关键
- 最终模型 l-DAE(latent Denoising Autoencoder)极其简单:PCA 投影 + 加噪声 + ViT 去噪
- l-DAE 接近 MAE 的性能,揭示扩散模型学表征的本质机制
背景知识
自监督学习的两大范式
| 范式 | 代表方法 | 核心思想 |
|---|---|---|
| 对比学习 | MoCo v3 | 让相似图像的表征接近,不相似的远离 |
| 掩码预测 | MAE | 遮住图像的一部分,预测被遮住的内容 |
| 去噪 | DDM / l-DAE | 给图像加噪声,预测原始图像 |
扩散模型做表征学习
扩散模型的正向过程:
\[\mathbf{z}_t = \gamma_t \mathbf{z}_0 + \sigma_t \boldsymbol{\varepsilon}, \quad \boldsymbol{\varepsilon} \sim \mathcal{N}(0, \mathbf{I})\]其中 $\gamma_t^2 + \sigma_t^2 = 1$。
训练目标(预测噪声):
\[\mathcal{L} = \|\boldsymbol{\varepsilon} - \text{net}(\mathbf{z}_t)\|^2\]编码器(网络前半部分)学到的特征可以用于下游任务(如分类)。
起点:DiT-Large
| 组件 | 配置 |
|---|---|
| 编码器 | DiT-L 前 12 层(ViT-½L) |
| 解码器 | DiT-L 后 12 层 |
| 分词器 | VQGAN(有感知损失 + 对抗损失 + KL 正则) |
| 噪声调度 | EDM 调度(偏向高噪声) |
| 条件 | 类别标签 |
| 基线性能 | 线性探测准确率 57.5%,FID 11.6 |
方法详解:三阶段解构
Stage 1:重新定向 DDM 用于自监督学习
1.1 去除类别条件
| 配置 | 线性探测准确率 |
|---|---|
| 有类别标签 | 57.5% |
| 无类别标签 | 62.5% |
反直觉:去掉标签反而更好——因为有标签时,网络可以”作弊”(直接从标签推断内容),不需要从图像学习语义。
1.2 简化 VQGAN 分词器
VQGAN 原始损失有 4 项:
\[\mathcal{L}_\text{VQGAN} = \underbrace{\|x - g(f(x))\|^2}_{\text{重建}} + \underbrace{\text{KL}}_{\text{正则}} + \underbrace{\text{Perceptual}}_{\text{感知(用 VGG)}} + \underbrace{\text{GAN}}_{\text{对抗}}\]逐步去除:
| 操作 | 准确率 |
|---|---|
| 完整 VQGAN | 62.5% |
| 去掉感知损失 | 58.4%(下降!) |
| 再去掉对抗损失 | 59.0%(恢复) |
发现:感知损失和对抗损失可以一起去掉。
1.3 替换噪声调度
原始 EDM 调度在高噪声区域花费太多步骤。替换为线性调度:$\gamma_t^2$ 从 1 线性衰减到 0。
\[59.0\% \to 63.4\%\]关键洞察:表征质量与生成质量解耦——更好的噪声调度不一定生成更好的图像,但学到更好的特征。
Stage 2:解构分词器
测试 4 种越来越简单的分词器:
2.1 卷积 VAE(基线)
深度卷积编码器/解码器 + KL 正则。
2.2 Patch-wise VAE
对每个 16×16 patch 做线性编码/解码:
\[\mathcal{L} = \|x - UV^T x\|^2 + \text{KL}[Vx | \mathcal{N}]\]其中 $U, V \in \mathbb{R}^{d \times D}$($d$ 是潜在维度,$D = 768$ 是 patch 维度)。
2.3 Patch-wise AE
去掉 KL 正则:
\[\mathcal{L} = \|x - UV^T x\|^2\]2.4 Patch-wise PCA
用 PCA 做无参数分词:
\[\mathcal{L} = \|x - V^T V x\|^2, \quad VV^T = \mathbf{I}\]不需要梯度训练——直接对 patch 做 PCA 就够了。
关键发现:潜在维度分析
| 维度 $d$ | 卷积 VAE | Patch VAE | Patch AE | Patch PCA |
|---|---|---|---|---|
| 8 | 54.5% | 58.3% | 59.9% | 56.0% |
| 16 | 63.4% | 64.9% | 64.7% | 63.4% |
| 32 | 62.8% | 64.8% | 64.6% | 65.1% |
| 64 | 57.0% | 56.8% | 59.9% | 60.0% |
三大发现:
- 四种分词器趋势一致——架构差异不大
- 最优维度很低(16-32),远低于 patch 维度 768
- PCA 与学习的分词器性能相当——不需要训练!
Stage 3:走向经典去噪自编码器
从 Patch PCA 出发(65.1%),逐步简化到经典 DAE:
3.1 预测干净数据而非噪声
经典 DAE 预测原始数据:
\[\mathcal{L} = \lambda_t \|\mathbf{z}_0 - \text{net}(\mathbf{z}_t)\|^2\]其中 $\lambda_t = \gamma_t^2$ 强调接近干净的样本。
\(65.1\% \to 62.4\%\)(暂时下降,但理论上更纯粹)
3.2 去除输入缩放($\gamma_t \equiv 1$)
令 $\gamma_t = 1$,噪声 $\sigma_t$ 从 0 线性增到 $\sqrt{2}$:
\[\mathbf{z}_t = \mathbf{z}_0 + \sigma_t \boldsymbol{\varepsilon}\]损失权重:$\lambda_t = 1 / (1 + \sigma_t^2)$。
\[62.4\% \to 63.6\%\]3.3 投影回图像空间
流程变为:图像 → PCA 投影到潜在空间 → 加噪 → 逆 PCA 投影回图像 → ViT 去噪。
\[63.6\% \to 63.9\%\]3.4 直接预测原始图像
网络直接预测原始干净图像。损失考虑噪声维度和 PCA 重建误差维度的不同权重:
\[\mathcal{L} = \lambda_t \sum_{i=1}^{D} w_i r_i^2\]其中:
- $w_i = 1$($i \leq d$):噪声损坏的维度
- $w_i = 0.1$($d < i \leq D$):PCA 丢弃的维度
3.5 单一噪声水平实验
固定 $\sigma = \sqrt{1/3}$:
\(64.5\% \to 61.5\%\)(下降 3%)
结论:多噪声水平起数据增强作用,有益但非必需。
最终 l-DAE 配置
1. 对图像 patch 做 PCA → 投影到 d=32 维潜在空间
2. 在潜在空间加高斯噪声
3. 通过逆 PCA 投影回图像空间
4. 训练 ViT 预测原始干净图像
5. 使用多水平噪声调度
实验结果
模型缩放
| 模型 | 参数量 | 线性探测准确率 |
|---|---|---|
| ViT-B | 86M | 66.6% |
| ViT-½L | — | 65.0% |
| ViT-L | 304M | 75.0% |
从 ViT-B 到 ViT-L:+10.6%。
训练时长
| Epochs | 准确率 |
|---|---|
| 400 | 65.0% |
| 800 | 67.5% |
| 1600 | 69.6% |
与 MAE/MoCo v3 对比
线性探测:
| 方法 | ViT-B | ViT-L |
|---|---|---|
| MoCo v3 | 76.7% | 77.6% |
| MAE | 68.0% | 75.8% |
| l-DAE | 66.6% | 75.0% |
微调:
| 方法 | ViT-B | ViT-L |
|---|---|---|
| MoCo v3 | 83.2% | 84.1% |
| MAE | 83.6% | 85.9% |
| l-DAE | 83.7% | 84.7% |
微调时 l-DAE 与 MAE 几乎持平。
数据增强的影响
| 增强 | 准确率 |
|---|---|
| 仅中心裁剪 | 64.5% |
| 随机缩放裁剪 | 65.0% |
l-DAE 对数据增强几乎不依赖(仅 +0.5%),而对比学习高度依赖增强。
个人思考
- “解构”方法论极具启发性:不是构建更复杂的系统,而是拆解现有系统找出真正重要的部分——这是理解深度学习的正确方式。
- 低维潜在空间是关键(16-32 维):不是原始像素(太高维),也不是 VAE 的 4 维(太低维),而是保留主要成分但去除噪声的中间维度。
- PCA = 学习的分词器:如此简单的无参数方法就能匹配复杂的 VAE——说明扩散模型真正需要的只是”降维到合适的维度”。
- 生成 ≠ 识别的又一证据:FID 更好的配置反而线性探测更差——优化生成质量和表征质量是不同的目标。
- 去噪是根本机制:不管是 DDM 的多步扩散还是经典 DAE 的单步去噪,学到的表征质量相似——去噪本身(而非扩散调度)才是自监督学习的核心。