Advertisement

pytorch深度学习:CNN卷积神经网络(三)

阅读量:

最近两期的博客内容中,我深入探讨了卷积神经网络的基础知识,并详细研究了其在数字识别任务中的应用。具体的学习内容可参考前几篇文章中的详细说明:

mnist仅包含单个通道的数据集,在执行卷积操作时使用的滤波器具有二维平面形状;然而,在实际应用中我们通常处理的是彩色图像这类多通道场景;因此具有多个通道值;此时滤波器呈现出类似立方体的结构特征;通过查看下面的图形能够更好地理解这一概念。

在这里插入图片描述

1.问题的提出

这节课我们采用另一个著名的数据集CIFAR-10来进行彩色图像分类任务。CIFAR-10常被用作标准的彩色图像数据集,在该集合中共包含十类物体:airplane、automobile、bird、cat、deer、dog、frog、horse、ship和trunk。每张图片均为3×32×32像素大小,在通道上采用RGB编码(红绿蓝),其分辨率保持在标准水平即为32×32像素。

我们挑选一些图片进行展示。今天的任务就是利用CNN实现彩色图像的10分类目标。

在这里插入图片描述

2.数据处理

数据集的准备仍然依赖于 torchvision 库以获取数据集。初次获取可能会稍显迟缓;若已准备好,则可通过 root 参数进行指定。

仍然继续将数据划分为训练集与测试集,并对这些手据实施归一化处理的同时,则采用了批量数据处理的方式来进行模型训练;见程序:

复制代码
    # 定义对数据的预处理
    transform = transforms.Compose([
       transforms.ToTensor(),  # 转为Tensor
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化
    ])
    
    # 训练集
    train_data = tv.datasets.CIFAR10(
        root='C:\ Users\ hurunqiao\ Downloads',
        train=True,
        download=True,
        transform=transform
    )
    	
    train_loader = torch.utils.data.DataLoader(
        dataset=train_data,
        batch_size=Batch_size,
        shuffle=True,
        num_workers=2
    )
    
    # 测试集
    test_data = tv.datasets.CIFAR10(
        'C:\ Users\ hurunqiao\ Downloads',
        train=False,
        download=True,
        transform=transform
    )
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_data,
        batch_size=Batch_size,
        shuffle=False,
        num_workers=2
    )
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

还要定义些超参数:Epoch(训练轮数)、Batch_size(批数据处理个数)、LR(学习率)

复制代码
    # 一些超参数的定义
    Epoch = 3
    Batch_size = 4
    LR = 0.01
    
    
      
      
      
      
    
    代码解读

3.CNN网络的搭建

我们这次搭建的CNN网络包含两卷积模块、两池化模块和两全连接模块 。它们的具体参数配置如下:

  • in_channels: 表示输入图像的颜色通道数量(彩色图像为3)。
  • out_channels: 表示输出的颜色通道数量(即卷积核的数量)。
  • kernel_size: 表示卷积核的尺寸(如一个5×5的方阵)。
  • stride: 即移动步长的距离。
  • padding: 在stride=1的情况下, 如果希望经过卷积后图片的高度和宽度保持不变, 则需要对padding进行设置, 其值等于(kernelsize−1)/2.

- 池化层:
使用的是Max Pooling,我们只需要设置池化层Filter的大小。

- 全连接层:

  • 输入层单元数: 处理后数据展平为一维向量时的单元数量。
    • 隐藏层数: 隐藏层的数量:可自由决定。
    • 输出单元层数: 输出层单元数量:分为十类。

代码如下:

复制代码
    class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.fc1 = nn.Linear(32*8*8,64)
        self.fc2 = nn.Linear(64,10)
    
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
    
        return x
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

4.数据维度的变化

以CIFAR-10数据集中的样本作为输入数据的基础之上,在每个样本中其尺寸配置为(3, 32, 32)。其中第一个数字表示通道数量(channel),后两个数字分别对应图片的高度和宽度。

