Advertisement

[深度学习]半监督学习、无监督学习之Variational Auto-Encoder变分自编码器(附代码)

阅读量:

论文全称:《Auto-Encoding Variational Bayes》

论文地址:https://arxiv.org/pdf/1312.6114.pdf

论文代码:

keras 版本:https://github.com/bojone/vae

pytorch 版本:https://colab.research.google.com/github/smartgeometry-ucl/dl4g/blob/master/variational_autoencoder.ipynb

关于VAE的博客教程网络上有很多,但是没有几个是能够讲得清晰明了的,而且能够与代码结合更是少之又少。

“Talk is cheap,show me the code”

这里推荐一个博主写的挺不错的VAE分析 :https://spaces.ac.cn/archives/5253

基本上上面这篇可以与下面的代码一致,而我觉得自己没有必要再去解释很多公式的VAE。

首先导入包和设置超参数:

复制代码
 import os

    
  
    
 import torch
    
 import torch.nn as nn
    
 import torch.nn.functional as F
    
  
    
 # 2-d latent space, parameter count in same order of magnitude
    
 # as in the original VAE paper (VAE paper has about 3x as many)
    
 latent_dims = 2
    
 num_epochs = 100
    
 batch_size = 128
    
 capacity = 64
    
 learning_rate = 1e-3
    
 variational_beta = 1
    
 use_gpu = True
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/nlJRNHoXLaZgVPCdEADpvyu6tbYM.png)

加载MINIST数据集:

复制代码
 import torchvision.transforms as transforms

    
 from torch.utils.data import DataLoader
    
 from torchvision.datasets import MNIST
    
  
    
 img_transform = transforms.Compose([
    
     transforms.ToTensor()
    
 ])
    
  
    
 train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
    
 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
  
    
 test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
    
 test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/TA2C8slBGQqHPuS17tOIvec4hEja.png)

定义VAE的结构,整体结构和Autoencoder很类似,但encoder学习的是隐变量的均值和方差,然后再根据他们生成隐变量。注意在encoder中分别有两个全连接层,对应于均值和方差。

而latent_sample函数对应于reparameterization tricks。从N(0,I)中采样一ε,然后让Z=μ+ε×σ。

复制代码
 class Encoder(nn.Module):

    
     def __init__(self):
    
     super(Encoder, self).__init__()
    
     c = capacity
    
     self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
    
     self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
    
     self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
    
     self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
    
         
    
     def forward(self, x):
    
     x = F.relu(self.conv1(x))
    
     x = F.relu(self.conv2(x))
    
     x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
    
     x_mu = self.fc_mu(x)
    
     x_logvar = self.fc_logvar(x)
    
     return x_mu, x_logvar
    
  
    
 class Decoder(nn.Module):
    
     def __init__(self):
    
     super(Decoder, self).__init__()
    
     c = capacity
    
     self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
    
     self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
    
     self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
    
         
    
     def forward(self, x):
    
     x = self.fc(x)
    
     x = x.view(x.size(0), capacity*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
    
     x = F.relu(self.conv2(x))
    
     x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
    
     return x
    
     
    
 class VariationalAutoencoder(nn.Module):
    
     def __init__(self):
    
     super(VariationalAutoencoder, self).__init__()
    
     self.encoder = Encoder()
    
     self.decoder = Decoder()
    
     
    
     def forward(self, x):
    
     latent_mu, latent_logvar = self.encoder(x)
    
     latent = self.latent_sample(latent_mu, latent_logvar)
    
     x_recon = self.decoder(latent)
    
     return x_recon, latent_mu, latent_logvar
    
     
    
     def latent_sample(self, mu, logvar):
    
     if self.training:
    
         # the reparameterization trick
    
         std = logvar.mul(0.5).exp_()
    
         eps = torch.empty_like(std).normal_()
    
         return eps.mul(std).add_(mu)
    
     else:
    
         return mu
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/MzXwu8ORCtBU7IWZLjr9D0ao4b5P.png)

定义loss函数,第一部分是关于输入与输出的相似程度,第二部分则是用KL散度来衡量学习到的隐变量空间和真实隐变量空间之间的相似性。

复制代码
 def vae_loss(recon_x, x, mu, logvar):

    
     # recon_x is the probability of a multivariate Bernoulli distribution p.
    
     # -log(p(x)) is then the pixel-wise binary cross-entropy.
    
     # Averaging or not averaging the binary cross-entropy over all pixels here
    
     # is a subtle detail with big effect on training, since it changes the weight
    
     # we need to pick for the other loss term by several orders of magnitude.
    
     # Not averaging is the direct implementation of the negative log likelihood,
    
     # but averaging makes the weight of the other loss term independent of the image resolution.
    
     recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    
     
    
     # KL-divergence between the prior distribution over latent vectors
    
     # (the one we are going to sample from when generating new images)
    
     # and the distribution estimated by the generator for the given image.
    
     kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
     
    
     return recon_loss + variational_beta * kldivergence
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/vhTcsNiqgZWBwaI9joSKdMbXzkRD.png)

用上面的结构初始化一个vae,看看有多少权重参数需要学习。

复制代码
 vae = VariationalAutoencoder()

    
  
    
 device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
    
 vae = vae.to(device)
    
  
    
 num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
    
 print('Number of parameters: %d' % num_params)
    
    
    
    
    python
    
    

