← 返回列表

Grendel-GS: On Scaling Up 3D Gaussian Splatting Training

作者 Hexu Zhao, Haoyang Weng, Daohan Lu, Ang Li, Jinyang Li, Aurojit Panda, Saining Xie
年份 2024
会议/期刊 arXiv 2024
评分
标签 3D重建 高斯溅射 分布式训练
摘要 首个多 GPU 分布式 3DGS 训练系统:稀疏 All-to-All 通信 + 动态负载均衡 + sqrt(batch) 超参缩放规则,16 GPU 支持 4000 万高斯,4K 场景 PSNR 27.28

核心思想

3D 高斯溅射(3DGS) 是一种高效的 3D 重建和渲染方法,但训练只能在单 GPU 上运行 → 显存限制了高斯数量 → 限制了大场景重建质量。

Grendel 是首个多 GPU 分布式 3DGS 训练系统:

  1. 混合并行:高斯级分布 + 像素级分布
  2. 稀疏通信:利用空间局部性,通信量减少 90%
  3. 动态负载均衡:避免天空等简单区域浪费 GPU
  4. 超参数缩放规则:$\sqrt{\text{batch_size}}$ 学习率缩放

背景知识

什么是 3D 高斯溅射

3DGS 用数百万个 3D 高斯椭球表示场景。每个高斯有以下参数:

参数 维度 含义
位置 $\mathbf{x}_i$ $\mathbb{R}^3$ 高斯中心的 3D 坐标
缩放 $\mathbf{s}_i$ $\mathbb{R}^3$ 椭球三个轴的大小
旋转 $\mathbf{q}_i$ $\mathbb{R}^4$ 四元数表示的朝向
不透明度 $\alpha_i$ $\mathbb{R}^1$ 透明度
球谐系数 $\text{sh}_i$ $\mathbb{R}^{48}$ 视角相关的颜色

训练流程

  1. 高斯变换:将 3D 高斯投影到屏幕空间,得到 2D 位置、深度、覆盖半径、颜色
  2. 图像渲染:对每个像素,找到覆盖它的所有高斯 → 按深度排序 → Alpha 合成
  3. 损失计算:L1 + SSIM 损失
  4. 反向传播:梯度更新高斯参数
  5. 致密化:克隆/分裂高质量高斯,剪枝低透明度高斯

单 GPU 的瓶颈

瓶颈 原因
高斯参数显存 1000 万高斯 ≈ 数 GB
Z-buffer 显存 存储每像素交叉的高斯索引
4K 分辨率 单帧渲染需要大量显存

单张 A100-40G 最多支持约 1100 万高斯——对于城市级场景远远不够。

方法详解

1. 混合并行策略

1.1 高斯级分布(Gaussian-wise Distribution)

将 $N$ 个高斯均匀分配给 $K$ 个 GPU:

  • GPU 0:高斯 $1$ 到 $N/K$
  • GPU 1:高斯 $N/K+1$ 到 $2N/K$

每个 GPU 独立计算所分配高斯的变换(3D → 2D 投影)。

1.2 像素级分布(Pixel-wise Distribution)

将图像分割为连续的像素块(16×16 粒度),分配给不同 GPU 做渲染和损失计算

1.3 两阶段切换

阶段 1:高斯级(每 GPU 处理自己的高斯子集)
    ↓ 稀疏 All-to-All 通信
阶段 2:像素级(每 GPU 渲染自己的像素区域)
    ↓ 反向稀疏 All-to-All 通信
阶段 1:高斯级(梯度更新回各 GPU 的高斯)

2. 稀疏 All-to-All 通信

核心观察:90% 的高斯在屏幕上的覆盖半径 < 图像宽度的 2%。

这意味着每个像素块只需要很少一部分高斯的信息 → 不需要全局广播所有高斯。

过程

  1. 每个 GPU 计算自己高斯的屏幕覆盖范围
  2. 确定每个高斯只需发送给哪些像素分区的 GPU
  3. 只发送必要的高斯数据

通信量比稠密 All-to-All 减少约 90%

