Advertisement

图片速览 NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

阅读量:

目录

  • Nerf中的体积渲染(Volume Rendering)

    • 积分值估计
  • 相关细节

    • 获取采样点
    • 获取光线r的函数
    • 位置编码
      • 调用函数
      • 实现函数
paper code
https://arxiv.org/pdf/2003.08934 https://github.com/yenchenlin/nerf-pytorch
  • Nerf是一种使用神经网络进行渲染3维数据的方法,其基于一种非刚体渲染的方法(刚体渲染方法可见OpenGL)。
    在这里插入图片描述

NeRF 的神经网络部分(图中的(a)(b))用于从给定的光线信息(位置和方向)计算颜色和密度。该网络的输入为一条光线的空间数据:

复制代码
* **光线采样点** :由光线起点和方向确定光线,并随机采样得到的多个空间中三维点的位置。
* **光线方向** :每个像素的光线方向,即从相机到场景中的某个点的方向。

网络的输出为光线对应空间位置的颜色和密度数据:

复制代码
* **颜色** (RGB):每个光线的颜色(即该点的RGB值)。
* **密度** (sigma):表示光线通过该点时的物质密度。

图中的(c)为体积渲染,体积渲染(Volume Rendering)来从神经网络的输出生成图像。

Nerf中的体积渲染(Volume Rendering)

C(\mathbf{r})=\int_{t_n}^{t_f}T(t)\sigma(\mathbf{r}(t))\mathbf{c}(\mathbf{r}(t),\mathbf{d})dt, \mathrm{where} \ T(t)=\exp\left(-\int_{t_n}^t\sigma(\mathbf{r}(s))ds\right)

  • 体积密度\sigma(x)可以解释为射线在位置x处终止于无穷小粒子的微分概率。

    • 微分概率:对于一个连续型随机变量X,其概率密度函数为f(x)。在某个特定的点x上,微分概率可以表示为f(x)dx,其中dx代表极小的变化量。微分概率表示了在极小区间dx内,随机变量X取值在x附近的概率。
  • C(r) 在光线r通过后最终显示的颜色

    • r(t) = o+t*d ,o应该为光源坐标,d为光照方向,t为光运动距离
    • t_nt_f为光线范围
    • 函数T(t)表示沿着从t_nt的光线的累积透射率。累积透射率为光线剩余的概率,体积密度越大,积分值越小,累积透射率越小。
    • c(r(t),d)为空间位置的颜色

积分值估计

  • [t_n,t_f]中进行分层采样获取采样点的公式:
    \begin{aligned} t_i\sim\mathcal{U}\left[t_n+\frac{i-1}{N}(t_f-t_n),t_n+\frac{i}{N}(t_f-t_n)\right] \end{aligned}

  • 由此得到颜色渲染积分的估计公式:

\begin{aligned} \hat{C}(\mathbf{r}) & =\sum_{i=1}^NT_i(1-\exp(-\sigma_i\delta_i))\mathbf{c}_i,\mathrm{~where~}T_i=\exp\left(-\sum_{j=1}^{i-1}\sigma_j\delta_j\right) \end{aligned}

  • a_i=1-\exp(-\sigma_i\delta_i)

\begin{aligned} \hat{C}(\mathbf{r}) & =\sum_{i=1}^NT_i(a_i)\mathbf{c}_i \\ & = \sum_{i=1}^N \exp^{\left(-\sum_{j=1}^{i-1}\sigma_j\delta_j\right)} (a_i)\mathbf{c}_i \\ & = \sum_{i=1}^N \Pi_{j=1}^{i-1} \exp^{\left(-\sigma_j\delta_j\right)} (a_i)\mathbf{c}_i \\ & = \sum_{i=1}^N \Pi_{j=1}^{i-1} (1-a_j) *a_i*\mathbf{c}_i \\ & = c_0*a_0+c_1*a_i*(1-a_0)+c_2*a_2*(1-a_0)*(1-a_1)+……+c_N*a_N*(1-a_0)*(1-a_1)……(1-a_{N-1}) \end{aligned}

  • 代码实现:
