ClipGS 中的 STE

什么是 STE

Straight-Through Estimator,简称 STE,是一种深度学习技巧,主要用于解决模型中‌离散操作不可导导致梯度无法反向传播的问题,通过在前向计算保留离散反向传播近似梯度的方式实现端到端训练.‌‌‌

ClipGS 的问题描述

在 ClipGS 中,判断一个高斯的可见性的表达式是

M=1[(μ+δ)n<z],(1)\mathcal{M} = \bm{1} \, [(\bm{\mu} + \bm{\delta}) \cdot \bm{n} < z], \tag{1}

其中 μ\bm{\mu} 是高斯的均值,δ\bm{\delta} 是使 μ+δ\bm{\mu} + \bm{\delta} 更接近贡献中心的偏移,n\bm{n} 是裁剪平面的法线.由于离散的 M{0,1}\mathcal{M} \in \{0, 1\} 无法提供梯度信息,所以需要利用 STE 使 M\mathcal{M} 可学习.

Sigmoid 函数

m=(μ+δ)n,(2)m = (\bm{\mu} + \bm{\delta}) \cdot \bm{n}, \tag{2}

M=1[m<z].(3)\mathcal{M} = \bm{1} \, [m < z]. \tag{3}

通过引入 Sigmoid 函数

σ(x)=11+ex,\sigma(x) = \frac{1}{1 + e ^ {-x}},

可以将 mzm - z 变为 (0,1)(0, 1) 的连续值 σ(mz)\sigma(m - z),再做一步 step 处理,得到

1[σ(mz)<ϵ],(4)\bm{1} \, [\sigma(m - z) < \epsilon], \tag{4}

其中 ϵ\epsilon 为超参数阈值.当 ϵ=0.5\epsilon = 0.5 时,(3)(3) 等号右边与 (4)(4) 在前向结果上等价.

应用 STE

s:=σ(mz),h:=1[σ(mz)<ϵ],\begin{aligned} s &:= \sigma(m - z), \\ h &:= \bm{1} \, [\sigma(m - z) < \epsilon], \end{aligned}

那么我们需要让 M\mathcal{M} 在前向计算时,结果为与原始形式等效的 hh;而在反向传播时,结果为连续可导的 ss.于是就有

M=sg(hs)+s,(5)\mathcal{M} = \mathrm{sg}(h - s) + s, \tag{5}

其中 sg()\mathrm{sg}(\cdot)stop gradient\text{stop gradient} 运算,即反向传播时不求梯度.展开 (5)(5) 式得到

M=sg(1[σ(mz)<ϵ]σ(mz))+σ(mz).(6)\mathcal{M} = \mathrm{sg}(\bm{1}\,[\sigma(m - z) < \epsilon] - \sigma(m - z)) + \sigma(m - z). \tag{6}