pytorch数据增强与数据读取

数据增强与数据读取

数据增强

为什么要进行数据增强

深度学习模型的参数很多,模型复杂度很高,如果此时数据集数量不够导致数据集中数据的复杂度没有涵盖所有特征空间,那么模型就会学习到这些数据集的一些独有的特征,这会导致过拟合的问题,因此我们有必要对数据集进行增强操作以减小过拟合的风险。

利用pytorch进行数据增强

在pytorch的torchvision库中提供了很多的数据增强方法,主要可以分为四类:

  • 裁剪
  • 翻转和旋转
  • 图像变换
  • 对transform操作使得数据增强更加灵活
    具体的函数可以参考pytorch官网:torchvision
    若对于图片的增强方法不止一种,我们需要通过torchvision.transform.Compose将几种方法聚合到一起,作为一个属性去初始化我们的数据集。

数据增强并不是真的增多了样本的数量,因为在训练的过程中,数据增强函数会随机对图片进行处理,比如对图片进行裁剪,将图片翻转,这样在多训练几轮之后就相当于增加了原本不属于数据集的图片,也就是实现了数据增强。

除了pytorch自带的torchvision,还有几个别的库也可以用:
imgaug
albumentations

数据读取

利用pytorch进行数据读取分为两步,一是定义自己的dataset,而是利用dataset形成dataloader以便于我们后续的训练、验证与测试。

定义自己的dataset

The most important argument of DataLoader constructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

  • map-style datasets,
  • iterable-style datasets.

这是官网对于数据集类型的表述,dataset是DataLoader最重要的参数,可以分为两类:映射型与迭代型。笔者在视觉类任务中见到的大部分所用的类型是映射型。
一般自定义的数据类需要遵循下面的框架

class myDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):

    def __getitem__(self, index):

    def __len__(self):

自定义的dataset继承自torch.utils.data.dataset,对于映射型,我们需要重写__getitem__()__len__()函数,__getitem__()用于返回给定索引对应的图片与标注,__len__()返回数据集的长度。

  • 重写两个函数时注意代码缩进与对齐,否则可能会出现意想不到的错误。
  • 在进行dataset的构建之前检查一下图片数据集的长度与标注的长度是否相等,因为有可能出现数据缺失的情况,这样后面在进行数据集加载时可能会出现越界的错误。

在自定义的dataet里面还有一个__init__()函数,这用于初始化我们所需要的一些属性,主要是图片的路径信息,图片的标注信息以及图片的预处理方法。图片的路径信息以及标注信息通常都是以列表作为数据结构,然后直接通过index进行相应信息的读取。同时要注意,pytorch中读取图片必须通过PIL来进行而不能通过opencv,如果需要利用opencv中的一些方法对图片进行处理,则需要先对图片类型进行转换。

形成dataloader

完成dataset之后我们就需要构建我们的dataloader了,这是为了我们后面进行训练,将整体数据集划分为几个批次,这样既能保证训练速度,也能较好地收敛,形成dataloader主要通过这个函数:
torch.utils.data.DataLoader
主要有四个参数:

  • dataset:自定义的数据集
  • batch_size:每次训练时读取的样本
  • shuffle:是否将数据集打乱
  • num_works:并行数量
    batch_size为每次训练时读取的样本,通常我们将其定为2的幂指数,同时我们需要将训练集进行打乱,验证集和测试集一般不需要打乱。
train_loader = torch.utils.data.DataLoader(
        SVHNDataset(train_path, train_label,
                   transforms.Compose([
                       transforms.Resize((64, 128)),
                       transforms.ColorJitter(0.3, 0.3, 0.2),
                       transforms.RandomRotation(5),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])), 
    batch_size=10, # 每批样本个数
    shuffle=False, # 是否打乱顺序
    num_workers=10, # 读取的线程个数
)

像这样,我们就完成了dataloader的构建。

你可能感兴趣的:(pytorch,比赛,pytorch,深度学习,神经网络)