Advertisement

CIFAR-10图像识别TensorFlow 2 实现

阅读量:

1、下载数据集

CIFAR-10是一个用于识别普适物 体的小型数据集,它包含了10个类 别的RGB彩色图片

复制代码
 import matplotlib.pyplot as plt

    
 import tensorflow as tf
    
 from tensorflow import keras
    
 from tensorflow.keras import layers, models
    
 import numpy as np
    
  
    
 cifar10=tf.keras.datasets.cifar10
    
 (Xtrain, Ytrain), (Xtest, Ytest) = cifar10.load_data()

2、数据预处理

复制代码
 Xtrain_normalize=Xtrain.astype("float32")/255.0

    
 Xtest_normalize=Xtrain.astype("float32")/255.0
    
  
    
 Ytrain_ohe=keras.utils.to_categorical(Ytrain)
    
 Ytest_ohe=keras.utils.to_categorical(Ytest)

3、建立卷神经网络CNN模型

图像的特征提取:通过卷积层1,降采样层1,卷积层2以及降采样层2的处理,提取图像的特征

全连接神经网络:全连接层、输出层所组成的网络结构

复制代码
 model = tf.keras.models.Sequential()

    
  
    
 model.add(layers.Conv2D(filters=32, 
    
                     kernel_size=(3, 3), 
    
                     input_shape=(32, 32, 3),
    
                     activation='relu',
    
                     padding='same'))
    
  
    
  
    
 model.add(tf.keras.layers.Dropout(rate=0.3))
    
  
    
 model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
    
  
    
 model.add(tf.keras.layers.Conv2D(filters=64, 
    
                               kernel_size=(3, 3), 
    
                               activation= 'relu',
    
                               padding='same'))
    
  
    
  
    
 model.add(tf.keras.layers.Dropout(rate=0.3))
    
  
    
 model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2)))
    
  
    
 model.add(tf.keras.layers.Flatten())
    
  
    
 model.add(tf.keras.layers.Dense(10,activation='softmax'))

4、模型摘要

复制代码
    model.summary()

5、设置模型训练超参数

复制代码
 train_epochs=5

    
 batch_size=100

6、设置模型训练模式

复制代码
 model.compile(optimizer='adam',

    
           loss='categorical_crossentropy',
    
           metrics=['accuracy'])

7、模型训练

复制代码
 model.fit(Xtrain_normalize,Ytrain_ohe,

    
                 validation_split = 0.2,
    
                 epochs = train_epochs,
    
                 batch_size=batch_size,
    
                 verbose = 2)

8、评估模型及预测

复制代码
 # 评估模型

    
 test_loss, test_acc = model.evaluate(Xtest_normalize, Ytest_ohe)
    
 print('Test accuracy:', test_acc)
    
  
    
 # 进行预测
    
 predictions = model.predict(Xtest_normalize)
    
  
    
 # 定义标签字典 每一个数字所代表的图像类别的名称
    
 label_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 
    
           5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

9、可视化预测结果

复制代码
 # 定义显示图像数据及其对应标签的函数

    
 def plot_images_labels_prediction(images,      # 图像列表
    
                               labels,      # 标签列表
    
                               prediction,  # 预测值列表
    
                               index,       # 从第index个开始显示
    
                               num = 5 ):   # 缺省一次显示5幅
    
     fig = plt.gcf() # 获取当前图表,Get Current Figure
    
     fig.set_size_inches(12, 6)  # 1英寸等于 2.54 cm
    
     if num > 10: 
    
     num = 10            # 最多显示10个子图
    
     for i in range(0, num):
    
     ax = plt.subplot(2, 5, i + 1) # 获取当前要处理的子图
    
     
    
     ax.imshow(images[index],  # 显示第index个图像
    
               cmap = 'binary')
    
         
    
     title = str(i) + ',' + label_dict[np.argmax(labels[index])] # 构建该图上要显示的title信息
    
     if len(prediction) > 0:
    
         title += ' => ' + label_dict[np.argmax(predictions[index])]
    
         
    
     ax.set_title(title,fontsize = 10)   # 显示图上的title信息        
    
     index += 1 
    
     plt.show()
    
  
    
 plot_images_labels_prediction(Xtest_normalize,
    
                           Ytest_ohe,
    
                           predictions,0,10)

全部评论 (0)

还没有任何评论哟~