Pytorch系列之——数据读取机制&Transforms

DataLoader与DataSet

  • 人民币二分类
  • DataLoader与DataSet

Pytorch系列之——数据读取机制&Transforms_第1张图片根据pytorch张量部分中提到的机器学习模型训练的五大步骤:数据、模型、损失函数、优化器和迭代训练。我们这部分主要介绍模型训练的第一个步骤:数据模块。其中,数据模块通常进一步划分为四个子模块,分别为:数据收集、数据划分、数据读取和数据预处理。
Pytorch系列之——数据读取机制&Transforms_第2张图片这次主要介绍的部分是数据模块中的数据读取子模块DataLoader,其中DataLoader还可以划分为Sampler和DataSet,Sampler是用来生成索引,而DataSet是根据索引获取Img和Label。

DataLoader

torch.utils.data.DataLoader:
功能:构建可迭代的数据装载器
Pytorch系列之——数据读取机制&Transforms_第3张图片

  • dataset:Dataset类,决定数据从哪儿读取及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
    接下来看看在模型训练中必不可少的epoch、Iteration和Batchsize三者都代表什么:
  • Epoch:所有训练样本都已经输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

比如,我们有样本总数为80,其中设置Batchsize为8,那么有:1 Epoch=10 Iteration;如果样本总数为87,这时其无法被Batchsize整除,那么有如下两种方式解决:
(1)1 Epoch=10 Iteration 需要设置drop_last=True(余下的7个样本将被忽略)
(2)1 Epoch=11 Iteration 需要设置drop_last=False

DataSet

torch.utils.data.DataSet:
功能:DataSet抽象类,所有自定义的DataSet需要继承它,并且复写__getitem__(),其中__getitem__()的作用是接收一个索引,返回一个样本。
Pytorch系列之——数据读取机制&Transforms_第4张图片在通过人民币二分类问题进行学习之前,我们需要考虑三个问题:

  • 读哪些数据?
  • 从哪读数据?
  • 怎么读数据?

首先我们是先将人民币图像压缩包文件进行解压,之后在通过python的os模块对图像进行读取,这个过程中还需对数据集进行划分,将数据集进一步划分为训练集、验证集和测试集(train、valid、test),这也就解决了数据读取的第一个问题,即读哪些数据。之后我们再根据划分好的训练集、验证集和测试集依次对人民币图像数据进行读取,这部分是解决了数据读取的第二个问题,即从哪读数据。
在这里插入图片描述最主要的部分是构建MyDataset实例,这里的RMBDataset是需要我们自己构建的,其中传入的两个参数分别为读取数据的文件路径,以及用于数据预处理的transform,这个transform之后会详细介绍。
Pytorch系列之——数据读取机制&Transforms_第5张图片按住Ctrl键再单击RMBDataset可以使程序跳转到具体的实现类:
Pytorch系列之——数据读取机制&Transforms_第6张图片可以看到,在RMBDataset实现类里有用于初始化的__init__()方法,以及__getitem__()函数会根据索引值返回一个图像和对应的标签值,len()函数可以用来查看数据的长度,即样本的数量。这里我们是构建两个Dataset,一个是训练数据的Dataset,另一个是验证数据的Dataset,有了Dataset我们就可以构建数据装载器DataLoader了。在DataLoader中我们需要传入之前创建好的Dataset,同时还有batch_size,也就是批处理的样本量,还有是否将数据进行打乱,在训练集中需要将数据进行shuffle而在验证集中则不用。之后我们采用经典的深度神经网络模型LeNet来解决这个人民币二分类问题,接下来我们通过代码的调试来观察pytorch是如何读取数据的:
首先在循环读取数据起始位置设置断点并进行Debug:
Pytorch系列之——数据读取机制&Transforms_第7张图片之后选择stepinto就能跳转到这个函数中,
Pytorch系列之——数据读取机制&Transforms_第8张图片可以看到首先函数会获取indices也就是索引,以及batch也就是数据,这个函数也正好告诉了我们数据读取问题中读取哪些数据的问题。之后我们将光标放在获取indices这一行代码上并点击Run to Cursor,将代码运行到光标这一行。接着我们再stepinto光标所在行的代码,看看pytorch是如何获取indices的,进入到了Sampler:
Pytorch系列之——数据读取机制&Transforms_第9张图片Sampler就是一个采样器,它用来告诉我们每一个batchsize该读取哪些数据。尤其是在__iter__()函数中,这个函数是用产生迭代索引值的,也就是指定每个step需要读取哪些数据。
之后我们再stepover,stepout跳出这个函数,并将光标放在indices下面那一行,即可看到已获取的索引值:

在这里插入图片描述

