Advertisement

Generating Faces using Generative Adversarial Networks

阅读量:

作者:禅与计算机程序设计艺术

1.简介

在这个项目中,我们将利用基于深度学习的生成对抗网络(GAN)来进行图像生成.GAN是一种无监督式的机器学习模型,它由两个主要组件构成:一个是生成器(Generator),另一个是判别器(Discriminator).其中,生成器是一个能够产出逼真的面部图像模型,而判别器则是一个能够辨识输入图片是否合法的模型.这两者通过相互竞争的方式不断优化以生产更加逼真且细节丰富的图像.

本项目的源码主要依赖于PyTorch框架,并附带了经过训练的模型文件。运行预训练的模型文件即可生成图像。不过如果希望自行训练模型,则可参考本指南中的相关说明。

2.相关概念

2.1 生成对抗网络

生成对抗网络(Generative Adversarial Network, GAN),是 Ian Goodfellow 和 Ilya Sutskever 于2014年创建的一种深度学习模型。该方法属于无监督机器学习范畴,并由两个相互竞争以捕获样本空间中的关键特征的网络组成。这两个网络分别扮演着不同的角色——一个是生成器(G),另一个是鉴别器(D)。它们的目标是通过游戏机制共同优化,在此过程中实现对高质量数据的有效合成。

2.1.1 概念

GAN最初于2014年由Ian Goodfellow和Ilya Sutskever建立。它是一种生成模型,并主要包含两个神经网络模块。

  • 生成器(G):其目标是从潜在空间中生成与输入数据分布一致的样本。通过作用于G输出的误差函数使其难以分辨真实样例与生成样例。
  • 判别器(D):它被设计为一种分类模型,在接收任意输入时能够识别该输入是否为真实数据或被生成的数据。

GAN模型的特点包括:

  • 无监督学习: GAN无需任何先验知识或标签信息,在生成器与判别器之间的持续对抗中产出一系列样本。
  • 生成模型: 其生成的数据通常呈现出高度多样的特性和复杂性。
  • 对抗训练: 基于对抗训练机制的模型中, 两个网络不断竞争协作以提升其性能.

2.1.2 两种网络结构

2.1.2.1 DCGAN

DCGAN简称为Deep Convolutional GAN(DCGAN),是一种基于卷积神经网络构建而成的生成对抗网络(GAN)模型。该系统由编码器和解码器两个关键组件构成,并通过共用参数矩阵实现信息传递。其中,在线性代数运算框架下,编码器通过卷积操作提取高阶抽象特征向量作为表征;而解码器则通过反向传播过程重构输入数据的空间结构。尽管该方法在生成能力方面展现出显著优势,在实际应用中仍面临两个主要局限性:一是系统规模较大导致计算效率较低;二是其输出图像质量普遍偏低。

2.1.2.2 CycleGAN

该生成对抗网络模型属于无监督学习范畴,在2017年首次提出。该模型的核心优势在于能够同步训练两个不同领域间的互逆映射关系,并通过循环反馈机制实现高质量的数据转换效果。其基本架构如图所示:

CycleGAN不仅能够实现不同领域间的转换,还通过持续优化来增强生成质量,并提高鉴别器的鉴别水平。

2.1.3 模型流程

生成对抗网络的训练过程分为两个阶段,即训练生成器和训练判别器。

2.1.3.1 训练生成器

生成器的训练通过两种途径完成, 包括预测误差与训练误差. 预测误差用于评估生成器的表现, 而训练误差则被用来优化生成器性能.

2.1.3.1.1 预测损失

为了最大限度地欺骗判别器,在训练过程中希望其将所有假图片误分为真实类别。与此同时,在评估阶段使用预测损失作为衡量标准来判断生成图片是否为真实类别。

