07. 贝叶斯神经网络
算法思路
普通的神经网络的权值是确定的,而贝叶斯神经网络的权值是不确定的,他服从于一个概率分布,这便是贝叶斯神经网络和普通神经网络的差别。
可以简单认为,贝叶斯神经网络是无穷个神经网络的融合,不过给每个神经网络标上一个 重要度 而已。
普通神经网络训练其权重时,采用的方法无外乎两种:最大似然估计 和最大后验估计 。即:
\begin{aligned} w^{MLE} &= \arg\max_{w} \log P(D|w)\\ &= \arg\max_w\sum_i\log P(y_i|x_i,w) \end{aligned}
\begin{aligned} w^{MAP}&=\arg\max_w\log P(w|D)\\ &= \arg\max_w\log P(D|w)+\log P(w) - P(D)\\ &= \arg\max_w\log P(D|w)+\log P(w) \end{aligned}
普通的神经网络就是这样求最值,其每个节点都是一个确定的值,w^\star。
但是神经网络是在假设P(w)为先验估计的条件下,直接求解P(w|D),每个节点都是一个概率分布。
模型推导
我们的最终目标是求出,对于输入x,输出y的概率分布,即P(y|x),显然:
P(y|x)=\mathbb E_{P(w|D)}P(y|x,w)
那么当我们预测神经网络的时候,并不是真真正正的通过概率算出来的最后结果的分布。
而是通过多次采样,估计出最后结果的值。
所以我们的关键是求出P(w|D),之后我们才可以对w进行采样。
根据贝叶斯公式:
P(w|D)=\frac{P(D|w)P(w)}{P(D)}
这个式子中P(w)是先验,P(D|w)是后验,这些都是可求的。
但是P(D)或为积分,或为累和,这个东西是难求的。
这里一开始我有疑问,为什么这里不可以对w进行采样,然后算出P(D)的概率。
后来想想确实不能,因为蒙特卡洛求P(D)之后,我们仍然需要蒙特卡洛,双蒙特卡洛的可靠性着实不敢相信。
所以引入变分估计,引入q(w|\theta)估计P(w|D),那么我们的优化目标就为:
\begin{aligned} \theta^\star&=\arg\min_\theta D_{KL}[q(w|\theta)\|P(w|D)]\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log\frac{q(w|\theta)}{P(w|D)}\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log\frac{q(w|\theta)P(D)}{P(D|w)P(w)}\\ &=\arg\min_\theta \mathbb{E}_{q(w|\theta)}\log q(w|\theta)-\mathbb{E}_{q(w|\theta)}\log P(w)-\mathbb{E}_{q(w|\theta)}\log P(D|w) \end{aligned}
中间有一步骤正好略去常数P(D),因为其取值和\theta无关。从而式子变得稍微更好解了一点。
设目标函数\mathcal F(D,\theta)为:
\mathcal F(D,\theta)= \mathbb{E}_{q(w|\theta)}\log q(w|\theta)-\mathbb{E}_{q(w|\theta)}\log P(w)-\mathbb{E}_{q(w|\theta)}\log P(D|w)
注意到他们都是q(w|\theta)的期望,所以我们直接蒙特卡洛采样,求解\mathcal F(D,\theta)近似值为:
\mathcal F(D,\theta)\approx \frac{1}{N}\sum_{i=1}^N\log q(w^{(i)}|\theta)-\log P(w^{(i)})-\log P(D|w^{(i)})
采样的时候不能直接\mathcal N(\mu,\sigma^2).sample(shape)这样直接采样,这样\mu和\sigma是不可导的。我们可以\mu+\sigma*\mathcal N(0,1).sample(),这样就是可导了。
怎么求导直接扔给框架就行了。
算法流程(demo)
假设w的概率服从于大概两个正太分布叠加的形式;
假设先验为w\sim\pi\mathcal{N}(\mu_1,\sigma_1^2)+(1-\pi)\mathcal{N}(\mu_2,\sigma_2^2);
w分布是有点复杂的,所以我们用\mathcal N(\mu,\sigma^2)去近似真实的分布;
优化目标自然是\mathcal N(\mu,\sigma^2)中的\mu,\sigma;
对于每次迭代,需要多次采样计算\mathcal F(D,[\mu,\sigma])的均值;
对于每次采样,对\mathcal N(\mu,\sigma^2)进行采样得到w,b;下面之说w,因为b也一模一样。
此时先验P(w)等于w在\pi\mathcal{N}(\mu_1,\sigma_1^2)+(1-\pi)\mathcal{N}(\mu_2,\sigma_2^2)中的概率密度,q(w|\theta)等于w在\mathcal N(\mu,\sigma^2)中的概率密度;P(D|w)是输出值\hat y在\mathcal N(y,\hat\sigma^2)中的概率密度。
上面我们又引入一个\hat\sigma,我感觉对于训练最终结果没有太大影响,因为不影响梯度的最终方向。但是应该稍微影响训练速度,什么的,其具体值个人觉得可以随便给。
算法实现
import torch
import torch.nn as nn
from torch import exp
from torch import optim
from torch import tensor
import torch.nn.functional as F
from torch.distributions import Normal
from torch.nn.functional import softplus
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
class Layer(nn.Module):
def __init__(self, input_features, output_features, prior_rho_1=1., prior_rho_2=1., prior_pi=0.5):
super().__init__()
self.input_features = input_features
self.output_features = output_features
self.weight_mu = nn.Parameter(torch.zeros(output_features, input_features))
self.weight_rho = nn.Parameter(torch.zeros(output_features, input_features))
self.bias_mu = nn.Parameter(torch.zeros(output_features))
self.bias_rho = nn.Parameter(torch.zeros(output_features))
self.log_prior, self.log_post = None, None
self.prior1 = torch.distributions.Normal(0, prior_rho_1)
self.prior2 = torch.distributions.Normal(0, prior_rho_2)
self.prior_pi_1, self.prior_pi_2 = np.log(prior_pi), np.log(1 - prior_pi)
def forward(self, inputs):
weight = self.weight_mu + softplus(self.weight_rho) * Normal(0, 1).sample(self.weight_mu.shape)
bias = self.bias_mu + softplus(self.bias_rho) * Normal(0, 1).sample(self.bias_mu.shape)
w_log_prior = torch.log(exp(self.prior_pi_1 + self.prior1.log_prob(weight)) + exp(self.prior_pi_2 + self.prior2.log_prob(weight)))
b_log_prior = torch.log(exp(self.prior_pi_1 + self.prior1.log_prob(bias)) + exp(self.prior_pi_2 + self.prior2.log_prob(bias)))
self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior)
w_post = Normal(self.weight_mu.data, softplus(self.weight_rho))
b_post = Normal(self.bias_mu.data, softplus(self.bias_rho))
self.log_post = w_post.log_prob(weight).sum() + b_post.log_prob(bias).sum()
return F.linear(inputs, weight, bias)
class BNN(nn.Module):
def __init__(self, noise_tol=.1, prior_var=1.):
super().__init__()
self.hidden_layer_1 = Layer(1, 32, prior_var)
self.output_layer = Layer(32, 1, prior_var)
self.noise_tol = noise_tol
self.log_prior, self.log_post = None, None
def forward(self, x):
x = torch.relu(self.hidden_layer_1(x))
x = self.output_layer(x)
self.log_prior = self.hidden_layer_1.log_prior + self.output_layer.log_prior
self.log_post = self.hidden_layer_1.log_post + self.output_layer.log_post
return x
def get_loss(self, inputs, target):
outputs = self(inputs).reshape(-1)
return self.log_post - self.log_prior - Normal(outputs, self.noise_tol).log_prob(target.reshape(-1)).sum()
def toy_fun(x):
return 10 * np.sin(2 * np.pi * x)
def load_data(num_samples):
x = tensor(np.linspace(-0.8, 0.8, num_samples).reshape(-1, 1)).float()
y = toy_fun(x) + np.random.randn(num_samples, 1)
return x, y
if __name__ == '__main__':
# 导入数据
data, label = load_data(32)
# 导入模型
net = BNN(prior_var=10)
# 构造优化器
optimizer = optim.Adam(net.parameters(), lr=.1)
for epoch in tqdm(range(2000)):
optimizer.zero_grad()
loss = net.get_loss(data, label)
loss.backward()
optimizer.step()
samples, dpi = 100, 100
x_tmp = torch.linspace(-1.2, 1.2, dpi).reshape(-1, 1)
y_samp = np.zeros((samples, dpi))
for s in range(samples):
y_tmp = net(x_tmp).detach().numpy()
y_samp[s] = y_tmp.reshape(-1)
plt.plot(x_tmp.numpy(), np.mean(y_samp, axis=0), label='Mean Posterior Predictive')
plt.fill_between(x_tmp.numpy().reshape(-1), np.percentile(y_samp, 2.5, axis=0), np.percentile(y_samp, 97.5, axis=0), alpha=0.25, label='95% Confidence')
plt.legend()
plt.scatter(data, toy_fun(data))
plt.title('Posterior Predictive')
plt.show()