3. 动态负载均衡

问题:不同像素区域的渲染计算量差异巨大——天空区域几乎没有高斯,建筑区域密集。

3.1 像素级均衡算法

for 每个 epoch:
    记录每个 16×16 像素块的渲染时间 ET[i]
    计算累积时间 CumSum = cumsum(ET)
    目标:每 GPU 分配的总时间 = Total / K
    用 searchsorted 找到分割点
    重新分配像素块给各 GPU

3.2 高斯级均衡

致密化(克隆/分裂)会导致高斯分布不均匀。定期重新分配高斯给各 GPU。

4. 超参数缩放规则

当 batch size(同时训练的视角数)增大时,如何调整超参数?

4.1 理论推导

假设不同视角的梯度独立

Batch size = 1 的 Adam 更新

\[\Delta^{(k)} = \frac{g_k}{\sqrt{\mathbb{E}[|V| g^2]}} = \frac{g_k}{\sqrt{|V|} \cdot \sqrt{\mathbb{E}[g^2]}}\]

Batch size = b 的 Adam 更新

\[\Delta^{(B)} = \frac{\sum_{k \in B} g_k / b}{\sqrt{\mathbb{E}[(|V|/b) g^2]}} = \frac{1}{\sqrt{b}} \cdot \frac{\sum_{k \in B} g_k}{\sqrt{|V|} \cdot \sqrt{\mathbb{E}[g^2]}}\]

比较两者:batch 增大 $b$ 倍 → 更新量缩小 $\sqrt{b}$ 倍。

4.2 缩放规则

学习率缩放

\[\lambda' = \lambda \times \sqrt{\text{batch\_size}}\]

动量缩放

\[\beta_1' = \beta_1^{\text{batch\_size}}, \quad \beta_2' = \beta_2^{\text{batch\_size}}\]

这确保不同 batch size 下更新的余弦相似度和幅度保持一致。

实验结果

显存扩展

GPU 数 batch=1 最大高斯数 batch=16 最大高斯数
1 12.71M OOM
4 63.44M 19.55M
16 230.41M 74.98M

显存随 GPU 数线性扩展

4K Rubble 场景质量

方法 高斯数 PSNR
单 GPU 基线 11.2M 26.28
Grendel 4 GPU
Grendel 16 GPU 40.4M 27.28

高斯数增加 3.6 倍 → PSNR 提升 1.0 dB。

吞吐量

配置 吞吐量(图/秒)
4 GPU, batch=1 5.55
32 GPU, batch=64 38.03

与 CityGaussian 对比

方法 Rubble PSNR 时间
CityGaussian 25.88 2.88-8.25 小时
Grendel 27.39 0.85-1.22 小时

质量更好,速度快 3-7 倍

缩放规则消融

缩放方式 余弦相似度
常数学习率 低(不同 batch 更新方向不一致)
线性缩放 发散
$\sqrt{\text{batch}}$ 缩放 高(一致)
动量缩放 效果
常数动量 随 batch 增大退化
指数缩放 $\beta^b$ 最佳

梯度独立性验证

方差的倒数随 batch size 线性增长(直到 batch=32 后趋于平稳)→ 支持独立性假设在实际范围内成立。

个人思考

  1. 系统论文的价值:3DGS 的算法创新很多,但”如何训练大规模场景”这个工程问题一直未解——Grendel 填补了这个空白。
  2. 稀疏通信的 90% 节省来自一个简单观察(高斯覆盖范围小)——最有效的优化往往来自对问题结构的理解。
  3. $\sqrt{\text{batch}}$ 规则与分布式深度学习的线性缩放规则不同——因为 3DGS 用 Adam 而非 SGD,且高斯参数的更新模式特殊。
  4. 动态负载均衡的必要性:3D 场景的密度极不均匀(天空 vs 建筑),静态分区会导致严重的 GPU 利用不均。
  5. 硬件依赖:实验在高端集群(A100 + NVLink + 200Gbps 网络)上完成——通信效率在消费级硬件上可能更关键。
← 返回列表