Advertisement

GAN生成对抗网络

阅读量:
  1. 对于深入研究GAN过程的新手而言,一个重要的注意事项是选择基础类型的网络架构作为生成器和判别器(例如:全连接层结构)。为了维持训练过程中的稳定性与效果,在设计网络时建议采用批量归一化技术以提升模型表现。此外,在生成器的最后一层应采用双曲正切函数(tanh)进行激活;虽然也可以选择sigmoid函数替代使用,在实际应用中可以根据具体需求进行调整。
  2. 这一代码示例展示了基于PyTorch框架实现的GAN模型。

导入相关库

复制代码
    import torch

    import torchvision
    import torch.nn as nn
    import torch.nn.functional as F
    
    import matplotlib.pylab as plt
    from matplotlib import animation
    from IPython.display import HTML

设置用到的一些常量

复制代码
    BATCH_SIZE = 100

    IMG_CHANNELS = 1
    NUM_Z = 100
    NUM_GENERATOR_FEATURES = 64
    NUM_DISCRIMINATOR_FEATURES = 64
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
    # INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)

加载数据集(MNIST10)数据集

复制代码
    transform = torchvision.transforms.Compose([

    torchvision.transforms.ToTensor()
    ])
    
    # ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform,  download=True)
    ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True)
    ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

查看数据

复制代码
    img_batch, lab_batch = next(iter(ds_loader))

    img_batch.shape, lab_batch.shape

绘制数据集图像

复制代码
    plt.figure(figsize=(8, 8), dpi=80)

    plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0))
    plt.tight_layout()
    plt.axis("off")

定义生成器和判别器

