Advertisement

Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis 笔记

阅读量:

Spacetime Gaussian Feature Splatting for Real-Time Dynamic View Synthesis

🚩解决动态场景的视图合成,达到高分辨率逼真效果、实时渲染、体积小的目标。

💡提出 Spacetime Gaussian Feature Splatting,由三部分构成:

  • 新的高斯表示 :被 temporal opacity 和 parametric motion/rotation 强化的时空高斯。使得时空高斯能够捕捉静态、动态和瞬态内容。
  • 引入 splatted feature rendering :用神经特征 neural features 替换球谐波 spherical harmonics。Splatted feature 解决视图和时间相关的外观建模,同时保持小尺寸。
  • 利用训练误差和粗糙深度引导采样新区域内的高斯 :解决难以与存在的pipeline融合(coverge)。

🚀性能:8K分辨率,轻量化版本能够在 RTX 4090 上以 60 FPS 渲染。

3D Gaussian Splatting

给定在已知相机姿态多视点的图像,3D Gaussian Splatting 通过可微光栅化(differentiable rasterization)优化各向异性的3D高斯,从而表示静态3D场景。有效的光栅化使得模型能够实时渲染高保真视图。

3D 高斯 ii 和 位置μi\mu_i,协方差矩阵 ∑i\sum_i,opacity σi\sigma_i,spherical harmonics hih_i关联。
在任意空间点 x,最终的 3D 高斯 opacity:
αi=σiexp⁡(−12(x−μi)T∑i−1(x−μi)).(1)\alpha_i = \sigma_i \exp(-\frac{1}{2}(x-\mu_i)^T \sum^{-1}_{i}(x-\mu_i)). (1)
∑i\sum_i是半正定,能解耦出 scaling matrix SiS_i 和 rotation matrix RiR_i:
∑i=RiSiSiTRiT\sum_i=R_i S_i S^T_i R^T_i

渲染流程 3D →\rightarrow 2D:

3D高斯首先通过透视变换的近似值投影到2D图像空间。3D高斯投影近似为2D高斯(有center μi2D\mu^{2D}_i 和 covariance ∑i2D\sum^{2D}_i)投影。W,KW, K分别为视角变换和投影矩阵。
计算μi2D\mu^{2D}_i:
μi2D=(K((Wμi)/(Wμi)z))1:2\mu^{2D}_i=(K((W\mu_i)/(W\mu_i)z)){1:2}
计算∑i2D\sum^{2D}_i:
∑i2D=(JW∑iWTJT)1:2,1:2\sum^{2D}i=(J W\sum_i W^T J^T){1:2,1:2}
WW为平移旋转矩阵,JJ为投影变换矩阵KK的雅可比矩阵。

对高斯按照深度值进行排序,像素颜色通过体积渲染得到的:
I=∑i∈Nciαi2D∏j=1i−1(1−αj2D)I=\sum_{i\in N} c_i \alpha^{2D}_i \prod{i-1}_{j=1}(1-\alpha{2D}_j)
αi2D\alpha^{2D}_i是公式(1)的2D版本, μi,∑i,x\mu_i, \sum_i, x被替换为对应的像素坐标;
cic_i是它是用视图方向和系数 hih_i 评估 SH 后的 RGB 颜色。

方法

点云作为输入
Spacetime Gaussian

  • 给 3D 高斯加入了时间。
  • 引入 temporal radial basis function, 编码 temporal opacity,有效建模场景内容的出现和消失。
  • 利用 time-conditioned parametric functions, 建模 3D 高斯的旋转和移动。

对于时空点(x,t)(x,t), STG 的 opacity:
αi(t)=σi(t)exp⁡(−12(x−μi(t))T∑i−1(x−μi(t)))\alpha_i(t)=\sigma_i(t)\exp(-\frac{1}{2}(x-\mu_i(t))^T \sum^{-1}_{i}(x-\mu_i(t)))

Temporal radial basis function →\rightarrow temporal opacity

σi(t)=σisexp⁡(−siτ∣t−μiτ∣2)\sigma_i(t)=\sigma^{s}_i \exp(-s^{\tau}_i \vert t-\mu{\tau}_i\vert2)
μiτ\mu{\tau}_i时间中心;siτs{\tau}_i时间尺度因子;σis\sigma^{s}_i是 time-independent spatial opacity.
在渲染器中输入了 Temploral radial basis function:

复制代码
    render_pkg = render(viewpoint_cam, gaussians, pipe, background,  
                    override_color=None,  
                    basicfunction=rbfbasefunction,     # temporal radial basis function
                    GRsetting=GRsetting, GRzer=GRzer)
    
    
    python
    
    

