Advertisement

对抗生成网络_生成对抗网络 | 实验

阅读量:

上期我们介绍了

生成对抗网络 | 原理及训练过程

同样地,我们依旧通过实验来巩固我们刚刚所学的知识点。本次实验是基于Jupyer Notebook、Anaconda Python3.7与Keras环境。数据集是利用Minst手写体图像数据集。
8be2d3f8adbe30b8c5356f27fcd465c7.png5.3.1 代码

复制代码
  # chapter5/5_3_GAN.ipynb2.  import random  3.  import numpy as np  4.  from keras.layers import Input  5.  from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten  6.  from keras.layers.advanced_activations import LeakyReLU  7.  from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D, Deconv2D, UpSampling2D  8.  from keras.regularizers import *  9.  from keras.layers.normalization import *  10.  from keras.optimizers import *  11.  from keras.datasets import mnist  12.  import matplotlib.pyplot as plt  13.  from keras.models import Model  14.  from tqdm import tqdm  15.  from IPython import display  
    

1. 读取数据集

复制代码
  img_rows, img_cols = 28, 28  2.    3.  # 数据集的切分与混洗(shuffle) 4.  (X_train, y_train), (X_test, y_test) = mnist.load_data()  5.    6.  X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)  7.  X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)  8.  X_train = X_train.astype('float32')  9.  X_test = X_test.astype('float32')  10.  X_train /= 255  11.  X_test /= 255  12.    13.  print(np.min(X_train), np.max(X_train))  14.  print('X_train shape:', X_train.shape)  15.  print(X_train.shape[0], 'train samples')  16.  print(X_test.shape[0], 'test samples')  
    

0.0 1.0

X_train shape: (60000, 1, 28,28)

60000 train samples

10000 test samples

2. 超参数设置

复制代码
  shp = X_train.shape[1:]  2.  dropout_rate = 0.25  3.    4.  # 优化器  5.  opt = Adam(lr=1e-4)  6.  dopt = Adam(lr=1e-5)  
    

3. 定义生成器

复制代码
  K.set_image_dim_ordering('th')  # 用theano的图片输入顺序  2.  # 生成1 * 28 * 28的图片  3.  nch = 200  4.  g_input = Input(shape=[100])  5.  H = Dense(nch*14*14, kernel_initializer='glorot_normal')(g_input)  6.  H = BatchNormalization()(H)  7.  H = Activation('relu')(H)  8.  H = Reshape( [nch, 14, 14] )(H)  # 转成200 * 14 * 14  9.  H = UpSampling2D(size=(2, 2))(H)  10.  H = Convolution2D(100, (3, 3), padding="same", kernel_initializer='glorot_normal')(H)  11.  H = BatchNormalization()(H)  12.  H = Activation('relu')(H)  13.  H = Convolution2D(50, (3, 3), padding="same", kernel_initializer='glorot_normal')(H)  14.  H = BatchNormalization()(H)  15.  H = Activation('relu')(H)  16.  H = Convolution2D(1, (1, 1), padding="same", kernel_initializer='glorot_normal')(H)  17.  g_V = Activation('sigmoid')(H)  18.  generator = Model(g_input,g_V)  19.  generator.compile(loss='binary_crossentropy', optimizer=opt)  20.  generator.summary()  
    

4. 定义辨别器

复制代码
  # 辨别是否来自真实训练集  2.  d_input = Input(shape=shp)  3.  H = Convolution2D(256, (5, 5), activation="relu", strides=(2, 2), padding="same")(d_input)  4.  H = LeakyReLU(0.2)(H)  5.  H = Dropout(dropout_rate)(H)  6.  H = Convolution2D(512, (5, 5), activation="relu", strides=(2, 2), padding="same")(H)  7.  H = LeakyReLU(0.2)(H)  8.  H = Dropout(dropout_rate)(H)  9.  H = Flatten()(H)  10.  H = Dense(256)(H)  11.  H = LeakyReLU(0.2)(H)  12.  H = Dropout(dropout_rate)(H)  13.  d_V = Dense(2,activation='softmax')(H)  14.  discriminator = Model(d_input,d_V)  15.  discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)  16.  discriminator.summary()  
    

5. 构造生成对抗网络

复制代码
  # 冷冻训练层  2.  def make_trainable(net, val):  3.      net.trainable = val  4.      for l in net.layers:  5.          l.trainable = val  6.  make_trainable(discriminator, False)  7.    8.  # 构造GAN  9.  gan_input = Input(shape=[100])  10.  H = generator(gan_input)  11.  gan_V = discriminator(H)  12.  GAN = Model(gan_input, gan_V)  13.  GAN.compile(loss='categorical_crossentropy', optimizer=opt)  14.  GAN.summary()
    
93ad892ad51b0c9518da370b14ba3d38.png

6. 训练