可以看到,已经获得了indices索引值,由于我们的batch_size设置的是16,所以获得的索引数量也就为16。在得到索引之后,就可以根据索引获取相应的数据了。
再在data这一行设置断点,并stepinto查看究竟在这个函数中究竟发生了什么:
Pytorch系列之——数据读取机制&Transforms_第10张图片可以看到我们stepinto到了__getitem__()这个函数,并根据索引值获取相应的img和label,而获取数据的关键点就在这里。
stepout跳出这个函数之后,我们可以发现函数collate_fn(),这个函数是对获取到的数据进行整理,整理成一个batch:
Pytorch系列之——数据读取机制&Transforms_第11张图片
经过代码调试之后,我们可以更好地解决之前数据读取的三个问题:
Pytorch系列之——数据读取机制&Transforms_第12张图片读哪些数据?是根据Sampler输出的Index进行读取;从哪读数据?是根据Dataset中的data_dir读取,简单地说就是从相关的硬盘文件读取;怎么读数据?这个就是根据比较重要的Dataset中的getitem函数来读取了。将这一数据读取过程用流程图表示如下:
Pytorch系列之——数据读取机制&Transforms_第13张图片首先我们根据for循环进入DataLoader,之后判断是采用单进程还是多进程进入DataLoaderIter,进入到DataLoaderIter之后我们会使用Sampler去获取index(indices)索引,拿到索引之后给到DatasetFetcher,在DatasetFetcher中会调用Dataset,Dataset根据我们给定的索引在getitem当中从硬盘当中去读取我们实际的图像和标签,在读取了一个batchsize大小的数据之后,通过一个collate_fn将获取的这些数据进行整理,整理成一个batchdata的形式,然后即可输入到我们的模型当中去训练了。

数据预处理模块——transforms

  • transforms运行机制
  • 数据标准化——transforms.normalize

在这部分,我们需要安装计算机视觉工具包——torchvision,其中常用的几个方法包括:

  • torchvision.transforms:常用的图像预处理方法
  • torchvision.datasets:常用数据集的datasets实现,包括MNIST、CIFAR-10、ImageNet等
  • torchvision.model:常用的模型预训练,AlexNet,VGG,ResNet,GoogLeNet等

torchvision.transforms:常用的图像预处理方法,主要包括数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度、饱和度及对比度变换。
Pytorch系列之——数据读取机制&Transforms_第14张图片这里我们依然采用人民币二分类来对transforms进行说明:
Pytorch系列之——数据读取机制&Transforms_第15张图片这里的Compose是将我们的一系列transforms方法进行有序的组合,在具体实现的时候会依次地将这些方法对图像进行操作。在train_transform那里,可以看到首先先对图像进行resize把图像缩放到32x32大小的尺寸,接着对图像进行一个随机裁剪,再使用ToTensor将图像转换成张量形式,同时会进行一个归一化操作,把像素值的区间从0-255归一化到0-1,最后一步是标准化将均值变为0,标准差变为1。
现在依然通过代码调试的方式来查看其运行机制,先通过设置断点并stepinto到如下这个函数:
Pytorch系列之——数据读取机制&Transforms_第16张图片
之后我们在self.transforms处进行Run on Cursor并stepinto,即可进入到transforms.py文件中的__call__()函数:
Pytorch系列之——数据读取机制&Transforms_第17张图片可以看到__call__()函数里是一个for循环,也就是依次地执行对图像的一系列变换操作,这个循环结束后最终返回执行完变换的图像。从这里可以看到,transform是在__getitem__()函数中被调用,再通过__getitem__()返回一个样本,之后不断根据索引总共获取batchsize大小的数据。这就是pytorch数据读取和transform的运行机制。
Pytorch系列之——数据读取机制&Transforms_第18张图片这次把transform加入到数据读取流程图中。

transforms.Normalize:
功能:逐channel的对图像进行标准化
output = (input - mean) / std
Pytorch系列之——数据读取机制&Transforms_第19张图片

  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原地操作
    对数据进行标准化之后可以加快模型的收敛。

transforms图像增强(一)

  • 数据增强
  • transforms——裁剪
  • transforms——翻转和旋转

数据增强(Data Augmentation):数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。
Pytorch系列之——数据读取机制&Transforms_第20张图片
裁剪——transforms.CenterCrop:
功能:从图像中心裁剪图片

  • size:所需裁剪图片尺寸

随机裁剪——transforms.RandomCrop:
功能:从图片中随机裁剪出尺寸为size的图片
Pytorch系列之——数据读取机制&Transforms_第21张图片

  • size:所需裁剪图片尺寸

  • padding:设置填充大小
    当padding=a时,上下左右均填充a个像素
    当padding=(a,b)时,上下填充b个像素,左右填充a个像素
    当padding=(a,b,c,d)时,左,上,右,下分别填充a,b,c,d

  • pad_if_need:若图像小于设定size,则填充

  • padding_mode:填充模式,有4种模式
    1.constant:像素值由fill设定
    2.edge:像素值由图像边缘像素设定
    3.reflect:镜像填充,最后一个像素不镜像
    4.symmetric:镜像填充,最后一个像素镜像

  • fill:constant时,设置填充的像素值
    Pytorch系列之——数据读取机制&Transforms_第22张图片
    transforms——Crop
    RandomResizedCrop:
    功能:随机大小、长宽比裁剪图片
    Pytorch系列之——数据读取机制&Transforms_第23张图片

  • size:所需裁剪图片尺寸

  • scale:随机裁剪面积比例,默认(0.08,1)

  • ratio:随机长宽比,默认(3/4,4/3)

  • interpolation:插值方法
    PIL.Image.NEAREST
    PIL.Image.BILINEAR
    PIL.Image.BICUBIC
    FiveCrop、TenCrop:
    功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或垂直镜像获得10张图片。
    Pytorch系列之——数据读取机制&Transforms_第24张图片

  • size:所需裁剪图片尺寸

  • vertical_flip:是否垂直翻转

