Advertisement

Pytorch 图像数据集增强(训练集扩展)图像提高泛化容错率

阅读量:

在处理有限的图像数据集时,过拟合和泛化能力差等问题是常见的挑战。为了缓解这些问题,可以利用PyTorch的transforms模块对图像进行增强处理以扩展训练集。具体方法包括使用Resize调整图像尺寸、RandomCrop随机裁剪以学习图像特性、RandomHorizontalFlip和RandomVerticalFlip随机翻转以增加数据多样性、RandomInvert颜色倒置器、RandomRotation旋转图片以及ColorJitter、RandomEqualize等随机调整HSV颜色空间的操作。这些方法有助于提升模型的泛化能力并改善训练效果。例如可以通过Compose将多个transform组合起来应用到同一张图片上从而生成多样化的训练样本。

当处理图像数据时,由于数据资源有限,可能会导致模型出现过拟合现象,并且其预测效果可能欠佳。

该研究基于Pytorch 技术对有限规模的图像数据集实施了图像增强处理过程,并旨在显著提升训练数据量。


环境

Pillow 8.4.0
torch 1.8.2+cu111

PyTorch 的 torchvision.transforms 内置了大量用于图像处理的函数,在这些中常用的包括

Resize

复制代码
 from torchvision import transforms as tfs

    
 import matplotlib.pyplot as plt
复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

RandomCrop随机裁剪

RandomCrop有助于对照片进行随机剪切处理。这有助于模型更好地学习图像的特征。此处注意transform的顺序

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomCrop(512),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

RandomHorizontalFlip(p) 随机水平翻转

RandomVerticalFlip(p) 随机垂直翻转

p为发生概率

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomVerticalFlip(0.5),
    
     tfs.RandomHorizontalFlip(0.5),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

RandomInvert(p) 颜色倒置器

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomInvert(0.5),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

对图像进行随机旋转处理后,在原有训练集的基础上能够有效地进行模型拟合

x,y = 旋转角度的范围

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomRotation((0, 360)),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

ColorJitter(brightness=, contrast=, hue=) 随机调整HSV

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomApply(torch.nn.ModuleList([tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)]), p=1),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

RandomEqualize(p) 随机均衡直方图

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomEqualize(1),
    
     tfs.ToTensor(),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

原histogram

RandomEqualize后的histogram


综合使用

复制代码
 im_aug = tfs.Compose([

    
     tfs.Resize(1024),
    
     tfs.RandomCrop(512, padding=2),
    
     tfs.RandomHorizontalFlip(1),
    
     tfs.RandomVerticalFlip(1),
    
     tfs.RandomRotation((0, 360)),
    
     tfs.RandomInvert(1),
    
     tfs.RandomApply(torch.nn.ModuleList([tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)]), p=0.5),
    
     tfs.RandomEqualize(1),
    
     tfs.ToTensor(),
    
     # tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    
     ])
    
  
    
 # --- load original image
    
 im = Image.open('dataset/IMG_4329.jpg')
    
 # --- transform image
    
 x = im_aug(im)

全部评论 (0)

还没有任何评论哟~