渲染器中实现细节:

复制代码
    pc: guassianmodel
    pointtimes = torch.ones((pc.get_xyz.shape[0], 1), 
                         dtype=pc.get_xyz.dtype, 
                         requires_grad=False, device="cuda") + 0
    
    # 定义时间中心,缩放参数
    trbfcenter = pc.get_trbfcenter
    trbfscale = pc.get_trbfscale
    pointopacity = pc.get_opacity    # time-independent spatial opacity
    
    trbfdistanceoffset = viewpoint_camera.timestamp * pointtimes - trbfcenter # t - u
    trbfdistance =  trbfdistanceoffset / torch.exp(trbfscale)     # 归一化
    trbfoutput = basicfunction(trbfdistance)
    
    opacity = pointopacity * trbfoutput  # - 0.5, temporal opacity
    pc.trbfoutput = trbfoutput
    
    def trbfunction(x): 
    return torch.exp(-1*x.pow(2))
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/vP2JTbQw9fGWa5UokSLFc3mujRY8.png)

Time-conditioned function →\rightarrow motion & rotation

Polynomial motion trajectory

对于每个 STG,采用 time-conditioned function 建模其运动。
μi(t)=∑k=0npbi,k(t−μiτ)k\mu_i(t)=\sum{n_p}_{k=0}b_{i,k}(t-\mu{\tau}i)^k
μi(t)\mu_i(t)是STG在时刻 t 的空间位置。{bi,k}k=0np{b
{i,k}}^{n_p}_{k=0}是对应的多项系数,可学习。
组合 temporal radial basis function 和 time-conditioned parametric functions for polynomial motion trajectory,复杂长动作能够被多个简单动作表示。
nq=3n_q=3 in this paper.

渲染器中实现方法:

复制代码
    means3D = pc.get_xyz
    tforpoly = trbfdistanceoffset.detach()    # t - u
    means3D = means3D +  
          pc._motion[:, 0:3] * tforpoly + 
          pc._motion[:, 3:6] * tforpoly * tforpoly + 
          pc._motion[:, 6:9] * tforpoly * tforpoly * tforpoly
    
    
    python
    
    

Polynomial rotation

使用 real-valued quaternion 参数化公式∑i=RiSiSiTRiT\sum_i=R_i S_i S^T_i R^T_i中的 rotation matrix RiR_i.
类似 motion trajectory,使用多项式表示四元数 quaternion:
qi(t)=∑k=0nqci,k(t−μiτ)kq_i(t)=\sum{n_q}_{k=0}c_{i,k}(t-\mu{\tau}{i})^k
qi(t)q_i(t) is the rotation of an STG at time tt. {ci,k}k=0nq{c
{i,k}}^{n_q}_{k=0} are polynomial coefficients.
nq=1n_q=1 in this paper.
After qi(t)→rotation matrix Ri(t)q_i(t)\rightarrow rotation\ matrix\ R_i(t), covariance ∑i(t)\sum_i(t) can be obtained via ∑i=RiSiSiTRiT\sum_i=R_i S_i S^T_i R^T_i.

渲染器中实现方法:

复制代码
    rotations = pc.get_rotation(tforpoly)    # t - u
    
    def get_rotation(self, delta_t):
    rotation =  self._rotation + delta_t * self._omega
    self.delta_t = delta_t
    return self.rotation_activation(rotation)
    
    
    python
    
    

Splatted Feature Rendering

To encode view- and time-dependent radiance both accurately and compactly, they change the method that store features.
The features fi(t)∈R3f_i(t) \in \R^3:
fi(t)=[fibase,fidir,(t−μiτ)fitime]f_i(t)=[f{base}_i,f{dir}_i,(t-\mu{\tau}_i)f{time}_i]
fibase∈R3f^{base}_i\in \R^3 base RGB color; fidir,fitime∈R3f{dir}_i,f{time}i \in \R^3 encode information related to view direction and time.
features fi(t)f_i(t)replace RGB colorcic_i in I=∑i∈Nciαi2D∏j=1i−1(1−αj2D)I=\sum
{i\in N} c_i \alpha^{2D}_i \prod{i-1}_{j=1}(1-\alpha{2D}_j)

After splatting to image space, they split the splatted features at each pixel into Fbase,Fdir,FtimeF^{base}, F^{dir}, F^{time}.

渲染器中实现方法:

