Advertisement

【机器学习】机器学习的基本分类-半监督学习-半监督生成对抗网络(Semi-supervised GANs)

阅读量:

半监督生成对抗网络(Semi-supervised GANs,简称 SGAN)是一种结合生成对抗网络(GAN)和半监督学习的模型,能够在有限标注数据和大量未标注数据的情况下训练分类器。它扩展了传统 GAN 的结构,使得判别器不仅仅用于区分真假样本,还用于对标注样本进行分类。


半监督 GAN 的核心思想

生成器 (Generator): * 模仿数据分布,生成与真实数据类似的样本。
* 输入为噪声 z im p
2.

判别器 (Discriminator): * 在传统 GAN 中,判别器只是二分类器,用于区分生成样本和真实样本。
* 在 SGAN 中,判别器被改造成多分类器,除了区分真假样本外,还负责对真实数据进行分类(有监督任务)。
3.

未标注数据的利用: * SGAN 将未标注数据作为“真实样本”,它们在训练中被用来提升生成器的质量以及判别器的分类能力。


半监督 GAN 的结构

判别器的输出

判别器 DDD 的输出层被设计为 K+1 个神经元,其中:

  • 前 K 个神经元表示标注样本的 K 个类别概率。
  • 第 K+1 个神经元表示“生成样本”的概率。
损失函数

分类损失(监督):
athcal{L}{ext{supervised}} = -athbb{E}{ im p_{ext{data}}} og D_y

其中 D_y 表示判别器对类别 y 的预测概率。

生成对抗损失:
athcal{L}{ext{unsupervised}} = -athbb{E}{x im p_{ext{data}}} og - athbb{E}{z im p} og D{K+1}

其中 D_{K+1} 表示判别器认为样本是生成样本的概率。

生成器损失:
athcal{L}G = -athbb{E}{z im p} og

总损失为上述损失的加权和。


半监督 GAN 的实现

以下是一个使用 TensorFlow/Keras 的 SGAN 实现示例:

复制代码
 import numpy as np

    
 from keras.models import Sequential
    
 from keras.layers import Dense, Flatten, Reshape
    
 from keras.layers import Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization
    
 from keras.optimizers import Adam
    
 from keras.losses import CategoricalCrossentropy
    
  
    
 # 超参数
    
 latent_dim = 100  # 随机噪声维度
    
 image_shape = (28, 28, 1)  # 输入图像形状
    
 num_classes = 11  # 类别数
    
  
    
  
    
 # 创建生成器
    
 def build_generator(latent_dim):
    
     model = Sequential([
    
     Dense(128 * 7 * 7, activation='relu', input_dim=latent_dim),
    
     Reshape((7, 7, 128)),
    
     BatchNormalization(),
    
     Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu'),
    
     BatchNormalization(),
    
     Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu'),
    
     BatchNormalization(),
    
     Conv2D(1, kernel_size=7, activation='tanh', padding='same')  # 输出形状为 (28, 28, 1)
    
     ])
    
     return model
    
  
    
  
    
 # 创建判别器
    
 def build_discriminator(image_shape, num_classes):
    
     model = Sequential([
    
     Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=image_shape),
    
     LeakyReLU(alpha=0.2),
    
     Dropout(0.3),
    
     Conv2D(128, kernel_size=3, strides=2, padding='same'),
    
     LeakyReLU(alpha=0.2),
    
     Dropout(0.3),
    
     Flatten(),
    
     Dense(num_classes, activation='softmax')  # 输出类别概率分布
    
     ])
    
     return model
    
  
    
  
    
 # 定义训练过程
    
 def train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled, epochs=10000, batch_size=64):
    
     # 编译判别器
    
     discriminator.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
    
                       loss=CategoricalCrossentropy(),
    
                       metrics=['accuracy'])
    
  
    
     # 构建生成器-判别器联合模型
    
     discriminator.trainable = False
    
     sgan = Sequential([generator, discriminator])
    
     sgan.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss=CategoricalCrossentropy())
    
  
    
     for epoch in range(epochs):
    
     # 生成虚假样本
    
     noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
    
     fake_images = generator.predict(noise)
    
     fake_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]
    
  
    
     # 训练生成器
    
     generator_loss = sgan.train_on_batch(noise, fake_labels)
    
  
    
     # 训练判别器
    
     idx = np.random.randint(0, X_unlabeled.shape[0], batch_size)
    
     real_images = X_unlabeled[idx]
    
     real_labels = np.eye(num_classes)[np.random.choice(num_classes, size=batch_size)]
    
  
    
     d_loss_real = discriminator.train_on_batch(real_images, real_labels)
    
     d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
    
  
    
     d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    
  
    
     # 打印日志
    
     if epoch % 100 == 0:
    
         print(
    
             f"Epoch {epoch}/{epochs} | D Loss: {d_loss[0]:.4f}, D Accuracy: {d_loss[1]:.4f} | G Loss: {generator_loss:.4f}")
    
  
    
  
    
 # 数据加载示例(使用 MNIST 数据)
    
 from keras.datasets import mnist
    
 from keras.utils import to_categorical
    
  
    
 (X_train, y_train), (_, _) = mnist.load_data()
    
 X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # 归一化到 [-1, 1]
    
 X_train = np.expand_dims(X_train, axis=-1)  # 转换为 (N, 28, 28, 1) 格式
    
 y_train = to_categorical(y_train, num_classes=num_classes)
    
  
    
 # 拆分有标签和无标签数据
    
 X_labeled = X_train[:1000]
    
 y_labeled = y_train[:1000]
    
 X_unlabeled = X_train[1000:]
    
  
    
 # 初始化模型
    
 generator = build_generator(latent_dim)
    
 discriminator = build_discriminator(image_shape, num_classes)
    
  
    
 # 训练 SGAN
    
 train_sgan(generator, discriminator, latent_dim, X_labeled, y_labeled, X_unlabeled)
    
    
    
    

部分结果

复制代码
 2/2 [==============================] - 0s 21ms/step

    
 Epoch 0/10000 | D Loss: 2.4036, D Accuracy: 0.0859 | G Loss: 2.4005
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 34ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 33ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 24ms/step
    
 2/2 [==============================] - 0s 31ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 33ms/step
    
 2/2 [==============================] - 0s 24ms/step
    
 2/2 [==============================] - 0s 23ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 25ms/step
    
 2/2 [==============================] - 0s 27ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 34ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 35ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 21ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 24ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 36ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 39ms/step
    
 2/2 [==============================] - 0s 22ms/step
    
 2/2 [==============================] - 0s 28ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 23ms/step
    
 2/2 [==============================] - 0s 22ms/step
    
 2/2 [==============================] - 0s 24ms/step
    
 2/2 [==============================] - 0s 55ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 23ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 22ms/step
    
 2/2 [==============================] - 0s 33ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 21ms/step
    
 2/2 [==============================] - 0s 30ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 34ms/step
    
 2/2 [==============================] - 0s 26ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 34ms/step
    
 2/2 [==============================] - 0s 47ms/step
    
 2/2 [==============================] - 0s 29ms/step
    
 2/2 [==============================] - 0s 18ms/step
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 2/2 [==============================] - 0s 31ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
 Epoch 100/10000 | D Loss: 2.3941, D Accuracy: 0.1250 | G Loss: 2.3827
    
 2/2 [==============================] - 0s 20ms/step
    
 2/2 [==============================] - 0s 19ms/step
    
    
    
    

总结

半监督 GAN 的核心在于将判别器扩展为多分类器,充分利用未标注数据和生成样本的对抗训练,提升分类器性能。相比传统的 GAN 和全监督学习方法,SGAN 能在标注数据不足的情况下取得更好的分类效果。

全部评论 (0)

还没有任何评论哟~