Grendel-GS: On Scaling Up 3D Gaussian Splatting Training
核心思想
3D 高斯溅射(3DGS) 是一种高效的 3D 重建和渲染方法,但训练只能在单 GPU 上运行 → 显存限制了高斯数量 → 限制了大场景重建质量。
Grendel 是首个多 GPU 分布式 3DGS 训练系统:
- 混合并行:高斯级分布 + 像素级分布
- 稀疏通信:利用空间局部性,通信量减少 90%
- 动态负载均衡:避免天空等简单区域浪费 GPU
- 超参数缩放规则:$\sqrt{\text{batch_size}}$ 学习率缩放
背景知识
什么是 3D 高斯溅射
3DGS 用数百万个 3D 高斯椭球表示场景。每个高斯有以下参数:
| 参数 | 维度 | 含义 |
|---|---|---|
| 位置 $\mathbf{x}_i$ | $\mathbb{R}^3$ | 高斯中心的 3D 坐标 |
| 缩放 $\mathbf{s}_i$ | $\mathbb{R}^3$ | 椭球三个轴的大小 |
| 旋转 $\mathbf{q}_i$ | $\mathbb{R}^4$ | 四元数表示的朝向 |
| 不透明度 $\alpha_i$ | $\mathbb{R}^1$ | 透明度 |
| 球谐系数 $\text{sh}_i$ | $\mathbb{R}^{48}$ | 视角相关的颜色 |
训练流程
- 高斯变换:将 3D 高斯投影到屏幕空间,得到 2D 位置、深度、覆盖半径、颜色
- 图像渲染:对每个像素,找到覆盖它的所有高斯 → 按深度排序 → Alpha 合成
- 损失计算:L1 + SSIM 损失
- 反向传播:梯度更新高斯参数
- 致密化:克隆/分裂高质量高斯,剪枝低透明度高斯
单 GPU 的瓶颈
| 瓶颈 | 原因 |
|---|---|
| 高斯参数显存 | 1000 万高斯 ≈ 数 GB |
| Z-buffer 显存 | 存储每像素交叉的高斯索引 |
| 4K 分辨率 | 单帧渲染需要大量显存 |
单张 A100-40G 最多支持约 1100 万高斯——对于城市级场景远远不够。
方法详解
1. 混合并行策略
1.1 高斯级分布(Gaussian-wise Distribution)
将 $N$ 个高斯均匀分配给 $K$ 个 GPU:
- GPU 0:高斯 $1$ 到 $N/K$
- GPU 1:高斯 $N/K+1$ 到 $2N/K$
- …
每个 GPU 独立计算所分配高斯的变换(3D → 2D 投影)。
1.2 像素级分布(Pixel-wise Distribution)
将图像分割为连续的像素块(16×16 粒度),分配给不同 GPU 做渲染和损失计算。
1.3 两阶段切换
阶段 1:高斯级(每 GPU 处理自己的高斯子集)
↓ 稀疏 All-to-All 通信
阶段 2:像素级(每 GPU 渲染自己的像素区域)
↓ 反向稀疏 All-to-All 通信
阶段 1:高斯级(梯度更新回各 GPU 的高斯)
2. 稀疏 All-to-All 通信
核心观察:90% 的高斯在屏幕上的覆盖半径 < 图像宽度的 2%。
这意味着每个像素块只需要很少一部分高斯的信息 → 不需要全局广播所有高斯。
过程:
- 每个 GPU 计算自己高斯的屏幕覆盖范围
- 确定每个高斯只需发送给哪些像素分区的 GPU
- 只发送必要的高斯数据
通信量比稠密 All-to-All 减少约 90%。
3. 动态负载均衡
问题:不同像素区域的渲染计算量差异巨大——天空区域几乎没有高斯,建筑区域密集。
3.1 像素级均衡算法
for 每个 epoch:
记录每个 16×16 像素块的渲染时间 ET[i]
计算累积时间 CumSum = cumsum(ET)
目标:每 GPU 分配的总时间 = Total / K
用 searchsorted 找到分割点
重新分配像素块给各 GPU
3.2 高斯级均衡
致密化(克隆/分裂)会导致高斯分布不均匀。定期重新分配高斯给各 GPU。
4. 超参数缩放规则
当 batch size(同时训练的视角数)增大时,如何调整超参数?
4.1 理论推导
假设不同视角的梯度独立。
Batch size = 1 的 Adam 更新:
\[\Delta^{(k)} = \frac{g_k}{\sqrt{\mathbb{E}[|V| g^2]}} = \frac{g_k}{\sqrt{|V|} \cdot \sqrt{\mathbb{E}[g^2]}}\]Batch size = b 的 Adam 更新:
\[\Delta^{(B)} = \frac{\sum_{k \in B} g_k / b}{\sqrt{\mathbb{E}[(|V|/b) g^2]}} = \frac{1}{\sqrt{b}} \cdot \frac{\sum_{k \in B} g_k}{\sqrt{|V|} \cdot \sqrt{\mathbb{E}[g^2]}}\]比较两者:batch 增大 $b$ 倍 → 更新量缩小 $\sqrt{b}$ 倍。
4.2 缩放规则
学习率缩放:
\[\lambda' = \lambda \times \sqrt{\text{batch\_size}}\]动量缩放:
\[\beta_1' = \beta_1^{\text{batch\_size}}, \quad \beta_2' = \beta_2^{\text{batch\_size}}\]这确保不同 batch size 下更新的余弦相似度和幅度保持一致。
实验结果
显存扩展
| GPU 数 | batch=1 最大高斯数 | batch=16 最大高斯数 |
|---|---|---|
| 1 | 12.71M | OOM |
| 4 | 63.44M | 19.55M |
| 16 | 230.41M | 74.98M |
显存随 GPU 数线性扩展。
4K Rubble 场景质量
| 方法 | 高斯数 | PSNR |
|---|---|---|
| 单 GPU 基线 | 11.2M | 26.28 |
| Grendel 4 GPU | — | — |
| Grendel 16 GPU | 40.4M | 27.28 |
高斯数增加 3.6 倍 → PSNR 提升 1.0 dB。
吞吐量
| 配置 | 吞吐量(图/秒) |
|---|---|
| 4 GPU, batch=1 | 5.55 |
| 32 GPU, batch=64 | 38.03 |
与 CityGaussian 对比
| 方法 | Rubble PSNR | 时间 |
|---|---|---|
| CityGaussian | 25.88 | 2.88-8.25 小时 |
| Grendel | 27.39 | 0.85-1.22 小时 |
质量更好,速度快 3-7 倍。
缩放规则消融
| 缩放方式 | 余弦相似度 |
|---|---|
| 常数学习率 | 低(不同 batch 更新方向不一致) |
| 线性缩放 | 发散 |
| $\sqrt{\text{batch}}$ 缩放 | 高(一致) |
| 动量缩放 | 效果 |
|---|---|
| 常数动量 | 随 batch 增大退化 |
| 指数缩放 $\beta^b$ | 最佳 |
梯度独立性验证
方差的倒数随 batch size 线性增长(直到 batch=32 后趋于平稳)→ 支持独立性假设在实际范围内成立。
个人思考
- 系统论文的价值:3DGS 的算法创新很多,但”如何训练大规模场景”这个工程问题一直未解——Grendel 填补了这个空白。
- 稀疏通信的 90% 节省来自一个简单观察(高斯覆盖范围小)——最有效的优化往往来自对问题结构的理解。
- $\sqrt{\text{batch}}$ 规则与分布式深度学习的线性缩放规则不同——因为 3DGS 用 Adam 而非 SGD,且高斯参数的更新模式特殊。
- 动态负载均衡的必要性:3D 场景的密度极不均匀(天空 vs 建筑),静态分区会导致严重的 GPU 利用不均。
- 硬件依赖:实验在高端集群(A100 + NVLink + 200Gbps 网络)上完成——通信效率在消费级硬件上可能更关键。