复制代码
    class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        # (o - 1) * s - 2 * p + w
        self.main = nn.Sequential(
            # 100 x 1 x 1 --> 512 x 4 x 4
            nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8),
            nn.ReLU(True),
            # 512 x 4 x 4 --> 512 x 8 x 8
            nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
            nn.ReLU(True),
            # 512 x 8 x 8 --> 512 x 16 x 16
            nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
            nn.ReLU(True),
            # 512 x 16 x 16 --> 512 x 14 x 14
            nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 1, 1, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
            nn.ReLU(True),
            # 512 x 14 x 14 --> 512 x 28 x 28
            nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 1, IMG_CHANNELS, 2, 2, 0, bias=False),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.main(x)
    
    
    class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # 1 x 28 x 28  --> 256 x 14 x 14 
            nn.Conv2d(IMG_CHANNELS, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 14 x 14  --> 128 x 7 x 7
            nn.Conv2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 7 x 7 --> 64 x 3 x 3
            nn.Conv2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 3 x 3 --> 1 x 1 x 1
            nn.Conv2d(NUM_GENERATOR_FEATURES * 1, 1, 3, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x).view(-1)

测试定义的模型

复制代码
    noise = torch.randn(BATCH_SIZE, NUM_Z, 1, 1)

    generator = Generator()
    fake_img = generator(noise)
    discriminator = Discriminator()
    discriminator(fake_img)

网络参数初始化函数

复制代码
    # custom weights initialization called on netG and netD

    def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

定义生成器和判别器对象,优化器,损失函数和评估标准

复制代码
    generator = Generator().to(DEVICE).apply(weights_init)

    discriminator = Discriminator().to(DEVICE).apply(weights_init)
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    loss_fn = nn.BCELoss()
    metrics_fn = lambda y_true, y_pred: torch.mean((y_true == torch.where(y_pred >=0.5, torch.tensor(1., device=DEVICE), torch.tensor(0., device=DEVICE))).to(torch.float32))

定义训练步骤(重点)

复制代码
    def train_step(inputs, labels):

    labels = labels.to(torch.float32)
    inputs_g = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
    # inputs_g = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
    outputs_g = generator(inputs_g)
    
    # fix generator, unfix discriminator
    for parameter in generator.parameters():
        parameter.require_grad = False
    for parameter in discriminator.parameters():
        parameter.require_grad = True
    optimizer_d.zero_grad()
    
    # real image
    labels = torch.ones_like(labels)
    outputs = discriminator(inputs)
    loss_real = loss_fn(outputs, labels)
    metrics_real = metrics_fn(labels, outputs)
    loss_real.backward()
    
    # fake image
    labels = torch.zeros_like(labels)
    outputs = discriminator(outputs_g.detach())  # 这里有一个detach()
    loss_fake = loss_fn(outputs, labels)
    metrics_fake = metrics_fn(labels, outputs)
    loss_fake.backward()
    
    loss_d = (loss_real + loss_fake) / 2
    metrics_d = (metrics_real + metrics_fake) / 2
    # loss_d.backward()
    optimizer_d.step()
    
    
    # unfix generator, fix discriminator
    for parameter in generator.parameters():
        parameter.require_grad = True
    for parameter in discriminator.parameters():
        parameter.require_grad = False
    optimizer_g.zero_grad()
    
    labels = torch.ones_like(labels)
    outputs = discriminator(outputs_g)
    loss_g = loss_fn(outputs, labels)
    metrics_g = metrics_fn(labels, outputs)
    loss_g.backward()
    optimizer_g.step()
    
    return loss_d.item(), metrics_d.item(), loss_g.item(), metrics_g.item()

测试定义的训练步骤

复制代码
    train_step(img_batch.to(DEVICE), lab_batch.to(DEVICE))

定义训练循环

复制代码
    epochs = 8

    loss_d_list, metrics_d_list, loss_g_list, metrics_g_list = [], [], [], []
    grid_img_list = []
    
    for epoch in range(epochs):
    
    loss_d_batch = metrics_d_batch = loss_g_batch = metrics_g_batch = .0
    num_batch = 0
    for img_batch, lab_batch in ds_loader:
        img_batch = img_batch.to(DEVICE)
        lab_batch = lab_batch.to(DEVICE)
        loss_d, metrics_d, loss_g, metrics_g = train_step(img_batch, torch.ones_like(lab_batch))
        num_batch += 1
        loss_d_batch, metrics_d_batch = loss_d_batch + loss_d, metrics_d_batch + metrics_d
        loss_g_batch, metrics_g_batch = loss_g_batch + loss_g, metrics_g_batch + metrics_g
        
    loss_d_batch, metrics_d_batch = loss_d_batch / num_batch, metrics_d_batch / num_batch
    loss_g_batch, metrics_g_batch = loss_g_batch / num_batch, metrics_g_batch / num_batch
    
    loss_d_list.append(loss_d_batch)
    metrics_d_list.append(metrics_d_batch)
    loss_g_list.append(loss_g_batch)
    metrics_g_list.append(metrics_g_batch)
    
    print("[%d/%d] loss_discriminator: %.2f, metrics_distriminator: %.2f, loss_generator: %.2f, metrics_generator: %.2f" % (
        epoch, epochs, loss_d_batch, metrics_d_batch, loss_g_batch, metrics_g_batch))
    
    
    
    with torch.no_grad():
        outputs_g = generator(INPUTS_G)
        outputs_d = discriminator(outputs_g)
        
        grid_img_list.append(torchvision.utils.make_grid(outputs_g.cpu(), nrow=10, normalize=True, pad_value=1))
    
        plt.figure(figsize=(20, 2), dpi=80)
        for i, (img, lab) in enumerate(zip(outputs_g[:16], outputs_d[:16])):
            plt.subplot(1, 16, i+1)
            plt.imshow(img.cpu().permute(1, 2, 0), cmap=plt.cm.binary)
            plt.title("%.2f" % lab.cpu().item())
            plt.axis("off")
        plt.tight_layout()
        plt.show()

绘制损失值和评估指标

复制代码
    plt.figure(figsize=(12, 4), dpi=80)

    
    plt.subplot(1, 2, 1)
    plt.plot(loss_d_list, label="discriminator_loss")
    plt.plot(loss_g_list, label="generator_loss")
    plt.title("Loss of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(metrics_d_list, label="discriminator_metrics")
    plt.plot(metrics_g_list, label="generator_metrics")
    plt.title("Metrics of discriminator and generator")
    plt.xlabel("epochs")
    plt.ylabel("metrics")
    plt.legend()
    
    plt.show()
在这里插入图片描述

绘制动态的GAN图像生成过程

复制代码
    fig = plt.figure(figsize=(10, 10), dpi=80)

    plt.axis("off")
    
    imgs = [[plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)] for img in grid_img_list]
    ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True)
    
    HTML(ani.to_jshtml())

绘制真实图片和GAN生成图片

复制代码
    plt.figure(figsize=(20, 10), dpi=80)

    
    plt.subplot(1, 2, 1)
    plt.title("real digits image")
    plt.imshow(torchvision.utils.make_grid(img_batch.cpu(), nrow=10, normalize=True, pad_value=1).permute(1, 2, 0))
    plt.axis("off")
    
    plt.subplot(1, 2, 2)
    plt.title("fake digits image")
    plt.imshow(np.transpose(grid_img_list[-1], (1, 2, 0)))
    plt.axis("off")
在这里插入图片描述

全部评论 (0)

还没有任何评论哟~