复制代码
    colors_precomp = pc.get_features(tforpoly)  # 预先计算的颜色值
    rendered_image, radii, depth = rasterizer(
                                            means3D = means3D,    # xyz
                                            means2D = means2D,    # screenspace points
                                            shs = shs,
                                            colors_precomp = colors_precomp,
                                            opacities = opacity,
                                            scales = scales,
                                            rotations = rotations,
                                            cov3D_precomp = cov3D_precomp)
    
    
    def get_features(self, deltat):
    return torch.cat((self._features_dc, deltat * self._features_t), dim=1)
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/fWpaN58lmhtzMq602kHdYv43I97U.png)

The final RGB color at each pixel is obtained after going through a 2-layer MLP Φ\Phi:
I=Fbase+Φ(Fdir,Ftime,r)I=F{base}+\Phi(F{dir},F^{time},r)
rr is the view direction at the pixel and is additionally concatenated with the features as input.

复制代码
    rgbdecoder = Sandwich(9, 3)
    
    rendered_image = pc.rgbdecoder(rendered_image.unsqueeze(0), 
                               viewpoint_camera.rays, # r, the view direction at the pixel
                               viewpoint_camera.timestamp) # 1 , 3
    rendered_image = rendered_image.squeeze(0)
    
    
    class Sandwich(nn.Module):
    def __init__(self, dim, outdim=3, bias=False):
        super(Sandwich, self).__init__()
        
        self.mlp1 = nn.Conv2d(12, 6, kernel_size=1, bias=bias)
        self.mlp2 = nn.Conv2d(6, 3, kernel_size=1, bias=bias)
        self.relu = nn.ReLU()
    
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, input, rays, time=None):
        albedo, spec, timefeature = input.chunk(3,dim=1)    # f_base, f_direction, f_time
        specular = torch.cat([spec, timefeature, rays], dim=1)  # 3+3 + 5
        specular = self.mlp1(specular)
        specular = self.relu(specular)
        specular = self.mlp2(specular)
    
        result = albedo + specular
        result = self.sigmoid(result) 
        return result
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/o9vX8IeTSDsRFLJzNMpuiVPGK3ja.png)

Sampling

解决:稀疏 高斯区域 和 距离摄像机太远的区域 很难收敛高质量渲染。

实现方法

找到 errors 较大的区域 →\rightarrow 找到区域中心像素 →\rightarrow 规定深度范围 →\rightarrow 找到经过中心点像素的 ray →\rightarrow 添加新的 Gaussian 点

  • 在训练 loss 稳定后,进行采样操作,保证采样的有效性;
  • 选择 errors 较大的区域,采用 patch 的方法聚合 training errors;
复制代码
    ssimcurrent = ssim(image.detach(), gt_image.detach()).item()
    if ssimcurrent < 0.88:
    imageadjust = image / (torch.mean(image) + 0.01)
    gtadjust = gt_image / (torch.mean(gt_image) + 0.01)
    diff = torch.abs(imageadjust - gtadjust)
    diff = torch.sum(diff, dim=0) # h, w
    
    ''' 取接近中间位置的 diff 作为阈值 '''
    diff_sorted, _ = torch.sort(diff.reshape(-1)) 
    numpixels = diff.shape[0] * diff.shape[1]
    threshold = diff_sorted[int(numpixels*opt.emsthr)].item()    # opt.emsthr = 0.6
    
    # 标记 errors 较大像素点
    outmask = diff > threshold    
    kh, kw = 16, 16 # kernel size
    dh, dw = 16, 16 # 垂直和水平方向的 stride
    # compute padding  
    idealh, idealw = int(image.shape[1] / dh  + 1) * kw, int(image.shape[2] / dw + 1) * kw 
    outmask = torch.nn.functional.pad(
                         outmask, 
                         (0, idealw - outmask.shape[1], 0, idealh - outmask.shape[0]), 
                         mode='constant', value=0)
    
    # 通过滑动窗口实现对 outmask 进行分区,得到 patches                 
    patches = outmask.unfold(0, kh, dh).unfold(1, kw, dw)  
    
    ''' 得到最终 errors 较大的区域 ''' 
    dummypatch = torch.ones_like(patches)
    # 求和 区域内 errors 的大小
    patchessum = patches.sum(dim=(2,3))
    patchesmusk = patchessum  >  kh * kh * 0.85
    patchesmusk = patchesmusk.unsqueeze(2).unsqueeze(3).repeat(1,1,kh,kh).float()
    patches = dummypatch * patchesmusk 
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/8UWCBq0cZXPAFnluhiLEDVMt4ORa.png)
  • 找到符合条件区域的中心区域,并选定中心像素;
