Task2:Pytorch数据预处理--数据读取与数据扩充

目录

    • 一、前言
    • 二、数据扩充
      • 2.1opencv对图像数据扩充常用Method
      • 2.2 Torchvision对图像数据进行扩充常用Method
      • 2.3 GAN生成扩充数据
    • 三、Pytorch实现数据读取与扩充

一、前言

在深度学习模型的训练过程中,为了保证模型的泛化能力,通常需要做数据扩充(Data augment),即在原有数据基础上,通过数据变换来生成新的数据从而扩大训练集,提高模型的泛化能力。

二、数据扩充

使用Python对图像数据处理来说,常用的Package有:

  • opencv-python
  • Pillow
  • imgaug
  • albumentations
  • torchvision(Pytorch官方提供的数据扩充库,可以和Torch无缝对接)

数据扩充一般是从颜色空间、尺度空间和样本空间根据不同任务进行相应的扩充,对于图像分类,数据扩充不会改变Label值;对于物体检测,数据扩充不会改变物体坐标位置;对于图像分割,数据扩充会改变像素标签

2.1opencv对图像数据扩充常用Method

1、图像处理
不改变图像整体特征的情况下,生成样本空间没有的新图像。
方法:

Method Augment effect
flip 将图像水平或垂直翻转
rotate 图像旋转
resize 图像放大或缩小
crop 图像剪裁,即从原图随机采样部分数据
warpAffine 仿射变换
Gsussian noise 添加随机噪声
Value/contrast 图像亮度和对比度变换

对于旋转、缩放等步骤来说,图像扩充后图像尺寸或维度发生 变化,需要使用插值(Interpolation)来填补为原尺寸,根据情况,可有以下几种插值方式:

  • 常亮填充:单色背景下,可使用0填充、1填充等方式得到一块单色区域
  • 边缘扩展:用图像边缘值填充边界,适用于平移图像填充
  • 反射:沿图像边缘值反射填充,一般适用连续或自然背景

Task2:Pytorch数据预处理--数据读取与数据扩充_第1张图片

2.2 Torchvision对图像数据进行扩充常用Method

Method Augment affect
transforms.CenterCrop 中心区域裁剪
transforms.RandomCrop 随机区域裁剪
transforms.RandomRotation 随机旋转
transforms.RandomHorizontalFlip 水平翻转
transforms.RandomVerticalFlip 垂直翻转
transforms.Resize 缩放
transforms.Normalize 标准化
transforms.ToTensor 转为tensor,归一化至[0-1] 区间
transforms.ToPILImage 转换为PILImage
transforms.Grayscale 灰度化
transforms.ColorJitter 亮度、对比度调整
transforms.LinearTransformation() 线性变换
transforms.RandomChoice(transforms) 选一个transforms进行操作

等等,效果图如下:
Task2:Pytorch数据预处理--数据读取与数据扩充_第2张图片

2.3 GAN生成扩充数据

利用对抗式神经网络在原有数据基础上通过训练可以生成新的数据。这就需要一个专门训练生成数据的网络。

三、Pytorch实现数据读取与扩充

Pytorch配合torchvision可同时完成数据读取与扩充:

# Python库加载
import os, sys, glob, shutil, json
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms

class SVHNDataset(Dataset):
	# dataset初始化
	def __init__(self, img_path, img_label, transform=None):
		self.img_path = img_path
		self.img_label = img_label
		if transform is not None:
			self.transform = transform
		else:
			self.transform = None
	
	# 读数据 
	def __getitem__(self, index):
		img = Image.open(self.img_path[index]).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		
		# 原始SVHN中类别10为数字0
		lbl = np.array(self.img_label[index], dtype=np.int)
		lbl = list(lbl) + (5 - len(lbl)) * [10]
		
		return img, torch.from_numpy(np.array(lbl[:5]))
	
	# 数据集长度
	def __len__(self):
		return len(self.img_path)

# 从文件夹中写入数据:假设数据集保存在 input文件夹下
train_path = glob.glob('../input/train/*.png')
train_path.sort()
train_json = json.load(open('../input/train.json'))
train_label = [train_json[x]['label'] for x in train_json]

data = SVHNDataset(train_path, train_label,
			transforms.Compose([
				# 缩放到固定尺寸
	 			transforms.Resize((64, 128)),
			
			  	# 随机颜色变换
	 		  	transforms.ColorJitter(0.2, 0.2, 0.2),
	
			  	# 加⼊随机旋转
	 	     	transforms.RandomRotation(5),
	 	   
	 	     	# 将图⽚转换为pytorch 的tesntor
		     	# transforms.ToTensor(),
		   
		     	# 对图像像素进⾏归⼀化
		     	# transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
	    	 ]))

通过以上过程,就可以将input中的train、train.json图片和对应Label同时加载进dataset、

为了方便在后续的操作中批量训练train datasets了,在PyTorch里,我们还要在datasets基础上构建DataLoder方便迭代读取,它和dataset的区别是:

  • Datasets只是完成对数据集的封装,可以提供Index对样本读取
  • DataLoder是对Datasets的封装,可以提供批量读取
# 这里只需要对上面data = SVHNDataset(train_path, train_label,
#							transforms.Compose(...));
# 封装一下就OK了,修改代码如下:

# ....读取过程一样....
train_loader = torch.utils.data.DataLoader(
					SVHNDataset(train_path, train_label,
								transforms.Compose([
									transforms.Resize((64, 128)),
									transforms.ColorJitter(0.3, 0.3, 0.2),
									transforms.RandomRotation(5),
									transforms.ToTensor(),
									transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
								])),
					batch_size=10, # 每批样本个数
					shuffle=False, # 是否打乱顺序
 					num_workers=10, # 读取的线程个数
)
for data in train_loader:
	break

通过DataLoder对Data进行封装后,每次调用Dataset读取单个样本进行拼接。此时data的格式为:

torch.Size([10, 3, 64, 128]), torch.Size([10, 6])  
# [10, 3, 64, 128]表示图像样本数据,四维向量分别是:batchsize * chanel * height * width
# torch.Size([10, 6]表示图像样本的Label

通过以上步骤就完成了Pytorch对datasets的读取,同时完成了对数据集的扩充。

你可能感兴趣的:(Pytorch)