训练vae

复制代码
 optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)

    
  
    
 # set to training mode
    
 vae.train()
    
  
    
 train_loss_avg = []
    
  
    
 print('Training ...')
    
 for epoch in range(num_epochs):
    
     train_loss_avg.append(0)
    
     num_batches = 0
    
     
    
     for image_batch, _ in train_dataloader:
    
     
    
     image_batch = image_batch.to(device)
    
  
    
     # vae reconstruction
    
     image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
    
     
    
     # reconstruction error
    
     loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
    
     
    
     # backpropagation
    
     optimizer.zero_grad()
    
     loss.backward()
    
     
    
     # one step of the optmizer (using the gradients from backpropagation)
    
     optimizer.step()
    
     
    
     train_loss_avg[-1] += loss.item()
    
     num_batches += 1
    
     
    
     train_loss_avg[-1] /= num_batches
    
     print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/UFn0CBY5DvVw8HKo4ZI7ab6jmJG3.png)

描绘loss曲线:

复制代码
 import matplotlib.pyplot as plt

    
 plt.ion()
    
  
    
 fig = plt.figure()
    
 plt.plot(train_loss_avg)
    
 plt.xlabel('Epochs')
    
 plt.ylabel('Loss')
    
 plt.show()
    
    
    
    
    python
    
    

如果不想从零训练,可以加载预训练模型。

复制代码
 filename = 'vae_2d.pth'

    
 # filename = 'vae_10d.pth'
    
 import urllib
    
 if not os.path.isdir('./pretrained'):
    
     os.makedirs('./pretrained')
    
 print('downloading ...')
    
 urllib.request.urlretrieve ("http://geometry.cs.ucl.ac.uk/creativeai/pretrained/"+filename, "./pretrained/"+filename)
    
 vae.load_state_dict(torch.load('./pretrained/'+filename))
    
 print('done')
    
  
    
 # this is how the VAE parameters can be saved:
    
 # torch.save(vae.state_dict(), './pretrained/my_vae.pth')
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/nGFxkyC1PqjKcLM0AXs6iSabpdIh.png)

在测试集测试一下结果。

复制代码
 # set to evaluation mode

    
 vae.eval()
    
  
    
 test_loss_avg, num_batches = 0, 0
    
 for image_batch, _ in test_dataloader:
    
     
    
     with torch.no_grad():
    
     
    
     image_batch = image_batch.to(device)
    
  
    
     # vae reconstruction
    
     image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
    
  
    
     # reconstruction error
    
     loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
    
  
    
     test_loss_avg += loss.item()
    
     num_batches += 1
    
     
    
 test_loss_avg /= num_batches
    
 print('average reconstruction error: %f' % (test_loss_avg))
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/P4V0IKFcxzabJws8Nql6EojGf1ny.png)

可视化结果

复制代码
 import numpy as np

    
 import matplotlib.pyplot as plt
    
 plt.ion()
    
  
    
 import torchvision.utils
    
  
    
 vae.eval()
    
  
    
 # This function takes as an input the images to reconstruct
    
 # and the name of the model with which the reconstructions
    
 # are performed
    
 def to_img(x):
    
     x = x.clamp(0, 1)
    
     return x
    
  
    
 def show_image(img):
    
     img = to_img(img)
    
     npimg = img.numpy()
    
     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
  
    
 def visualise_output(images, model):
    
  
    
     with torch.no_grad():
    
     
    
     images = images.to(device)
    
     images, _, _ = model(images)
    
     images = images.cpu()
    
     images = to_img(images)
    
     np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
    
     plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
    
     plt.show()
    
  
    
 images, labels = iter(test_dataloader).next()
    
  
    
 # First visualise the original images
    
 print('Original images')
    
 show_image(torchvision.utils.make_grid(images[1:50],10,5))
    
 plt.show()
    
  
    
 # Reconstruct and visualise the images using the vae
    
 print('VAE reconstruction:')
    
 visualise_output(images, vae)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/iog6kDwKLcf09BCyzEO1YhuQURpG.png)

可视化2d隐变量空间

复制代码
 # load a network that was trained with a 2d latent space

    
 if latent_dims != 2:
    
     print('Please change the parameters to two latent dimensions.')
    
     
    
 with torch.no_grad():
    
     
    
     # create a sample grid in 2d latent space
    
     latent_x = np.linspace(-1.5,1.5,20)
    
     latent_y = np.linspace(-1.5,1.5,20)
    
     latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
    
     for i, lx in enumerate(latent_x):
    
     for j, ly in enumerate(latent_y):
    
         latents[j, i, 0] = lx
    
         latents[j, i, 1] = ly
    
     latents = latents.view(-1, 2) # flatten grid into a batch
    
  
    
     # reconstruct images from the latent vectors
    
     latents = latents.to(device)
    
     image_recon = vae.decoder(latents)
    
  
    
     image_recon = image_recon.cpu()
    
     fig, ax = plt.subplots(figsize=(10, 10))
    
     show_image(torchvision.utils.make_grid(image_recon.data[:400],20,5))
    
     plt.show()
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-12/L7nfSlFQc0JPMoCziABGyXvHRpDj.png)

全部评论 (0)

还没有任何评论哟~