GAN生成对抗网络
发布时间
阅读量:
阅读量
- 对于深入研究GAN过程的新手而言,一个重要的注意事项是选择基础类型的网络架构作为生成器和判别器(例如:全连接层结构)。为了维持训练过程中的稳定性与效果,在设计网络时建议采用批量归一化技术以提升模型表现。此外,在生成器的最后一层应采用双曲正切函数(tanh)进行激活;虽然也可以选择sigmoid函数替代使用,在实际应用中可以根据具体需求进行调整。
- 这一代码示例展示了基于PyTorch框架实现的GAN模型。
导入相关库
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pylab as plt
from matplotlib import animation
from IPython.display import HTML
设置用到的一些常量
BATCH_SIZE = 100
IMG_CHANNELS = 1
NUM_Z = 100
NUM_GENERATOR_FEATURES = 64
NUM_DISCRIMINATOR_FEATURES = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
# INPUTS_G = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
加载数据集(MNIST10)数据集
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# ds = torchvision.datasets.cifar.CIFAR10(root="data", train=True, transform=transform, download=True)
ds = torchvision.datasets.mnist.MNIST(root="data", train=True, transform=transform, download=True)
ds_loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
查看数据
img_batch, lab_batch = next(iter(ds_loader))
img_batch.shape, lab_batch.shape
绘制数据集图像
plt.figure(figsize=(8, 8), dpi=80)
plt.imshow(torchvision.utils.make_grid(img_batch, nrow=10, padding=2, pad_value=1, normalize=True).permute(1, 2, 0))
plt.tight_layout()
plt.axis("off")
定义生成器和判别器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# (o - 1) * s - 2 * p + w
self.main = nn.Sequential(
# 100 x 1 x 1 --> 512 x 4 x 4
nn.ConvTranspose2d(NUM_Z, NUM_GENERATOR_FEATURES * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 8),
nn.ReLU(True),
# 512 x 4 x 4 --> 512 x 8 x 8
nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 8, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
nn.ReLU(True),
# 512 x 8 x 8 --> 512 x 16 x 16
nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
nn.ReLU(True),
# 512 x 16 x 16 --> 512 x 14 x 14
nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 1, 1, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
nn.ReLU(True),
# 512 x 14 x 14 --> 512 x 28 x 28
nn.ConvTranspose2d(NUM_GENERATOR_FEATURES * 1, IMG_CHANNELS, 2, 2, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
return self.main(x)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 1 x 28 x 28 --> 256 x 14 x 14
nn.Conv2d(IMG_CHANNELS, NUM_GENERATOR_FEATURES * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 4),
nn.LeakyReLU(0.2, inplace=True),
# 256 x 14 x 14 --> 128 x 7 x 7
nn.Conv2d(NUM_GENERATOR_FEATURES * 4, NUM_GENERATOR_FEATURES * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 2),
nn.LeakyReLU(0.2, inplace=True),
# 128 x 7 x 7 --> 64 x 3 x 3
nn.Conv2d(NUM_GENERATOR_FEATURES * 2, NUM_GENERATOR_FEATURES * 1, 4, 2, 1, bias=False),
nn.BatchNorm2d(NUM_GENERATOR_FEATURES * 1),
nn.LeakyReLU(0.2, inplace=True),
# 64 x 3 x 3 --> 1 x 1 x 1
nn.Conv2d(NUM_GENERATOR_FEATURES * 1, 1, 3, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1)
测试定义的模型
noise = torch.randn(BATCH_SIZE, NUM_Z, 1, 1)
generator = Generator()
fake_img = generator(noise)
discriminator = Discriminator()
discriminator(fake_img)
网络参数初始化函数
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
定义生成器和判别器对象,优化器,损失函数和评估标准
generator = Generator().to(DEVICE).apply(weights_init)
discriminator = Discriminator().to(DEVICE).apply(weights_init)
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
loss_fn = nn.BCELoss()
metrics_fn = lambda y_true, y_pred: torch.mean((y_true == torch.where(y_pred >=0.5, torch.tensor(1., device=DEVICE), torch.tensor(0., device=DEVICE))).to(torch.float32))
定义训练步骤(重点)
def train_step(inputs, labels):
labels = labels.to(torch.float32)
inputs_g = torch.randn(BATCH_SIZE, NUM_Z, 1, 1, device=DEVICE)
# inputs_g = torch.randn(BATCH_SIZE, NUM_Z, device=DEVICE)
outputs_g = generator(inputs_g)
# fix generator, unfix discriminator
for parameter in generator.parameters():
parameter.require_grad = False
for parameter in discriminator.parameters():
parameter.require_grad = True
optimizer_d.zero_grad()
# real image
labels = torch.ones_like(labels)
outputs = discriminator(inputs)
loss_real = loss_fn(outputs, labels)
metrics_real = metrics_fn(labels, outputs)
loss_real.backward()
# fake image
labels = torch.zeros_like(labels)
outputs = discriminator(outputs_g.detach()) # 这里有一个detach()
loss_fake = loss_fn(outputs, labels)
metrics_fake = metrics_fn(labels, outputs)
loss_fake.backward()
loss_d = (loss_real + loss_fake) / 2
metrics_d = (metrics_real + metrics_fake) / 2
# loss_d.backward()
optimizer_d.step()
# unfix generator, fix discriminator
for parameter in generator.parameters():
parameter.require_grad = True
for parameter in discriminator.parameters():
parameter.require_grad = False
optimizer_g.zero_grad()
labels = torch.ones_like(labels)
outputs = discriminator(outputs_g)
loss_g = loss_fn(outputs, labels)
metrics_g = metrics_fn(labels, outputs)
loss_g.backward()
optimizer_g.step()
return loss_d.item(), metrics_d.item(), loss_g.item(), metrics_g.item()
测试定义的训练步骤
train_step(img_batch.to(DEVICE), lab_batch.to(DEVICE))
定义训练循环
epochs = 8
loss_d_list, metrics_d_list, loss_g_list, metrics_g_list = [], [], [], []
grid_img_list = []
for epoch in range(epochs):
loss_d_batch = metrics_d_batch = loss_g_batch = metrics_g_batch = .0
num_batch = 0
for img_batch, lab_batch in ds_loader:
img_batch = img_batch.to(DEVICE)
lab_batch = lab_batch.to(DEVICE)
loss_d, metrics_d, loss_g, metrics_g = train_step(img_batch, torch.ones_like(lab_batch))
num_batch += 1
loss_d_batch, metrics_d_batch = loss_d_batch + loss_d, metrics_d_batch + metrics_d
loss_g_batch, metrics_g_batch = loss_g_batch + loss_g, metrics_g_batch + metrics_g
loss_d_batch, metrics_d_batch = loss_d_batch / num_batch, metrics_d_batch / num_batch
loss_g_batch, metrics_g_batch = loss_g_batch / num_batch, metrics_g_batch / num_batch
loss_d_list.append(loss_d_batch)
metrics_d_list.append(metrics_d_batch)
loss_g_list.append(loss_g_batch)
metrics_g_list.append(metrics_g_batch)
print("[%d/%d] loss_discriminator: %.2f, metrics_distriminator: %.2f, loss_generator: %.2f, metrics_generator: %.2f" % (
epoch, epochs, loss_d_batch, metrics_d_batch, loss_g_batch, metrics_g_batch))
with torch.no_grad():
outputs_g = generator(INPUTS_G)
outputs_d = discriminator(outputs_g)
grid_img_list.append(torchvision.utils.make_grid(outputs_g.cpu(), nrow=10, normalize=True, pad_value=1))
plt.figure(figsize=(20, 2), dpi=80)
for i, (img, lab) in enumerate(zip(outputs_g[:16], outputs_d[:16])):
plt.subplot(1, 16, i+1)
plt.imshow(img.cpu().permute(1, 2, 0), cmap=plt.cm.binary)
plt.title("%.2f" % lab.cpu().item())
plt.axis("off")
plt.tight_layout()
plt.show()
绘制损失值和评估指标
plt.figure(figsize=(12, 4), dpi=80)
plt.subplot(1, 2, 1)
plt.plot(loss_d_list, label="discriminator_loss")
plt.plot(loss_g_list, label="generator_loss")
plt.title("Loss of discriminator and generator")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(metrics_d_list, label="discriminator_metrics")
plt.plot(metrics_g_list, label="generator_metrics")
plt.title("Metrics of discriminator and generator")
plt.xlabel("epochs")
plt.ylabel("metrics")
plt.legend()
plt.show()

绘制动态的GAN图像生成过程
fig = plt.figure(figsize=(10, 10), dpi=80)
plt.axis("off")
imgs = [[plt.imshow(np.transpose(img, (1, 2, 0)), animated=True)] for img in grid_img_list]
ani = animation.ArtistAnimation(fig, imgs, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
绘制真实图片和GAN生成图片
plt.figure(figsize=(20, 10), dpi=80)
plt.subplot(1, 2, 1)
plt.title("real digits image")
plt.imshow(torchvision.utils.make_grid(img_batch.cpu(), nrow=10, normalize=True, pad_value=1).permute(1, 2, 0))
plt.axis("off")
plt.subplot(1, 2, 2)
plt.title("fake digits image")
plt.imshow(np.transpose(grid_img_list[-1], (1, 2, 0)))
plt.axis("off")

全部评论 (0)
还没有任何评论哟~