复制代码
  # 描绘损失收敛过程  2.  def plot_loss(losses):  3.          display.clear_output(wait=True)  4.          display.display(plt.gcf())  5.          plt.figure(figsize=(10,8))  6.          plt.plot(losses["d"], label='discriminitive loss')  7.          plt.plot(losses["g"], label='generative loss')  8.          plt.legend()  9.          plt.show()  10.            11.            12.  #  描绘生成器生成图像          13.  def plot_gen(n_ex=16,dim=(4,4), figsize=(10,10) ):  14.      noise = np.random.uniform(0,1,size=[n_ex,100])  15.      generated_images = generator.predict(noise)  16.    17.      plt.figure(figsize=figsize)  18.      for i in range(generated_images.shape[0]):  19.          plt.subplot(dim[0],dim[1],i+1)  20.          img = generated_images[i,0,:,:]  21.          plt.imshow(img)  22.          plt.axis('off')  23.      plt.tight_layout()  24.      plt.show()  25.    26.  # 抽取训练集样本  27.  ntrain = 10000  28.  trainidx = random.sample(range(0,X_train.shape[0]), ntrain)  29.  XT = X_train[trainidx,:,:,:]    30.    31.  # 预训练辨别器  32.  noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100])  33.  generated_images = generator.predict(noise_gen)  # 生成器产生样本  34.  X = np.concatenate((XT, generated_images))    35.  n = XT.shape[0]  36.  y = np.zeros([2*n,2])  # 构造辨别器标签 one-hot encode  37.  y[:n,1] = 1  38.  y[n:,0] = 1  39.    40.  make_trainable(discriminator,True)  41.  discriminator.fit(X,y, epochs=1, batch_size=32)  42.  y_hat = discriminator.predict(X)  
    
78c20bf71f5a677a4bd2894e21ea3290.png
复制代码
  #  计算辨别器的准确率  2.  y_hat_idx = np.argmax(y_hat,axis=1)  3.  y_idx = np.argmax(y,axis=1)  4.  diff = y_idx-y_hat_idx  5.  n_total = y.shape[0]  6.  n_right = (diff==0).sum()  7.    8.  print( "(%d of %d) right"  % (n_right, n_total)) 
    
89f11f8a4d238a37ad37f282bd0ef68a.png
复制代码
  def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32):  2.      for e in tqdm(range(nb_epoch)):    3.            4.          # 生成器生成样本  5.          image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:]      6.          noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100])  7.          generated_images = generator.predict(noise_gen)  8.            9.          # 训练辨别器  10.          X = np.concatenate((image_batch, generated_images))  11.          y = np.zeros([2*BATCH_SIZE,2])  12.          y[0:BATCH_SIZE,1] = 1  13.          y[BATCH_SIZE:,0] = 1  14.            15.          # 存储辨别器损失loss  16.          make_trainable(discriminator,True)  17.          d_loss  = discriminator.train_on_batch(X,y)  18.          losses["d"].append(d_loss)    19.        20.          # 生成器生成样本  21.          noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100])  22.          y2 = np.zeros([BATCH_SIZE,2])  23.          y2[:,1] = 1  24.            25.          # 存储生成器损失loss  26.          make_trainable(discriminator,False)  # 辨别器的训练关掉  27.          g_loss = GAN.train_on_batch(noise_tr, y2)  28.          losses["g"].append(g_loss)  29.            30.          # 更新损失loss图  31.          if e%plt_frq == plt_frq-1:  32.              plot_loss(losses)  33.              plot_gen()  34.  train_for_n(nb_epoch=1000, plt_frq=10,BATCH_SIZE=128)  
    
c647b8cd13146b8b56a277b22bbd3192.png

5.3.2 结果分析

从模型输出的loss我们可以知道生成器与辨别器两者拟合的loss并不是特别地好,因此我们可以通过调参来解决。主要调参方向有以下四点:

1. batch size

2. adam优化器的learning rate

3. 迭代次数nb_epoch

4. 生成器generator和辨别器discriminator的网络结构

5.4 小结

好了,到这里,我们就已经将生成对抗网络(GAN)的知识点讲完了。大家在掌握了整个流程之后,就可以将笔者的代码修改成自己所需要的场景,进而训练自己的GAN模型了。

最后,笔者在本章介绍的GAN只是2014年的开山之作,后面有很多人基于GAN提出了许多有趣的实验,但是所用的网络原理都差不多,这里就不一一赘述了。而且GAN的应用范围非常广阔,比如市面上很火的“换脸”软件,大多都是基于GAN的原理去做的。甚至我们也可以利用GAN去做数据增强,比如在我们缺少训练集的时候,可以考虑用GAN去生成一些数据,扩充我们的训练样本。
8be2d3f8adbe30b8c5356f27fcd465c7.png

下一期,我们将讲授

李宏毅老师的无监督学习讲座PPT总结

敬请期待~
de613a5ce3a7155ddf1369bd08f2ef21.gif

关注我的微信公众号不定期更新相关专业知识
f6d13fe44229cb200fb3ed535e811035.png
00878f9ef29b04c77aebc41963ef0950.gif

内容 |阿力阿哩哩

编辑 | 阿璃
a501717dc0db57c1d3940077086073a3.png点个“在看”,作者高产似那啥~8b67e3b233abcaa6944be59e0adc1dfc.gif

全部评论 (0)

还没有任何评论哟~