复制代码
    midpatch = torch.ones_like(patches)
    
    # 将偶数的 行 和 列 设置为 0
    for i in range(0, kh,  2):
        for j in range(0, kw, 2):
            midpatch[:,:, i, j] = 0.0  
    
    # 保留 patches 的中心
    centerpatches = patches * midpatch
    
    unfold_shape = patches.size()
    patches_orig = patches.view(unfold_shape)
    centerpatches_orig = centerpatches.view(unfold_shape)
    
    output_h = unfold_shape[0] * unfold_shape[2]
    output_w = unfold_shape[1] * unfold_shape[3]
    patches_orig = patches_orig.permute(0, 2, 1, 3).contiguous()
    centerpatches_orig = centerpatches_orig.permute(0, 2, 1, 3).contiguous()
    centermask = centerpatches_orig.view(output_h, output_w).float() # H * W  mask, # 1 for error, 0 for no error
    # 变回 原始图像大小
    centermask = centermask[:image.shape[1], :image.shape[2]] # reverse back
    
    errormask = patches_orig.view(output_h, output_w).float() # H * W  mask, # 1 for error, 0 for no error
    errormask = errormask[:image.shape[1], :image.shape[2]] # reverse back
    
    '''取中心部分'''
    H, W = centermask.shape
    
    offsetH = int(H/10)
    offsetW = int(W/10)
    
    centermask[0:offsetH, :] = 0.0
    centermask[:, 0:offsetW] = 0.0
    
    centermask[-offsetH:, :] = 0.0
    centermask[:, -offsetW:] = 0.0
    
    # errors 较大的点的索引
    badindices = centermask.nonzero()
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/eVFwlgP0mJMvK7O8p4XAhkxynIqz.png)
  • 防止在过大的区域采样,故规定深度区域,在区域内采样新的高斯。使用 Gaussian 中心的粗糙 depth map 规定采样的深度范围;