在经过第一层卷积层后,在经过第一层卷积层后

经过池化层处理后,数据维度转换为(16,16,16)。高度保持不变;长宽均被压缩一半长度(这是因为池化核的尺寸设定为2\times 2)。

再经过一层卷积层,数据维度变为(32,16,16),32是由这层的卷积核个数决定。

再经过池化层,数据格式变为(32,8,8)

将其展成一维即为全连接层的输入数据维度。
总结整个网络的数据维度变化过程。

复制代码
    (3,32,32)-->(16,32,32)-->(16,16,16)-->(32,16,16)-->(32,8,8)-->32*8*8
    
    
      
    
    代码解读

5.结果展示

在训练过程中存在单一化现象,在此不做详细展示。我们可以观察网络训练的结果,在经过10^3组测试数据后发现实验结果显示准确率达到65%,明显优于随机分类方法(仅达10%),这表明CNN网络训练模型具有良好的效果。

从这4个样本中进行结果展示,在表格中可以看到其中一行展示了模型的实际分类情况,在另一行则显示了预测的结果

复制代码
    实际分类:    cat  ship  ship plane
    预测结果:    cat   car  ship plane
    
    
      
      
    
    代码解读

完整代码放在了下边:

复制代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision as tv
    import torchvision.transforms as transforms
    from torchvision.transforms import ToPILImage
    
    class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.fc1 = nn.Linear(32*8*8,64)
        self.fc2 = nn.Linear(64,10)
    
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(x.size(0),-1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
    
        return x
    
    
    if __name__ == '__main__':
    
    show = ToPILImage()  # 可以将Tensor转成Image,方便可视化
    
    # 一些超参数的定义
    Epoch = 3
    Batch_size = 4
    LR = 0.01
    
    # 定义对数据的预处理
    transform = transforms.Compose([
        transforms.ToTensor(),  # 转为Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化
    ])
    
    # 训练集
    train_data = tv.datasets.CIFAR10(
        root='C:\ Users\ hurunqiao\ Downloads',
        train=True,
        download=True,
        transform=transform
    )
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_data,
        batch_size=Batch_size,
        shuffle=True,
        num_workers=2
    )
    
    test_data = tv.datasets.CIFAR10(
        'C:\ Users\ hurunqiao\ Downloads',
        train=False,
        download=True,
        transform=transform
    )
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_data,
        batch_size=Batch_size,
        shuffle=False,
        num_workers=2
    )
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    (data, label) = train_data[100]
    print(classes[label])
    
    # 显示一下示例图片
    # (data+1)是为了还原被归一化的数据
    show((data + 1) / 2).resize((100, 100)).show()
    
    net = CNN()
    optimizer = torch.optim.SGD(net.parameters(), lr=LR)
    loss_func = nn.CrossEntropyLoss()
    
    for epoch in range(Epoch):
        running_loss = 0.0
        for step, (b_x, b_y) in enumerate(train_loader):
    
            prediction = net(b_x)
            loss = loss_func(prediction, b_y)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            if step % 2000 == 1999:  # 每2000个batch打印一次训练状态
                print('[%d, %5d] loss: %.3f' \
                      % (epoch + 1, step + 1, running_loss / 2000))
                running_loss = 0.0
    
    ## 检验网络效果
    dataiter = iter(test_loader)
    images, labels = dataiter.next()  # 一个batch返回4张图片
    print('实际分类: ', ' '.join( \
        '%05s' % classes[labels[j]] for j in range(4)))
    show(tv.utils.make_grid(images / 2 - 0.5)).resize((400, 100)).show()
    
    # 计算网络预测的label
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    print('预测结果: ', ' '.join('%5s' \
                             % classes[predicted[j]] for j in range(4)))
    # 测试准确率
    correct = 0
    total = 0
    for data in test_loader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
    
    print('1000张测试集中的准确率为: %d  %%' % (100 * correct / total))
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

全部评论 (0)

还没有任何评论哟~