Self Attention Generative Adversarial Network
本文介绍了基于Self-Attention的GAN模型(SAGAN)及其在图像生成任务中的应用。该模型通过将生成器G和判别器D联合训练,并结合自注意力机制来增强模型的生成能力。具体而言,SAGAN采用编码器-解码器结构,在无监督学习中通过特征提取和分类任务优化判别器D,在有监督学习中通过最小化误差函数优化生成器G。文章详细描述了SAGAN的网络结构、损失函数、优化算法以及训练流程,并通过实验验证了其有效性。该模型展示了自注意力机制在提升生成质量方面的潜力,并为后续研究提供了新的方向。
作者:禅与计算机程序设计艺术
1.简介
微软亚洲研究院于2017年发布了一种基于Self-Attention机制的生成对抗网络(GAN),命名为SAGAN(Self-Attention Generative Adversarial Networks)。该研究通过联合训练判别器、生成器以及自注意机制的优化过程,成功地降低了各组件之间的依赖关系,并显著提升了生成模型的泛化能力。作为补充内容,在本文中我们首先对Self-Attention相关背景进行了深入探讨,并详细阐述了SAGAN的具体架构设计与训练流程;随后通过一系列实验数据验证了该方法的有效性和实际性能。
一、Self-Attention概述
1.1 Attention是什么?
Attention mechanism是一种在序列数据中关注特定元素的技术。它实际上是一种计算权重的方法,在这种机制下使得网络能够根据输入数据的不同部分的重要程度进行相应的调整。换句话说,在这种机制下会赋予输入一定的权重层次结构从而体现出各要素的重要性差异。如图所示在左边展示了一个注意力机制的例子而在右边展示了一个没有使用注意力机制的传统神经网络的情况:
注意力机制本质上是一种特征选择机制,在模型训练过程中能够综合捕捉输入数据的全局信息与局部细节信息,并通过这两者的融合生成更加高效的表征形式。在机器翻译领域中,该方法通常会基于输入文本中的单词或短语来选择最相关的内容进行处理;对于图像识别问题,则依赖于网络对关键区域的识别能力来提取具有代表性的特征;在自然语言处理方面,则利用其特性对文本内容进行高层次解析并重点关注核心词汇和段落结构。
2.2 Attention模块
在深度学习领域中广泛运用着注意力机制,在长期记忆中存储和分析数据的传统上来看待这一过程可以被划分为两个主要阶段:第一部分是编码过程第二部分是解码过程我们将这一整体称为Encoder-Decoder模型其中编码器部分通常包含若干个层级每一层都包含若干个子模块这些子模块包括但不仅限于多头注意力机制和位置编码等核心组件它们共同作用于输入序列信息完成特征提取任务解码器则基于编码器提取的信息进行后续处理最终完成信息传递与结果生成这一流程体现了从信息接收者到信息传递者再到结果生成者的完整转化路径
为了实现自注意力机制, 我们可以采用权重共享的方式构建自注意力机制模块. 该模块的输入由查询Q、键K以及值V构成, 其维度分别为q_dim、k_dim和v_dim. 其中, 查询向量Q用于计算注意力分布, 键向量K用于描述输入序列中的各个元素, 值向量V用于保存各个元素对应的潜在表征信息. 通过Softmax函数可以对注意力分布进行归一化处理, 并且权重矩阵W可以被多个不同的查询向量共享使用以提高模型效率. 该模块的核心输出即为:
Attention(Q, K, V)=softmax(\frac{QK^T}{\sqrt{d_k}})V
d_k为模型的隐藏单元大小。
注意力机制在现代深度学习模型中发挥着关键作用,并已得到了广泛应用,在多个领域内得到了应用与实践支持,包括图像分类、文本生成以及机器翻译等主要涉及的任务。
二、Self-Attention GAN
本文所采用的SAGAN生成式模型由生成器G和判别器D两个主要组件构成。其中,生成器G旨在从潜在空间中提取特征并合成图像;判别器D则负责判断合成图像的真实性与否。具体而言,生成器G主要包含编码子网络和解码子网络两部分:编码子网络接收随机噪声z作为输入信号,并运用一系列自注意机制、位置编码层以及卷积层对其进行特征提取;解码子网络则基于编码后的中间表示运用自注意机制、位置编码层以及反卷积层来重建图像输出序列。
判别器D主要由一个特征提取器和一个分类器组成。特征提取器从生成图片x开始处理信息,在经过多层自注意力机制、位置编码网络以及卷积核网络等处理后得到图片特征表示。随后分类器会对这些特征进行分析并完成判别任务。
SAGAN的优点主要有:
该系统通过其自注意力机制实现对全局与局部模式的有效识别,并显著提升其在复杂信息处理中的性能表现。
2.SAGAN的encoder-decoder结构,能够让生成器G生成更逼真的图片。
SAGAN采用了一种兼具无监督与有监督特点的学习方式,在其初始阶段既包含了无监督元素又融入了有 supervision 的指导机制。具体而言,在无 supervision 学习阶段初期(即未完成知识预设之前),判别器 D 值将趋近于 0.5,并在此过程中促使生成器 G 进一步优化以实现更为合理的数据分布预测。当系统逐步进入有 supervision 学习阶段后(即基于标签信息进行分类的任务完成后),判别器 D 值将趋近于 0.9,在此背景下生成器 G 将能够充分发展成为一种纯粹有效的GAN模型。
在SAGAN模型的训练阶段中,判别器D的主要职责是通过评估生成器G所生成图像与真实图像之间的差异程度来优化生成器G的参数设置.为了提升整体模型性能,在优化过程中,生成器G旨在最小化定义为... 的误差函数.
min_G max_{D} E_{x~p_{data}(x)}[\log D(x)]+E_{z~p_z(z)}[\log (1-D(G(z)))]
其中x\sim p_{data}表示输入的图像;而z\sim p_z则表示输入的随机噪声。判别器D则通过计算并最小化负对数似然损失函数来更新其参数。
三、SAGAN的训练
下面我们详细介绍SAGAN的训练过程。SAGAN的训练分为以下几个步骤:
1.准备数据集
2.定义网络结构
3.定义损失函数
4.配置优化器
5.训练模型
6.评估模型性能
3.1 数据集准备
我们采用了CelebA数据库这一多模态人脸图像数据库进行相关研究工作,在该数据集中包含了具有128×128像素的名人照片样本共计1万零七百七十七张。其中包含了8百张图片样本用于训练生成器G这一模型模块;另有两百张图片样本则被分配用于评估生成器G模型的能力表现
CelebA数据集的准备工作如下:
1.下载数据集
!wget http://mmlab.ie.cuhk.edu.hk/projects/CelebA.zip
!unzip CelebA.zip -d celeba
!ls celeba/Img/img_align_celeba | wc -l
10177
2.划分数据集
import os
from PIL import Image
import numpy as np
DATASET_DIR = 'celeba'
def load_image(filename):
img = Image.open(os.path.join(DATASET_DIR+'/Img', filename))
return img
class CelebADataset():
def __init__(self, dataset_dir=DATASET_DIR, image_size=(128,128), mode='train'):
self.dataset_dir = dataset_dir
self.image_size = image_size
if mode == 'train':
annotations_file = open(os.path.join(dataset_dir,'Anno','list_attr_celeba.txt'), 'r')
images_file = open(os.path.join(dataset_dir,'Img','list_eval_partition.txt'), 'r')
lines = [line.strip().split() for line in annotations_file]
data = {}
count = len([name for name in os.listdir(os.path.join(dataset_dir,'Img','img_align_celeba'))])
for i, line in enumerate(lines):
name, *attrs = line[0], list(map(int, line[1:]))
if not attrs and mode=='train':
continue
elif not attrs and mode=='test':
continue
else:
data[name] = {
'class': attrs[-1],
'image': load_image(f'{name}.jpg').resize(image_size)
}
split_ratio = 0.8
keys = list(data.keys())[:count*split_ratio//1]
train_keys = set(np.random.choice(keys, size=count*split_ratio//1, replace=False).tolist())
test_keys = set(keys) - train_keys
print("Number of training examples:",len(train_keys))
print("Number of testing examples:",len(test_keys))
self.images = []
self.labels = []
for key in sorted(train_keys):
self.images.append(data[key]['image'])
self.labels.append(data[key]['class'] / 2 - 1)
def __getitem__(self, idx):
image = self.images[idx].convert('RGB')
label = self.labels[idx]
return image, label
def __len__(self):
return len(self.images)
3.数据集的可视化
import matplotlib.pyplot as plt
%matplotlib inline
fig = plt.figure(figsize=(16,8))
grid = plt.GridSpec(4,4,wspace=0.0,hspace=0.0)
for i in range(16):
ax = fig.add_subplot(grid[i])
image, label = ds[i]
image = np.array(image)/255
ax.imshow(image)
ax.set_title(label)
ax.axis('off')
plt.show()
其中,/255的目的是将像素值缩放到0-1之间。
3.2 模型结构
SAGAN的网络结构如图所示:
图中所示的Encoder接收随机噪声z作为输入,并经过自注意力机制、位置编码机制以及卷积核机制等多种处理步骤完成特征提取过程以生成最终的隐式表示h。随后经由判别器D执行识别判断操作以获得最终的分类判定结果y其中该判别器D的设计架构基于批次大小为batch、类别数量为M的一组标准训练样本集合。具体而言该鉴别网络体系由两个主要组成部分构成一方面由特征提取子网络负责从输入图像x出发提取关键特征信息另一方面则利用全连接层等结构实现最终的情感或类别判定过程其对应的优化目标函数采用的是二元交叉熵损失函数与sigmoid激活相结合的形式即BCEWithLogitsLoss的形式
在编码过程中(原文:生成器G的结构相较于encoder略有不同),与编码器相比(原文: encoder), 生成器G架构(原文: generator G)有所差异(原文:略有不同)。在编码阶段(原文: G首先通过多个自注意力模块), 生成器G接收一个由随机噪声z构成(原文: G的输入是由随机噪声z组成的batch x L的噪声输入)的一个batch x L维度的数据块作为输入(新增)。随后,在解码阶段(原文: G将h送入一个解码器decoder,并输出生成的图片x), 为了优化性能(原文: 在训练过程中), 生成器G会根据判别器D反馈的信息——即生成图像与真实图像之间的差异——来更新自身的参数(新增)。在这种训练机制下(原文: 通过比较判别...能力), 该模型能够逐步提升其对复杂模式的学习能力并实现高质量图像合成目标(新增)。该模型采用BCEWithLogitsLoss作为损失函数(原文: 生成器...损失函数是BCEWithLogitsLoss)。
SAGAN通过Gradient Clipping技术抑制梯度爆炸。其中最大的裁剪阈值设定为5.0。
3.3 Loss Function
它是由sigmoid激活函数与交叉熵损失函数结合而成的损失函数。它不仅能够处理回归问题,并且其输出结果为单一数值。
3.4 Optimizer
Adam optimizer是广泛使用的优化算法,在深度学习领域占据重要地位。该方法通过结合了两个核心操作实现逐步提升的效果——通过动态调整学习速率来提高训练效率与准确性。具体而言,在模型性能不断改善的过程中动态调节学习速率能够帮助算法以更快捷的方式实现更好的训练效果。在优化器设置方面:动量项通常建议设置为0.9(或类似值),这种参数选择有助于加速收敛并提升模型稳定性;同时可以考虑引入指数衰减的学习率策略(如gamma=0.99),这种指数衰减机制不仅能够防止过拟合还能进一步提升模型泛化能力
optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=weight_decay)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=weight_decay)
3.5 Training Loop
SAGAN的训练循环如下:
loss_g_total = 0
loss_d_total = 0
for epoch in range(num_epochs):
running_loss_g = 0
running_loss_d = 0
for batch_id, (imgs, labels) in enumerate(dataloader):
imgs = imgs.to(device)
noise = torch.randn((batch_size, latent_dim)).to(device)
real_labels = torch.ones((batch_size, ), dtype=torch.float32).unsqueeze(-1).to(device)
fake_labels = torch.zeros((batch_size, ), dtype=torch.float32).unsqueeze(-1).to(device)
'''
Update discriminator network parameters using backpropagation
'''
optimizer_d.zero_grad()
outputs = discriminator(imgs.reshape((-1,) + input_shape))
err_real = criterion(outputs, real_labels)
err_real.backward()
z = torch.randn((batch_size, latent_dim)).to(device)
gen_imgs = generator(z)
outputs = discriminator(gen_imgs.detach().reshape((-1,) + input_shape))
err_fake = criterion(outputs, fake_labels)
err_fake.backward()
gradient_penalty = calc_gradient_penalty(discriminator, imgs.reshape((-1,) + input_shape),
gen_imgs.detach().reshape((-1,) + input_shape))
gradient_penalty.backward()
optimizer_d.step()
'''
Update generator network parameters using backpropagation
'''
optimizer_g.zero_grad()
z = torch.randn((batch_size, latent_dim)).to(device)
gen_imgs = generator(z)
outputs = discriminator(gen_imgs.reshape((-1,) + input_shape))
err_g = criterion(outputs, real_labels)
err_g.backward()
optimizer_g.step()
running_loss_g += err_g.item()
running_loss_d += err_fake.item()+err_real.item()
'''
Print statistics
'''
avg_loss_g = running_loss_g/(batch_id+1)
avg_loss_d = running_loss_d/(batch_id+1)
scheduler_d.step()
scheduler_g.step()
if epoch%1==0:
print(f"Epoch:{epoch}/{num_epochs}, Generator Loss: {avg_loss_g:.4f}, Discriminator Loss: {avg_loss_d:.4f}")
在每一轮训练结束时(或每一轮训练结束后),系统将输出统计数据,并详细记录生成器与判别器的损失值变化情况。
四、Experiment Results
最后,我们通过几个实验验证SAGAN的有效性和效果。
4.1 Inception Score
该指标通常被用作评估生成图像质量的标准,在实际应用中具有重要的参考价值。
当该指标值越高时,则表示生成图像的质量越佳。
与传统的GAN模型相比,在计算复杂度方面有所优化后得到的SAGAN模型同样能够通过该指标来进行评估。
采用inception score作为评价标准时,默认假设生成的所有图片均为无标签图像。其计算方法即为:
\mu_{\theta}(x) \approx \mathbb{E}_{x\sim p_{ \theta}} [\log D(x)] + \log 1+\log n
where it represents \theta, where it represents
在本研究中,在以下公式推导中使用了以下变量:θ代表Inception v3网络的参数;其中D表示一张输入图片对应某个类别的预测概率值;n代表生成的数据样本总数;而log则表示自然对数函数。
4.2 FID score
Frechet Inception Distance(FID score),也被简称为"鞍点距离"。该指标也可被视为另一个用于评估生成图像质量的方法。当计算得到的FID分数值最小时,通常意味着生成图像的质量较高。通过计算其对应的FID分数值大小即可评估SAGAN模型所生成图像的质量水平。
采用FID score作为评估图像生成质量的标准时,在生成的所有图像均为无标签的前提下,则其计算公式如下:
该模型通过 Frechet Inception Distance (FID) 计算生成图像与真实图像之间的差异性指标。
具体而言,
FI D(x)=?
其中,
第一项
||μθ在点X处的差向量||₂²
表示两个均值向量 μ_θ 和 μ_θ’ 在 X 处取值之差的 L2 范数平方,
反映了两种分布之间的距离度量。
第二项
$Tr(Σ_{
}
)
则衡量了协方差矩阵 Σ_{
} 在 X 处的变化情况。
Among other things, θ represents the parameter of the Inception v3 network. Furthermore, θ' denotes the second distribution. Additionally, Tr stands for the trace operator.
在当前情境中,在Inception v3网络中,变量θ和θ’分别表示为两个分布的参数;其中符号Tr被定义为trace运算符
4.3 Visual Comparison
通过对比真实图片和生成图片,来直观地评价生成图片质量。
5. Future Work
有很多地方需要改进,比如:
更多的实验数据集,比如ImageNet、Places、LSUN等。
更高级的网络结构,比如ResNet-based、MobileNet-based等。
跨域迁移训练,即将生成器和判别器从源域迁移到目标域。
除了更为复杂的自注意力机制模块之外还有可变形状卷积和基于导数的注意力机制模块
另外,在探索生成模型时,还可以考虑采用诸如DCGAN或WGAN-GP等不同的架构设计;通过对比分析这些模型的特点与优势劣势关系,在实际应用中能够帮助我们找到最适合该场景的最佳解决方案。
