Advertisement

07. 贝叶斯神经网络

阅读量:

算法思路

普通的神经网络的权值是确定的,而贝叶斯神经网络的权值是不确定的,他服从于一个概率分布,这便是贝叶斯神经网络和普通神经网络的差别。

可以简单认为,贝叶斯神经网络是无穷个神经网络的融合,不过给每个神经网络标上一个 重要度 而已。

普通神经网络训练其权重时,采用的方法无外乎两种:最大似然估计最大后验估计 。即:
\begin{aligned} w^{MLE} &= \arg\max_{w} \log P(D|w)\\ &= \arg\max_w\sum_i\log P(y_i|x_i,w) \end{aligned}
\begin{aligned} w^{MAP}&=\arg\max_w\log P(w|D)\\ &= \arg\max_w\log P(D|w)+\log P(w) - P(D)\\ &= \arg\max_w\log P(D|w)+\log P(w) \end{aligned}
普通的神经网络就是这样求最值,其每个节点都是一个确定的值,w^\star
但是神经网络是在假设P(w)为先验估计的条件下,直接求解P(w|D),每个节点都是一个概率分布。

模型推导

我们的最终目标是求出,对于输入x,输出y的概率分布,即P(y|x),显然:
P(y|x)=\mathbb E_{P(w|D)}P(y|x,w)

那么当我们预测神经网络的时候,并不是真真正正的通过概率算出来的最后结果的分布。
而是通过多次采样,估计出最后结果的值。

所以我们的关键是求出P(w|D),之后我们才可以对w进行采样。
根据贝叶斯公式:
P(w|D)=\frac{P(D|w)P(w)}{P(D)}
这个式子中P(w)是先验,P(D|w)是后验,这些都是可求的。
但是P(D)或为积分,或为累和,这个东西是难求的。

这里一开始我有疑问,为什么这里不可以对w进行采样,然后算出P(D)的概率。
后来想想确实不能,因为蒙特卡洛求P(D)之后,我们仍然需要蒙特卡洛,双蒙特卡洛的可靠性着实不敢相信。

所以引入变分估计,引入q(w|\theta)估计P(w|D),那么我们的优化目标就为:
\begin{aligned} \theta^\star&=\arg\min_\theta D_{KL}[q(w|\theta)\|P(w|D)]\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log\frac{q(w|\theta)}{P(w|D)}\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log\frac{q(w|\theta)P(D)}{P(D|w)P(w)}\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log q(w|\theta)-\mathbb{E}_{q(w|\theta)}\log P(w)-\mathbb{E}_{q(w|\theta)}\log P(D|w) \end{aligned}
中间有一步骤正好略去常数P(D),因为其取值和\theta无关。从而式子变得稍微更好解了一点。
设目标函数\mathcal F(D,\theta)为:
\mathcal F(D,\theta)= \mathbb{E}_{q(w|\theta)}\log q(w|\theta)-\mathbb{E}_{q(w|\theta)}\log P(w)-\mathbb{E}_{q(w|\theta)}\log P(D|w)
注意到他们都是q(w|\theta)的期望,所以我们直接蒙特卡洛采样,求解\mathcal F(D,\theta)近似值为:
\mathcal F(D,\theta)\approx \frac{1}{N}\sum_{i=1}^N\log q(w^{(i)}|\theta)-\log P(w^{(i)})-\log P(D|w^{(i)})

采样的时候不能直接\mathcal N(\mu,\sigma^2).sample(shape)这样直接采样,这样\mu\sigma是不可导的。我们可以\mu+\sigma*\mathcal N(0,1).sample(),这样就是可导了。

怎么求导直接扔给框架就行了。

算法流程(demo)

假设w的概率服从于大概两个正太分布叠加的形式;
假设先验为w\sim\pi\mathcal{N}(\mu_1,\sigma_1^2)+(1-\pi)\mathcal{N}(\mu_2,\sigma_2^2)
w分布是有点复杂的,所以我们用\mathcal N(\mu,\sigma^2)去近似真实的分布;
优化目标自然是\mathcal N(\mu,\sigma^2)中的\mu,\sigma
对于每次迭代,需要多次采样计算\mathcal F(D,[\mu,\sigma])的均值;
对于每次采样,对\mathcal N(\mu,\sigma^2)进行采样得到w,b;下面之说w,因为b也一模一样。
此时先验P(w)等于w\pi\mathcal{N}(\mu_1,\sigma_1^2)+(1-\pi)\mathcal{N}(\mu_2,\sigma_2^2)中的概率密度,q(w|\theta)等于w\mathcal N(\mu,\sigma^2)中的概率密度;P(D|w)是输出值\hat y\mathcal N(y,\hat\sigma^2)中的概率密度。
上面我们又引入一个\hat\sigma,我感觉对于训练最终结果没有太大影响,因为不影响梯度的最终方向。但是应该稍微影响训练速度,什么的,其具体值个人觉得可以随便给。

