Advertisement

对抗生成网络_白话生成对抗网络GAN及代码实现

阅读量:

他人的智慧结晶

作者: 养生的控制人

地址:https://www.zhihu.com/people/yilan-zhong-shan-xiao-29-98

本文主要是个简单的笔记,参考资料来自下面三部分

  1. Tutorial_HYLee_GAN
  2. Renu Khandelwal 的博客
  3. Jason 的博客

01

神经网络一览

不同类别的人工智能模型(如全连接前馈网络、卷积神经网络和循环神经网络)在本质上其主要特点在于各自有不同的输入输出类型。例如它们的输入输出类型包括向量矩阵或由多个向量组成的序列。

53e82c16851d83ab3fffa99baf3280cf.png

02

GAN的基本思想

GAN是由生成器与判别器组成的:其中生成器本质上也是一个神经网络,并可被视为一个函数。

7a731c74d5baa0a8db2485c4c4c9a42e.png

当给定一个向量时能够生成一幅漫画图像该向量的每个维度对应不同的意义

09c5dcf0d133718670cb9fbcbcbddc1e.png

判别器的本质也是一个神经网络

32756810958ddd398beabdcc4824778c.png

如果给定一张图片,判别器就会告诉你这是不是真实图片

c10032f619342d92411c13ee6435cae3.png

所以GAN的训练本质就是训练两个神经网络。

03

GAN的工作原理

生成器旨在产出与训练数据相似且可欺骗(如伪造)的图像;判别器的任务则是识别这些图像的真实性。 从输入端来看,生成网络通常接收随机噪声作为基础信息源;而整个系统中包含了两个关键路径:一个是真实图片的数据集;另一个是经过生成网络处理后产生的虚假图像。 这个循环系统的运行流程如图所示。

c774792875fcb39da667e93924b15cee.png

每一次迭代过程中:

  1. 优化判别器模型参数。具体而言,在输入生成样本及其对应的标记(上图中的generated example)和真实样本及其对应标注(上图中的real example)的情况下,请指导使判别器能够区分真假样本,并旨在训练一个分类准确率较高的二分类模型。
  2. 在固定判别器模型参数的前提下更新生成模型。具体而言,在输入生成样本及其对应的标记(让判别
    器将这些图像误认为是真实图像)的情况下,请指导通过反向传播误差更新生成模型参数,并从而使
    生成图像更具仿真性。

GAN训练的目标函数如下所示

d00c48fecd9120c143d985c6550c1a4c.png
  • 判别器旨在通过最大化目标函数来区分真实数据与生成数据,在这种情况下,当输入是真实数据时(即D(x)),其值应趋近于1;而当输入是生成的数据时(即D(G(z))),其值应趋近于0。
  • 生成器则通过最小化目标函数来实现对判别器的欺骗效果,具体表现为使生成的数据能够使判别器将其误判为真实数据(即D(G(z))接近1)。

04

GAN的实现

本研究中使用 MNIST 数据集作为实验数据集。经过训练后,我们预期观察到生成器能够生成看似真实的数字样本。导入需要用到的库

复制代码
复制代码
    import numpy as npimport pandas as pdimport matplotlib.pyplot as plt%matplotlib inlineimport kerasfrom keras.layers import Dense, Dropout, Inputfrom keras.models import Model,Sequentialfrom keras.datasets import mnistfrom tqdm import tqdmfrom keras.layers.advanced_activations import LeakyReLUfrom keras.optimizers import Adam

`` 导入数据

复制代码
复制代码
    def load_data():    (x_train, y_train), (x_test, y_test) = mnist.load_data()    x_train = (x_train.astype(np.float32) - 127.5)/127.5    # 将图片转为向量 x_train from (60000, 28, 28) to (60000, 784)     # 每一行 784 个元素    x_train = x_train.reshape(60000, 784)    return (x_train, y_train, x_test, y_test)(X_train, y_train,X_test, y_test)=load_data()print(X_train.shape)

`` 定义优化器

复制代码
复制代码
    def adam_optimizer():    return Adam(lr=0.0002, beta_1=0.5)

`` 这里要采用的生成对抗网络的结构如下图所示

b4edd7b857dcbdf05d86a0a54a17766a.png

定义生成器:输入是 100 维,经过三层隐藏层,输出 784 维的向量(造假的图片)

