Advertisement

365天深度学习训练营-第T10周:数据增强

阅读量:

我的环境:

  • 语言环境:Python3.11.2
  • 编译器:PyCharm Community Edition 2022.3
  • 深度学习环境:TensorFlow2

一、设置数据

1.1 获取数据集

先初步导入、设置数据

复制代码
 import tensorflow as tf

    
 import matplotlib.pyplot as plt
    
 from tensorflow.keras import layers,models
    
  
    
 data = 'F:/365-7-data'
    
  
    
 img_height = 224
    
 img_width = 224
    
 batch_size = 32
    
  
    
 train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    
     data,
    
     validation_split=0.3,
    
     subset='training',
    
     seed=123,
    
     image_size=(img_height,img_width),
    
     batch_size=batch_size
    
 )
    
 val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    
     data,
    
     validation_split=0.3,
    
     subset='validation',
    
     seed=123,
    
     image_size=(img_height,img_width),
    
     batch_size=batch_size
    
 )
    
    
    
    
    AI写代码

本次导入路径没有使用pathlib库而是直接导入。

与直接使用文件路径相比,使用pathlib库操作文件路径有以下几个优点:

1. 跨平台兼容性更好:pathlib.Path 可以自动适应不同操作系统的文件路径分隔符,无需手动处理。

2. 更加安全:使用 pathlib.Path 可以避免一些常见的路径操作错误,例如路径拼接时忘记添加分隔符、路径中包含特殊字符等。

3. 更加可读性高:使用 pathlib.Path 可以更加清晰地表达路径的含义,例如 pathlib.Path('/home/user/data') 比字符串 '/home/user/data' 更加易读。

4. 更加灵活:pathlib.Path 提供了丰富的方法和属性,可以方便地进行路径操作,例如获取文件名、扩展名、父目录等。

1.2 获取测试集

**** 本次的数据中没有测试集,从验证集中抽取一部分作为测试集。

使用tf.data.experimental.cardinality获取验证集数据的数量。它返回一个 tf.data.experimental.Cardinality 对象,该对象包含了数据集的元素数量信息。

使用take方法获取数据,使用skip方法将测试集数据移出验证集。

复制代码
 val_batches = tf.data.experimental.cardinality(val_ds)

    
 #取整
    
 test_ds = val_ds.take(val_batches//5)
    
 val_ds = val_ds.skip(val_batches//5)
    
  
    
 print('%d'%tf.data.experimental.cardinality(val_ds))
    
 print('%d'%tf.data.experimental.cardinality(test_ds))
    
    
    
    
    AI写代码

1.3 继续配置

复制代码
 AUTOTUNE = tf.data.AUTOTUNE

    
 #归一化
    
 def preprocessing_image(image,label):
    
     return (image/255.0,label)
    
  
    
 train_ds = train_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)
    
 val_ds = val_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)
    
 test_ds = test_ds.map(preprocessing_image,num_parallel_calls=AUTOTUNE)
    
  
    
 train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    
 val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    
    
    
    
    AI写代码

这里的map方法是TensorFlow中的,而非python自带的。tf.data.Dataset对象的map()方法可以用于对数据集中的每个元素应用一个函数,返回一个新的数据集。

二、数据增强

使用tf.keras.layers.experimental.preprocessing.RandomFlip与tf.keras.layers.experimental.preprocessing.RandomRotation进行数据增强。

前者将图像在水平和垂直方向上随机反转,或者这是随机反转图像。

复制代码
 data_au = tf.keras.Sequential([

    
     tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal_and_vertical'),
    
     tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
    
 ])
    
    
    
    
    AI写代码

前者表示水平与垂直随机反转,后者为按照0.2的弧度随机进行反转。

tf.keras.Sequentialtf.keras 模块中的一个类,用于创建顺序模型。它是 tf.keras.models.Sequential 的子类。

复制代码
 for images,label in train_ds.take(1):

    
     for i in range(9):
    
     image = tf.expand_dims(images[i],0)
    
     aug = data_au(image)
    
     ax = plt.subplot(3,3,i+1)
    
     plt.imshow(aug[0])
    
     plt.axis('off')
    
 plt.show()
    
    
    
    
    AI写代码

增强数据可以放在model中,这样在模型进行训练(fit)时,GPU便会帮助加速增强数据。

复制代码
 model = tf.keras.Sequential([

    
     data_au,
    
     layers.Conv2D(64,3,activation='relu'),
    
     layers.Dense(64)
    
 ])
    
    
    
    
    AI写代码

因为data_au是通过tf.keras.Sequential定义的,所有也相当于一个模型,这里直接将模型作为网络层添加到新模型中。

也可以在数据集中使用map进行增强。

复制代码
    val_ds = val_ds.map(lambda x,y: (data_au(x,training=False),y,num_parallel_calls=AUTOTUNE)
    
    AI写代码

num_parallel_calls 是一个用于控制数据预处理并行度的参数。在数据预处理过程中,通常需要进行一些图像增强、数据标准化、数据裁剪等操作,这些操作可能会比较耗时。为了加快数据预处理的速度,可以使用 TensorFlow 的 tf.data.Dataset.map() 函数来对数据进行预处理,并通过 num_parallel_calls 参数来指定并行处理的线程数。

三、训练模型

复制代码
 model = tf.keras.Sequential([

    
     layers.Conv2D(16,3,padding='same',activation='relu'),
    
     layers.MaxPooling2D(),
    
     layers.Conv2D(16,3,padding='same',activation='relu'),
    
     layers.Flatten(),
    
     layers.Dense(32,activation='relu'),
    
     layers.Dense(len(class_name))
    
 ])
    
  
    
 model.compile(
    
     optimizer='adam',
    
     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    
     metrics=['accuracy']
    
 )
    
  
    
 epochs = 10
    
 history = model.fit(
    
     train_ds,
    
     validation_data=val_ds,
    
     epochs=epochs
    
 )
    
    
    
    
    AI写代码

四、总结

本次学习了数据集的增强方式,使用tf.keras.layers.experimental.preprocessing.RandomFlip与tf.keras.layers.experimental.preprocessing.RandomRotation进行数据增强。并在model或使用数据集的map方法来增强数据。有助于提高模型的准确率。

全部评论 (0)

还没有任何评论哟~