算法实现

复制代码
    import torch
    import torch.nn as nn
    from torch import exp
    from torch import optim
    from torch import tensor
    import torch.nn.functional as F
    from torch.distributions import Normal
    from torch.nn.functional import softplus
    import numpy as np
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    
    
    class Layer(nn.Module):
    def __init__(self, input_features, output_features, prior_rho_1=1., prior_rho_2=1., prior_pi=0.5):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
    
        self.weight_mu = nn.Parameter(torch.zeros(output_features, input_features))
        self.weight_rho = nn.Parameter(torch.zeros(output_features, input_features))
    
        self.bias_mu = nn.Parameter(torch.zeros(output_features))
        self.bias_rho = nn.Parameter(torch.zeros(output_features))
    
        self.log_prior, self.log_post = None, None
    
        self.prior1 = torch.distributions.Normal(0, prior_rho_1)
        self.prior2 = torch.distributions.Normal(0, prior_rho_2)
        self.prior_pi_1, self.prior_pi_2 = np.log(prior_pi), np.log(1 - prior_pi)
    
    def forward(self, inputs):
        weight = self.weight_mu + softplus(self.weight_rho) * Normal(0, 1).sample(self.weight_mu.shape)
        bias = self.bias_mu + softplus(self.bias_rho) * Normal(0, 1).sample(self.bias_mu.shape)
    
        w_log_prior = torch.log(exp(self.prior_pi_1 + self.prior1.log_prob(weight)) + exp(self.prior_pi_2 + self.prior2.log_prob(weight)))
        b_log_prior = torch.log(exp(self.prior_pi_1 + self.prior1.log_prob(bias)) + exp(self.prior_pi_2 + self.prior2.log_prob(bias)))
        self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior)
    
        w_post = Normal(self.weight_mu.data, softplus(self.weight_rho))
        b_post = Normal(self.bias_mu.data, softplus(self.bias_rho))
        self.log_post = w_post.log_prob(weight).sum() + b_post.log_prob(bias).sum()
    
        return F.linear(inputs, weight, bias)
    
    
    class BNN(nn.Module):
    def __init__(self, noise_tol=.1,  prior_var=1.):
        super().__init__()
        self.hidden_layer_1 = Layer(1, 32, prior_var)
        self.output_layer = Layer(32, 1, prior_var)
        self.noise_tol = noise_tol
    
        self.log_prior, self.log_post = None, None
    
    def forward(self, x):
        x = torch.relu(self.hidden_layer_1(x))
        x = self.output_layer(x)
    
        self.log_prior = self.hidden_layer_1.log_prior + self.output_layer.log_prior
        self.log_post = self.hidden_layer_1.log_post + self.output_layer.log_post
    
        return x
    
    def get_loss(self, inputs, target):
        outputs = self(inputs).reshape(-1)
        return self.log_post - self.log_prior - Normal(outputs, self.noise_tol).log_prob(target.reshape(-1)).sum()
    
    
    def toy_fun(x):
    return 10 * np.sin(2 * np.pi * x)
    
    
    def load_data(num_samples):
    x = tensor(np.linspace(-0.8, 0.8, num_samples).reshape(-1, 1)).float()
    y = toy_fun(x) + np.random.randn(num_samples, 1)
    return x, y
    
    
    if __name__ == '__main__':
    # 导入数据
    data, label = load_data(32)
    # 导入模型
    net = BNN(prior_var=10)
    # 构造优化器
    optimizer = optim.Adam(net.parameters(), lr=.1)
    for epoch in tqdm(range(2000)):
        optimizer.zero_grad()
        loss = net.get_loss(data, label)
        loss.backward()
        optimizer.step()
    
    samples, dpi = 100, 100
    x_tmp = torch.linspace(-1.2, 1.2, dpi).reshape(-1, 1)
    y_samp = np.zeros((samples, dpi))
    for s in range(samples):
        y_tmp = net(x_tmp).detach().numpy()
        y_samp[s] = y_tmp.reshape(-1)
    plt.plot(x_tmp.numpy(), np.mean(y_samp, axis=0), label='Mean Posterior Predictive')
    plt.fill_between(x_tmp.numpy().reshape(-1), np.percentile(y_samp, 2.5, axis=0), np.percentile(y_samp, 97.5, axis=0), alpha=0.25, label='95% Confidence')
    plt.legend()
    plt.scatter(data, toy_fun(data))
    plt.title('Posterior Predictive')
    plt.show()
在这里插入图片描述

全部评论 (0)

还没有任何评论哟~