复制代码
复制代码
    def create_generator():    generator=Sequential()    generator.add(Dense(units=256,input_dim=100))    generator.add(LeakyReLU(0.2))    generator.add(Dense(units=512))    generator.add(LeakyReLU(0.2))    generator.add(Dense(units=1024))    generator.add(LeakyReLU(0.2))    generator.add(Dense(units=784, activation='tanh'))    generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())    return generatorg=create_generator()g.summary()

定义判别器:其输入包括真实图片或由生成器生成的假图片(784维),通过三层隐含层进行处理后输出类别(1维)。

复制代码
复制代码
    def create_discriminator():    discriminator=Sequential()    discriminator.add(Dense(units=1024,input_dim=784))    discriminator.add(LeakyReLU(0.2))    discriminator.add(Dropout(0.3))    discriminator.add(Dense(units=512))    discriminator.add(LeakyReLU(0.2))    discriminator.add(Dropout(0.3))    discriminator.add(Dense(units=256))    discriminator.add(LeakyReLU(0.2))    discriminator.add(Dense(units=1, activation='sigmoid'))    discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())    return discriminatord =create_discriminator()d.summary()

`` 定义生成对抗网络

复制代码
复制代码
    def create_gan(discriminator, generator):    discriminator.trainable=False    # 这是一个链式模型:输入经过生成器、判别器得到输出    gan_input = Input(shape=(100,))    x = generator(gan_input)    gan_output= discriminator(x)    gan= Model(inputs=gan_input, outputs=gan_output)    gan.compile(loss='binary_crossentropy', optimizer='adam')    return gangan = create_gan(d,g)gan.summary()

`` 定义画图函数来可视化图片的生成

复制代码
复制代码
    def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])    generated_images = generator.predict(noise)    generated_images = generated_images.reshape(100,28,28)    plt.figure(figsize=figsize)    for i in range(generated_images.shape[0]):        plt.subplot(dim[0], dim[1], i+1)        plt.imshow(generated_images[i], interpolation='nearest')        plt.axis('off')    plt.tight_layout()    plt.savefig('gan_generated_image %d.png' %epoch)

`` 生成对抗网络的训练函数

复制代码
复制代码
    def training(epochs=1, batch_size=128):    #导入数据    (X_train, y_train, X_test, y_test) = load_data()    batch_count = X_train.shape[0] / batch_size    # 定义生成器、判别器和GAN网络    generator= create_generator()    discriminator= create_discriminator()    gan = create_gan(discriminator, generator)    for e in range(1,epochs+1 ):        print("Epoch %d" %e)        for _ in tqdm(range(int(batch_count))):            #产生噪声喂给生成器            noise= np.random.normal(0,1, [batch_size, 100])            # 产生假图片            generated_images = generator.predict(noise)            # 一组随机真图片            image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]            # 真假图片拼接             X= np.concatenate([image_batch, generated_images])            # 生成数据和真实数据的标签            y_dis=np.zeros(2*batch_size)            y_dis[:batch_size]=0.9            # 预训练,判别器区分真假            discriminator.trainable=True            discriminator.train_on_batch(X, y_dis)            # 欺骗判别器 生成的图片为真的图片            noise= np.random.normal(0,1, [batch_size, 100])            y_gen = np.ones(batch_size)            # GAN的训练过程中判别器的权重需要固定             discriminator.trainable=False            # GAN的训练过程为交替“训练判别器”和“固定判别器权重训练链式模型”            gan.train_on_batch(noise, y_gen)        if e == 1 or e % 50 == 0:            # 画图 看一下生成器能生成什么            plot_generated_images(e, generator)training(400,256)

`` 经过训练后生成的图片 一个epoch后生成器还是个小学生

5e023aa70057f4a4659284edb498a6ab.png

100个epoch后生成器已经有点样子了

b40b75dd8d7bb3980411f7f933d8c724.png

400个epoch后生成器可以出师了

f76293f35babc4775ddad4a62ef7c7af.png

是不是已经掌握了生成式模型的核心原理?这样就可以借助生成器的机制来创造逼真的图像了。

本文旨在促进学术交流,并不表明本公众号对此持赞同态度;对内容真实性和准确性负责。文章的所有权归原创者所有;如若发生侵权行为,请告知以便及时处理。

c275deb268ac6d1b1b9011e38d1778a1.png

分享、点赞、在看,给个三连击呗!

全部评论 (0)

还没有任何评论哟~