复制代码
    def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    """将模型的预测结果(多条射线,每条射线上多个采样点,)转换为最终的颜色和距离。
    参数:
        raw: [num_rays, num_samples along ray, 4] 来自模型的预测值,三维RGB+一维密度
        z_vals: [num_rays, num_samples along ray] 射线上采样点距离
        rays_d: [num_rays, 3] 每条光线的方向向量
    返回值:
        rgb_map: [num_rays, 3] 每条光线的估计RGB颜色。
        disp_map: [num_rays] 视差图,深度图的倒数。
        acc_map: [num_rays] 每条光线的权重和(透明度的累积值)
        weights: [num_rays, num_samples] 每个采样点的权重 (通过计算每个采样点的透明度并使用累积乘积方法来得到每个采样点对最终结果的贡献权重)
        depth_map: [num_rays] 每条光线的估计深度
    """
    
    # 定义函数,根据密度(raw的第三个通道)和距离计算点的不透明度(alpha)
    raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)
    
    # 计算每条光线中相邻采样点的距离(公式中的delta)
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    
    # 将一个很大的值(1e10)追加到最后一个距离上,表示光线的末端(背景)
    dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]
    
    # 将距离乘以光线的方向向量的模长,以便考虑空间几何。
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
    
    # 对raw的前3个通道(RGB值)应用sigmoid函数,确保颜色值在[0, 1]范围内
    rgb = torch.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]
    
    # 初始化噪声值为0。噪声会在透明度计算时使用(如果raw_noise_std > 0)
    noise = 0.
    
    # 如果raw_noise_std > 0,则生成随机噪声,添加到透明度(alpha)中
    if raw_noise_std > 0.:
        noise = torch.randn(raw[..., 3].shape) * raw_noise_std
    
        if pytest: # 若为pytest模式,使用固定的随机种子,以便调试时保持一致。
            np.random.seed(0)
            noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
            noise = torch.Tensor(noise)
    
    # 计算透明度(alpha),使用raw的密度值(并加上噪声)以及距离来计算
    alpha = raw2alpha(raw[..., 3] + noise, dists)  # [N_rays, N_samples]
    
    # 计算每个采样点的权重,权重是不透明度(alpha)与前面所有采样点透明度(1-alpha)的累积乘积
    # torch.cumprod函数用于计算张量(Tensor)在指定维度上的元素的累积乘积
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:, :-1]
    
    # 计算每条光线的RGB颜色,将所有采样点的RGB值按照权重求和,这里就是公式中hat_C的估计值
    rgb_map = torch.sum(weights[..., None] * rgb, -2)  # [N_rays, 3]
    
    # 计算每条光线的深度图,通过将权重与深度(z_vals)相乘得到。
    depth_map = torch.sum(weights * z_vals, -1)
    
    # 计算视差图,它是深度图的倒数,避免除以零。
    disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
    
    # 计算每条光线的累积透明度(权重和),即所有采样点的权重之和。
    acc_map = torch.sum(weights, -1)
    
    # 如果背景是白色(white_bkgd = True),则将背景的颜色(1 - acc_map)加到最终的RGB颜色上。
    if white_bkgd:
        rgb_map = rgb_map + (1. - acc_map[..., None])
    
    # 返回最终的RGB颜色图、视差图、累积透明度图、每个采样点的权重和深度图。
    return rgb_map, disp_map, acc_map, weights, depth_map
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/VezyI4QkdM2XUSauCBt9NlqpvD7r.png)

相关细节

获取采样点

  • 分层采样(Hierarchical Sampling),用于通过累积分布函数(CDF)从概率密度函数(PDF)中进行采样
复制代码
    # Hierarchical sampling (section 5.2)
    def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf
    weights = weights + 1e-5 # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))
    
    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
    
    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)
    
    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)
    
    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
    
    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
    
    return samples
    
    
    c
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/3LNgwEiqhzKOAVP7yCcJBW0sefkd.png)
  • 相关调用:
    • z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
    • pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]#为最终的输入点数据

获取光线r的函数

  • get_rays()函数:根据相机模型获取从相机中心位置点到图片上的像素点的射线:
复制代码
    # Ray helpers
    def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d # 光线坐标原点,方向
    
    
    c
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/rO2etcTAXBPR8lupImKhnYLqgzNG.png)
  • 实际训练时会随机在图片中选择一个batch数量的射线select_coords = coords[select_inds].long() # (N_rand, 2)组成训练batch
    • 然后在每条射线上采样64个点,通过这些点对使用求积法对渲染公式中的连续积分C(r)进行数值估计。

位置编码

  • 用于对输入位置和方向数据进行编码,防止输入过平滑

\gamma(p)=\left(\sin\left(2^0\pi p\right),\cos\left(2^0\pi p\right),\cdots,\sin\left(2^{L-1}\pi p\right),\cos\left(2^{L-1}\pi p\right)\right)

  • L为超参数,对应代码中的multires
  • p为归一化到[-1,1]的位置坐标x或空间方向坐标d

调用函数

复制代码
    # inputs torch.Size([1024, 64, 3]), 64为采样点的个数
    # viewdirs torch.Size([1024, 3])
    # fn NeRF()
    # embed_fn & embeddirs_fn 位置编码,输入数据预处理编码
    def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'.
    """
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # torch.Size([65536, 3])
    embedded = embed_fn(inputs_flat) # torch.Size([65536, 63]),L=10,63=3[]*(10[sin]+10[cos])+3[原始数据]
    
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)# torch.Size([65536, 63])
        embedded = torch.cat([embedded, embedded_dirs], -1)
    
    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs
    
    
    c
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/8IjMASCgOkFolsDQZu56XaHrcLb3.png)

实现函数

复制代码
    # Positional encoding (section 5.1)
    class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()
        
    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x : x)
            out_dim += d
            
        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']
        
        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else: # 2^L
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
            
        for freq in freq_bands: # 'periodic_fns' : [torch.sin, torch.cos]
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns # 21
        self.out_dim = out_dim # 63
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
    
    
    def get_embedder(multires, i=0):
    if i == -1:
        return nn.Identity(), 3
    
    embed_kwargs = {
                'include_input' : True,
                'input_dims' : 3,
                'max_freq_log2' : multires-1,
                'num_freqs' : multires,
                'log_sampling' : True,
                'periodic_fns' : [torch.sin, torch.cos],
    }
    
    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj : eo.embed(x)
    return embed, embedder_obj.out_dim
    
    
    c
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-16/RLahOqJbjSMCX4KBlw9p1uEIWYx6.png)

全部评论 (0)

还没有任何评论哟~