Advertisement

SAGAN 代码阅读笔记

阅读量:

论文《自注意力生成对抗网络》
链接:https://arxiv.org/abs/1805.08318
代码链接:https://github.com/heykeetae/Self-Attention-GAN

按照代码流程进行记录

默认参数设置

复制代码
    adv_loss         = 'hinge'
    attn_path        = './attn'
    batch_size       = 64
    beta1            = 0.0
    beta2            = 0.9
    d_conv_dim       = 64
    d_iters          = 5
    d_lr             = 0.0004
    dataset          = 'celeb'
    g_conv_dim       = 64
    g_lr             = 0.0001
    g_num            = 5
    image_path       = './data'
    imsize           = 64
    lambda_gp        = 10
    log_path         = './logs'
    log_step         = 10
    lr_decay         = 0.95
    model            = 'sagan'
    model_save_path  = './models'
    model_save_step  = 1.0
    num_workers      = 2
    parallel         = False
    pretrained_model = None
    sample_path      = './samples'
    sample_step      = 100
    total_step       = 1000000
    train            = True
    use_tensorboard  = False
    version          = 'sagan_celeb'
    z_dim            = 128

Discriminator网络结构

判别器网络设定参数为batch size=64, image_size=64, conv_dim=64

假定输入数据为 torch.Size([64, 3, 64, 64])

复制代码
    # layer1
    Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 64, 32, 32])

复制代码
    # layer2
    Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 128, 16, 16])

复制代码
    # layer3
    Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 256, 8, 8])

可以观察到前三层的网络结构显示基本一致,并且随着深度增加(channel数量)持续增加(同时),dimensión逐步减少

前三层完成后,经过一次自注意力层处理后(此处指卷积神经网络中的自注意力机制),输出特征图的空间维度保持不变(仍为 torch.Size([64, 256, 8, 8])),而生成的注意力权重矩阵(attention map)的规模则缩小为 torch.Size([64, 64, 64])

如果输入图像数据的尺寸为64时,还有一个layer4,与前三层结构一致

复制代码
    # layer4
    Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    LeakyReLU(negative_slope=0.1)

此时变为 torch.Size([64, 512, 4, 4])

第四层处理完毕后,请随后再次进行自注意力计算,并输出第二个注意力图谱torch.Size([64, 16, 16])

复制代码
    # last
    Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))

此时变为 torch.Size([64, 1, 1, 1])

最终,在代码实现中会调用squeeze()函数以去除多余的单维元素,并将模型输出的特征图尺寸限制在torch.Size([64])

Generator网络结构

生成器网络参数设置为 batch_size=64, image_size=64, z_dim=128, conv_dim=64

随后生成一个随机采样值,并对每个图像都具有具有z_dim维空间的噪音样本进行合成操作;同时假设输入数据形式为 torch.Size([64, 128])

先将输入数据变为 torch.Size([64, 128, 1, 1])

复制代码
    repeat_num = int(np.log2(self.imsize)) - 3
    mult = 2 ** repeat_num # 8

计算mult=8

复制代码
    # layer1
    ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))
    SpectralNorm()
    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ReLU()

此时变为 torch.Size([64, 512, 4, 4])

复制代码
    # layer2
    ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ReLU()

此时变为 torch.Size([64, 256, 8, 8])

复制代码
    # layer3
    ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ReLU()

此时变为 torch.Size([64, 128, 16, 16])

第3层之后,会计算 self-attention,其中map1torch.Size([64, 256, 256])

复制代码
    # layer4
    ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    SpectralNorm()
    BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ReLU()

此时变为 torch.Size([64, 64, 32, 32])

第4层之后,也会有attention层,map2torch.Size([64, 1024, 1024])

复制代码
    # last
    ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    Tanh()

此时变为 torch.Size([64, 3, 64, 64])

损失函数计算

Discriminator

鉴别器的完整损失函数表达式为
L_D = -E_{(x,y)\sim p_{data}}[min(0, -1 + D(x,y))] - E_{z \sim p_z, y \sim p_{data}}[min(0, -1 - D(G(z), y))]

输入真实图像

复制代码
    d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()

输入生成图像

复制代码
    d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()

Generator

生成器整体的损失函数是
L_G=-E_{z \sim p_{z},y \sim p_{data}}D(G(z),y)

复制代码
    fake_images,_,_ = self.G(z)
    g_out_fake,_,_ = self.D(fake_images)  # batch x n
    g_loss_fake = - g_out_fake.mean()

也就是说,生成器的损失是判别器对生成图像判别的平均值

总结

生成器和判别器中使用了两层self-attention

在生成器中加入光谱归一化后紧接着又添加了一层BatchNorm2d的地方让我感到困惑

学习速率不同,但是学习迭代比例是1:1的

全部评论 (0)

还没有任何评论哟~