\mathbb{E}_{x\sim p_{data}(x)}\left[\log D(x)\right] + \mathbb{E}_{z\sim p_z(z)}\left[\log (1 - D(G(z))\right]

其中 p_{data}(x) 表示真实图片的分布;判别器网络 D 和生成器网络 G 被用来训练生成对抗网络;噪声 z 被设计为服从标准正态分布。

2.1.3.1.2 训练损失

训练损失的目标是帮助生成器产出更高质量的图像,并非旨在使判别器误判。为了实现这一目标,生成器必须通过调整自身参数来增强对判别器输出的反驳能力。通过计算训练损失, 我们可以评估生成图像是否被判别器正确识别为真实图像

\min _{\theta_{\text {generator}}}\max _{\theta_{\text {discriminator}}}\mathbb{E}_{x\sim p_{data}(x)}[\log D(x) + \log (1-\tilde{D}(G(z))) + \frac{1}{2}\|x-G(z)\|_2^2] \quad \text{s.t.} \quad \|\nabla_x D(x)\|_2 \leqslant 1, \quad \|\nabla_z G(z)\|_2 \leqslant 1

其中\theta_{\text {gen }}\theta_{\text {dis }}分别代表生成器和判别器的参数;\tilde{D}则用于表示一副假图片;而绝对值运算符则用于计算数值。

2.1.3.2 训练判别器

判别器的训练分为两种途径进行:一是采用真实损失作为优化目标;二是采用假想损失来降低模型对虚假数据的学习能力。

2.1.3.2.1 真实损失

真实损失的目的是为了让判别器更准确地识别出所有真实图片。

\min _{\theta_{\text {dis }}}\mathbb{E}_{x\sim p_{data}(x)}[(\log D(x)+\log (1-\tilde{D}(x)))]

2.1.3.2.2 假设损失

假设损失的目的是为了使判别器更难以正确地识别假图片。

为了最大化判别器参数\theta_{discrete}, 在数据分布p_g(x)下计算生成器能够实现的最大化期望值。具体来说,在真实样本y_{fake}与判别器估计概率\tilde{D}(x)之间取差值的基础上,并加入一个基于梯度范数的正则项以防止过拟合。

其中 p_g(x) 代表生成模型所推断的数据分布,在GAN框架中;当 y_{fake}=0 时,则表明判别器将真实图像归类为虚假样本;而当 y_{fake}=1 时,则表明判别器将虚假图像归类为真实样本;其中 \beta 被视为调节参数以平衡模型性能与稳定性之间的关系

以上为生成对抗网络(GAN)的相关概念。

3.具体方案与实施

3.1 数据集准备

3.1.1 CelebA

Celeba 是一个专门收集面部属性数据的数据库,在其庞大的存储量中包含了超过二十万张不同名人面部照片,并按类别进行了细致划分。每个类别都对应着独特的属性特征(如外貌特征、眼距大小以及笑容的程度等)。值得注意的是,在这些分类中有一些共同点也存在差异性特点(例如相似的笑容深度可能在不同类别中体现得并不完全相同)。此外,在这一丰富的数据资源下研究人员可从中提取出大量具有参考价值的样本数据从而实现多维度的数据分析工作(包括但不限于图像识别动作识别目标检测图像生成风格迁移等方面的具体应用)。

CelebA 包含以下几种类别:

以下是基于给定规则对原文的改写内容

3.1.2 数据集准备

由于CelebA数据集过大,所以我们只选择其中一部分作为训练集。

复制代码
    import os
    import random
    from shutil import copyfile
    
    def split_dataset(src_dir='./', dest_dir='./'):
    train_path = os.path.join(dest_dir, 'train')
    valid_path = os.path.join(dest_dir, 'valid')
    test_path = os.path.join(dest_dir, 'test')
    
    # Create folders if not exist
    os.makedirs(train_path, exist_ok=True)
    os.makedirs(valid_path, exist_ok=True)
    os.makedirs(test_path, exist_ok=True)
    
    # Split data into training and testing sets randomly with ratio of 0.8:0.1:0.1
    img_files = [os.path.join(src_dir, f) for f in sorted(os.listdir(src_dir))]
    num_imgs = len(img_files)
    idx_list = list(range(num_imgs))
    random.shuffle(idx_list)
    start_idx = int((num_imgs*0.8)//1)*1
    end_idx = int((num_imgs*(0.8+0.1))//1)*1
    train_idx = idx_list[:start_idx]
    val_idx = idx_list[start_idx:end_idx]
    test_idx = idx_list[end_idx:]
    print('Number of images:', num_imgs)
    print('Training set size:', len(train_idx), '\nValidation set size:', len(val_idx), '\nTesting set size:', len(test_idx))
    
    # Move files to corresponding folder
    for i in range(len(img_files)):
        if i in train_idx:
        elif i in val_idx:
        else:
        copyfile(img_files[i], dst_path)
    
    split_dataset()
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.2 训练模型

3.2.1 配置环境与导入包

首先正确设置相应的环境参数,例如安装Anaconda程序并创建虚拟环境后进行激活。随后导入所需的软件包。

首先正确设置相应的环境参数,例如安装Anaconda程序并创建虚拟环境后进行激活.随后导入所需的软件包.

复制代码
    !pip install torch torchvision matplotlib numpy tqdm
    %matplotlib inline
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision.transforms as transforms
    import torchvision.utils as vutils
    from torch.autograd import Variable
    from PIL import Image
    import matplotlib.pyplot as plt
    import numpy as np
    from tqdm import trange, tqdm
    
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.2.2 加载数据

加载已划分好的训练集、验证集、测试集。定义数据预处理函数。

复制代码
    transform = transforms.Compose([
    transforms.Resize(size=(64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    batch_size = 32
    
    trainset = torchvision.datasets.ImageFolder('./train/', transform)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    
    validset = torchvision.datasets.ImageFolder('./valid/', transform)
    validloader = DataLoader(validset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    testset = torchvision.datasets.ImageFolder('./test/', transform)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.2.3 创建网络结构

定义生成器(Generator)网络和判别器(Discriminator)网络。

复制代码
    class Generator(nn.Module):
    def __init__(self, ngpu):
        super().__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels=nz, out_channels=nc, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(nc),
            nn.ReLU(inplace=True),
    
            nn.ConvTranspose2d(in_channels=nc, out_channels=int(nc/2), kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(int(nc/2)),
            nn.ReLU(inplace=True),
    
            nn.ConvTranspose2d(in_channels=int(nc/2), out_channels=int(nc/4), kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(int(nc/4)),
            nn.ReLU(inplace=True),
    
            nn.ConvTranspose2d(in_channels=int(nc/4), out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)
    
    class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super().__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=int(ndf/4), kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
    
            nn.Conv2d(in_channels=int(ndf/4), out_channels=int(ndf/2), kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(int(ndf/2)),
            nn.LeakyReLU(0.2, inplace=True),
    
            nn.Conv2d(in_channels=int(ndf/2), out_channels=int(ndf), kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(int(ndf)),
            nn.LeakyReLU(0.2, inplace=True),
    
            nn.Conv2d(in_channels=int(ndf), out_channels=1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input)
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.2.4 定义优化器和损失函数

定义优化器和损失函数。

复制代码
    lr = 0.0002
    betas = (0.5, 0.999)
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=betas)
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=betas)
    criterion = nn.BCELoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    
      
      
      
      
      
      
    
    代码解读

3.2.5 测试模型

随机测试一下模型的运行情况。

复制代码
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
    real_batch = next(iter(trainloader))
    real_images = real_batch[0].to(device)
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_images[0:64], normalize=True).cpu(),(1,2,0)))
    
    fake = netG(fixed_noise).detach().cpu()
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(vutils.make_grid(fake, normalize=True),(1,2,0)))
    plt.show()
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.2.6 训练模型

训练模型,保存模型参数。

复制代码
    for epoch in trange(num_epochs):
    for i, data in enumerate(trainloader, 0):
    
        # Configure input
        real_images = data[0].to(device)
    
        # ---------------------
        #  Train Discriminator
        # ---------------------
    
        optimizerD.zero_grad()
    
        # Sample noise as generator input
        z = torch.randn(batch_size, nz, 1, 1, device=device)
    
        # Generate a batch of images
        fake_images = netG(z).to(device)
    
        # Real images
        label = torch.full((batch_size,), real_label, device=device)
        output = netD(real_images).view(-1)
        errD_real = criterion(output, label)
        d_x = output.mean().item()
    
        # Fake images
        label.fill_(fake_label)
        output = netD(fake_images.detach()).view(-1)
        errD_fake = criterion(output, label)
        d_G_z1 = output.mean().item()
    
        # Combined loss and calculate gradients
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()
    
    
        # -----------------
        #  Train Generator
        # -----------------
    
        optimizerG.zero_grad()
    
        # Generate a batch of images
        fake_images = netG(z).to(device)
        label.fill_(real_label)
    
        # Loss measures generator's ability to fool the discriminator
        output = netD(fake_images).view(-1)
        errG = criterion(output, label)
    
        # Calculate gradients for G
        errG.backward()
        optimizerG.step()
    
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
    
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i==len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
    
        iters += 1
    
    # Save the trained model parameters
    torch.save({
    'netG': netG.state_dict(),
    'netD': netD.state_dict(),
    }, './checkpoint/faces_checkpoint.pth')
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

3.3 测试模型

测试模型,并展示生成的图像。

复制代码
    with torch.no_grad():
    fake = netG(fixed_noise).detach().cpu()
    plt.figure(figsize=(15,15))
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True),(1,2,0)))
    plt.show()
    
      
      
      
      
      
      
    
    代码解读

4.总结与思考

4.1 总结

本文旨在介绍生成对抗网络(GAN)的基本概念及其实现机制,并以‘生成人脸图像’为例详细阐述了GAN的工作原理。文章从模型架构、训练方法以及数据处理这几个方面深入探讨了GAN的工作机制。

本文在应用领域上实现了突破,在图像处理方面引入了Generative Adversarial Networks(GAN)。该研究同时探索了适用于特定场景的数据集构建方法,并采用具有独特特性的面部图像数据集。相较于现有的生成图像技术而言,在提升生成效果的同时也为生成效果的提升提供了新的思路。研究者通过协调训练自动生成器与鉴别器模型,并模仿真实数据分布特性构建了独特的数据样本库。

作者系统地研究了GAN的基本概念、构成模块以及优化方法,并全面探讨了其在数据预处理和输出图像质量上的应用。

作者提供的代码实现、可读性强、注释详细、易于理解,对初学者非常友好。

4.2 心得与感悟

本文围绕基础概念、模型结构以及训练策略展开全面阐述GAN的工作原理,并将其成功应用于人脸图像生成领域。其语言生动且流畅,在内容上条理清晰,并通过图文并举的方式使读者易于理解;整篇文章层次分明且内容详实,特别适合非专业人士阅读。

作者通过简明扼要且条理清晰的文字叙述,精确阐述了GAN的基本原理、结构以及训练策略,并对其在图像领域的发展前景进行了深入探讨。

通过阅读这篇文章,我深切体会到了阐述抽象理论与复杂体系的重要性——它不仅仅局限于堆积成山的理论知识,而关键在于掌握这些理论背后所蕴含的科学技术原理.这不仅有助于深化我们对自然规律的认识,也有助于提升我们对实际世界的认知能力.

全部评论 (0)

还没有任何评论哟~