复制代码
    render_pkg = render(viewpoint_cam, gaussians, pipe, background,  override_color=None,  basicfunction=rbfbasefunction, GRsetting=GRsetting, GRzer=GRzer)
    
    depth = render_pkg["depth"]
                         
    diff_sorted , _ = torch.sort(depth.reshape(-1)) 
    N = diff_sorted.shape[0]
    mediandepth = int(0.7 * N)
    mediandepth = diff_sorted[mediandepth]
    depth = torch.where(depth>mediandepth, depth, mediandepth) 
    
    # 在(meidandepth, maxdepth]深度区域内添加新的高斯
    totalNnewpoints = gaussians.addgaussians(badindices, 
                                             viewpoint_cam, 
                                             depth, 
                                             gt_image, 
                                             numperay=opt.farray,ratioend=opt.rayends,  
                                             depthmax=depthdict[viewpoint_cam.image_name], 
                                             shuffle=(opt.shuffleems != 0))     
     
     visibility_filter = torch.cat((visibility_filter, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
     radii = torch.cat((radii, torch.zeros(totalNnewpoints).cuda(0)), dim=0)
     viewspace_point_tensor = torch.cat((viewspace_point_tensor, torch.zeros(totalNnewpoints, 3).cuda(0)), dim=0)    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/1Px3MkYtF5rUyTZfAeDHwVp6W8bB.png)
  • 沿着有较大训练 errors 的像素射线 ray;
复制代码
    def addgaussians(self, baduvidx, viewpoint_cam, depthmap, gt_image, numperay=3, ratioend=2, 
    				 trbfcenter=0.5,depthmax=None,shuffle=False):
        def pix2ndc(v, S):
            return (v * 2.0 + 1.0) / S - 1.0
        
        rgbs = gt_image[:, baduvidx[:,0], baduvidx[:,1]]
        rgbs = rgbs.permute(1,0)
        # should we add the feature dc with non zero values? direction feature
        featuredc = torch.cat((rgbs, torch.zeros_like(rgbs)), dim=1)
    
        depths = depthmap[:, baduvidx[:,0], baduvidx[:,1]]
        # only use depth map > 15 .
        depths = depths.permute(1,0) 
        
        # use the max local depth for the scene ?
        depths = torch.ones_like(depths) * depthmax 
        
        maxx, minx = self.maxx, self.minx 
        
        # baduvidx 存储的点的坐标
        u = baduvidx[:,0] # hight y
        v = baduvidx[:,1] # weidth  x 
        
        # 0.7 to ratiostart
        ratiaolist = torch.linspace(self.raystart, ratioend, numperay) 
        for zscale in ratiaolist :
            ndcu, ndcv = pix2ndc(u, viewpoint_cam.image_height), pix2ndc(v, viewpoint_cam.image_width)
            # targetPz = depths * zscale # depth in local cameras..
            if shuffle == True:
                randomdepth = torch.rand_like(depths) - 0.5 # -0.5 to 0.5
                # 设置 depths 左右的深度值
                targetPz = (depths + depths/10*(randomdepth)) * zscale 
            else:
                targetPz = depths*zscale # depth in local cameras..
            
            ndcu = ndcu.unsqueeze(1)
            ndcv = ndcv.unsqueeze(1)
     
            ndccamera = torch.cat((ndcv, 
                                   ndcu,   
                                   torch.ones_like(ndcu) * (1.0), 
                                   torch.ones_like(ndcu)), dim=1) # N,4 ...
            # 投影到相机坐标
            localpointuv = ndccamera @ projectinverse.T 
            
            # ray direction in camera space 
            diretioninlocal = localpointuv / localpointuv[:,3:] 
    
            # 目标深度值 和 ray的z坐标 的比率
            rate = targetPz / diretioninlocal[:, 2:3]
            
            # 得到目标深度的点
            localpoint = diretioninlocal * rate
            localpoint[:, -1] = 1
            
            # 投影到世界坐标
            worldpointH = localpoint @ camera2wold.T   
            worldpoint = worldpointH / worldpointH[:, 3:]
            
            # 得到 世界坐标里 ray上的目标点
            xyz = worldpoint[:, :3] 
            
            # 在 (minx, maxx) 之间的点
            xmask = torch.logical_and(xyz[:, 0] > minx, xyz[:, 0] < maxx )
            # 整个区域
            selectedmask = torch.logical_or(xmask, torch.logical_not(xmask))
            # 存储目标点 (ray 上的点)
            new_xyz.append(xyz[selectedmask]) 
            new_features_dc.append(featuredc.cuda(0)[selectedmask])
            
            selectnumpoints = torch.sum(selectedmask).item()
            new_trbf_center.append(torch.rand((selectnumpoints, 1)).cuda())
    
            assert self.trbfslinit < 1 
            new_trbf_scale.append(self.trbfslinit * torch.ones((selectnumpoints, 1), device="cuda"))
            new_motion.append(torch.zeros((selectnumpoints, 9), device="cuda")) 
            new_omega.append(torch.zeros((selectnumpoints, 4), device="cuda"))
            new_featuret.append(torch.zeros((selectnumpoints, 3), device="cuda"))
        
        new_xyz = torch.cat(new_xyz, dim=0)
        new_rotation = torch.zeros((new_xyz.shape[0],4), device="cuda")
        new_rotation[:, 1]= 0
        
        new_features_dc = torch.cat(new_features_dc, dim=0)
        new_opacity = inverse_sigmoid(0.1 *torch.ones_like(new_xyz[:, 0:1]))
        new_trbf_center = torch.cat(new_trbf_center, dim=0)
        new_trbf_scale = torch.cat(new_trbf_scale, dim=0)
        new_motion = torch.cat(new_motion, dim=0)
        new_omega = torch.cat(new_omega, dim=0)
        new_featuret = torch.cat(new_featuret, dim=0)
    
        tmpxyz = torch.cat((new_xyz, self._xyz), dim=0)
        dist2 = torch.clamp_min(distCUDA2(tmpxyz), 0.0000001)
        dist2 = dist2[:new_xyz.shape[0]]
        scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
        scales = torch.clamp(scales, -10, 1.0)
        new_scaling = scales 
        
        # 致密化新的 3D 高斯点
        self.densification_postfix(new_xyz, new_features_dc, new_opacity, new_scaling, new_rotation, new_trbf_center, 
        						   new_trbf_scale, new_motion, new_omega,new_featuret)
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/0vO5dx6sgkoSnEMQUAteIqDVi2LF.png)

关注标记的点
去掉渲染正确点,保留 errors 大的点,在下次迭代中着重优化:

复制代码
    gt_image = gt_image * errormask
    image = render_pkg["render"] * errormask
    
    torchvision.utils.save_image(gt_image, 
                                 os.path.join(pathdir,  "maskedudgt" + str(iteration) + ".png"))
    torchvision.utils.save_image(image, 
                                 os.path.join(pathdir,  "maskedrender" + str(iteration) + ".png"))
    
    
    python
    
    

优点

基于特征的方法比球谐波编码的方法,需要的参数少。渲染时间快。
可去掉Φ\Phi来加快渲染时间。(lite版本)

Loss

image 和 gt-image 做 loss 计算。
L1L_1 和 D-SSIM

全部评论 (0)

还没有任何评论哟~