transforms——翻转、旋转

1.RandomHorizontalFlip
2.RandomVerticalFlip
功能:依概率水平(左右)或垂直(上下)翻转图片

  • p:翻转概率
    在这里插入图片描述
    3.RandomRotation
    功能:随机旋转图片
    Pytorch系列之——数据读取机制&Transforms_第25张图片

  • degrees:旋转角度
    当degrees=a时,在(-a,a)之间选择旋转角度
    当degrees=(a,b)时,在(a,b)之间选择旋转角度

  • resample:重采样方法

  • expand:是否扩大图片,以保持原图信息

  • center:旋转点位置,默认中心旋转

transforms图像增强(二)

1.Pad
功能:对图片边缘进行填充
在这里插入图片描述

  • padding:设置填充大小
    当padding=a时,上下左右均填充a个像素
    当padding=(a,b)时,上下填充b个像素,左右填充a个像素
    当padding=(a,b,c,d)时,左,上,右,下分别填充a,b,c,d

  • padding_mode:填充模式,有4种模式,constant,edge,reflect和symmetric

  • fill:constant时,设置填充的像素值(R,G,B) 或(Gray)

2.ColorJitter
功能:调整亮度、对比度、饱和度和色相
Pytorch系列之——数据读取机制&Transforms_第26张图片

  • brightness:亮度调整因子
    当brightness=a时,从[max(0,1-a),1+a]中随机选择
    当brightness=(a,b)时,从[a,b]中选择

  • contrast:对比度参数,同brightness

  • saturation:饱和度参数,同brightness

  • hue:色相参数
    当hue=a时,从[-a,a]中选择参数,注:0<=a<=0.5
    当hue=(a,b)时,从[a,b]中选择参数,注:-0.5<=a<=b<=0.5

3.Grayscale
4.RandomGrayscale

功能:依概率将图片转换为灰度图
Pytorch系列之——数据读取机制&Transforms_第27张图片

  • num_output_channels:输出通道数,只能设1或3
  • p:概率值,图像被转换为灰度图的概率

5.RandomAffine
功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
Pytorch系列之——数据读取机制&Transforms_第28张图片

  • degrees:旋转角度设置
  • translate:平移区间设置,如(a,b),a设置宽(width),b设置高(height),图像在宽维度平移的区间为-img_width * a< dx
  • scale:缩放比例(以面积为单位)
  • fill_color:填充颜色设置
  • shear:错切角度设置有水平错切和垂直错切
  • resample:重采样方式,有NEAREST,BILINEAR,BICUBIC

6.RandomErasing
功能:对图像进行随机遮挡(接收张量,是对Tensor进行操作的)
Pytorch系列之——数据读取机制&Transforms_第29张图片

  • p:概率值,执行该操作的概率
  • scale:遮挡区域的面积
  • ratio:遮挡区域的长宽比
  • value:设置遮挡区域的像素值,(R,G,B)或(Gray)

7.transforms.Lambda
功能:用户自定义lambda方法
在这里插入图片描述

  • lambda:lambda匿名函数
    lambda [arg1,[,arg2,…,argn]]:expression

transforms的操作

1.transforms.RandomChoice
功能:从一系列transforms方法中随机挑选一个
在这里插入图片描述
2.transforms.RandomApply
功能:依据概率执行一组transforms操作
在这里插入图片描述
3.transforms.RandomOrder
功能:对一组transforms操作打乱顺序
在这里插入图片描述

自定义transforms

自定义transforms要素:

  • 仅接收一个参数,返回一个参数
  • 注意上下游的输出与输入
    Pytorch系列之——数据读取机制&Transforms_第30张图片通过类实现多参数传入:
    Pytorch系列之——数据读取机制&Transforms_第31张图片
    椒盐噪声:
    椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点,白点称为盐噪声,黑色称为椒噪声。
    信噪比(Signal-Noise Rate,SNR)是衡量噪声的比例,图像中为图像像素的占比。

transforms methods总结:
Pytorch系列之——数据读取机制&Transforms_第32张图片

数据增强实战应用
原则:让训练集与测试集更接近

  • 空间位置:平移
  • 色彩:灰度图,色彩抖动
  • 形状:仿射变换
  • 上下文场景:遮挡、填充
  • … …

参考资料

深度之眼训练营——Pytorch课程

你可能感兴趣的:(PyTorch,深度学习)