前两个实战然我们基本学会了Pytorch的使用,本文的 Cifar 10 是经典的人工智能深度学习数据集。将详细介绍数据输入、网络搭建等其中的细致处理,这几个工程下来,我们一定会学会 Pytorch 编程的。
Cifar 10 的数据集下载地址(可以在官网,这里我给出网盘地址):
通过网盘分享的文件:cifar-10-python.tar.gz
链接: https://pan.baidu.com/s/1nFIdo2-gFqNWbwhk9IBTyQ?pwd=6ypf 提取码: 6ypf
import torch
import os
import cv2
import numpy as np
import glob
import pickle
# 数据解析
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
save_path = r"D:\TestDemo\PytrochLearning\datasets\cifar10\test"
train_list = glob.glob(r"D:\论文\数据\cifar-10-python\cifar-10-batches-py/test_batch")
label_name = ["airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"]
print("train_list:\n",train_list)
for l in train_list:
print(l)
l_dict = unpickle(l)
print(l_dict.keys())
for im_index ,im_data in enumerate(l_dict[b'data']):
# print(im_index)
# print(im_data)
# 保存label 和 name
im_label = l_dict[b'labels'][im_index]
im_name = l_dict[b'filenames'][im_index]
# print(im_label,im_name,im_data)
im_label_name = label_name[im_label] # 使用到10个名字的数组,im_lable 是一个索引个
# 使用numpy 进行张量化
im_data = np.reshape(im_data,[3,32,32])
# 转化成Pytorch 的张量
im_data = np.transpose(im_data,(1,2,0))
# 可视化
# cv2.namedWindow("image", cv2.WINDOW_NORMAL) # 控制打印的大小
# cv2.imshow("image", im_data) # 控制打印窗口大小
# cv2.waitKey(0)
# 保存文件,我们会为每一个类别创建一个文件夹
if not os.path.exists("{}/{}".format(save_path,im_label_name)):
os.makedirs("{}/{}".format(save_path,im_label_name))
# im2 写入图片
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)
数据解析函数和标签是官方提供的,采用键值对表示数据。打印出的数据是为了方便我们初步了解数据,文件路径一定要对。data_batch_*
因为文件很大,分了几个区块加载,* 的含义是选中所有区块。
我们把原始数据打印出来发现是一个数组,所以我们需要使用reshape
函数,把数组转换成张量 3 x 32 x 32,但是 Pytorch 处理图片的张量形式为 32 x 32 x 3 ,所以就要把数据维度交换一下。图片的类别是数字,我们需要把数字作为label_name
索引,获得其英文名字。
for im_index ,im_data in enumerate(l_dict[b'data']):
# print(im_index)
# print(im_data)
# 保存label 和 name
im_label = l_dict[b'labels'][im_index]
im_name = l_dict[b'filenames'][im_index]
# print(im_label,im_name,im_data)
im_label_name = label_name[im_label] # 使用到10个名字的数组,im_lable 是一个索引个
# 使用numpy 进行张量化
im_data = np.reshape(im_data,[3,32,32])
# 转化成Pytorch 的张量
im_data = np.transpose(im_data,(1,2,0))
# 可视化
# cv2.namedWindow("image", cv2.WINDOW_NORMAL) # 控制打印的大小
# cv2.imshow("image", im_data) # 控制打印窗口大小
# cv2.waitKey(0)
这串代码我们从每个batch入手,然后提取图片的标签、名字、数据,把数据转化成Tensor张量。可视化是为了看看图片的模样,把窗口放小一点,非常小才能看清图片的内容哦。
保存为训练集
# 保存文件,我们会为每一个类别创建一个文件夹
if not os.path.exists("{}/{}".format(save_path,im_label_name)):
os.makedirs("{}/{}".format(save_path,im_label_name))
# im2 写入图片
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)
保存为训练集,也就是创建10个类别的文件夹,然后写入数据这两步。train文件夹自己创建,如图1.2所示。
如法炮制,我们生成测试集的数据。改两个地方文件末尾,test,记得自己创建 test 文件夹,效果图就不展示了。为什么要保存为图片,我们下载的数据是数组,只有跑完上面的代码才能看见效果图,我们保存为图片,更加直观,当然你说你看到数组就能知道这图片是啥,那也就无所谓了,建议参加最强大脑,哈哈哈。
创建一个新的 Python 文件,跟着我起飞,直接上源代码。
import glob
# 这个包用于增强数据
from torchvision import transforms
# 导入数据加载相应的类
from torch.utils.data import DataLoader,Dataset
import os
# 数据读取
from PIL import Image
import numpy as np
# 标签加工
label_name = ["airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"]
# 创建一个空列表
label_dict = {}
for index, name in enumerate(label_name):
label_dict[name] = index
# 辅助
def default_loader(path):
return Image.open(path).convert('RGB')
# transform 在训练和测试是不同的,这里我们在训练阶段,拼接多个方法
train_transform = transforms.Compose([
transforms.RandomCrop((28, 28)), # Crop 之后尺寸会小一点
transforms.RandomHorizontalFlip(),# 随机水平反转,默认概率0.5
transforms.RandomVerticalFlip(), # 垂直翻转 定义反转注意我们的数据是否对反转敏感,敏感的意思是翻转之后的label会不会发生变化,变化重新定义label
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1), # 灰度变化,概率0.1
transforms.ColorJitter(0.3, 0.3, 0.3,0.3), # 颜色增强
transforms.ToTensor() # 转化成输入数据
])
# 核心
class MyDataset(Dataset):
# 初始化函数,完成对数据的读取和简单处理,放在一个列表中
def __init__(self,im_list,transform=None,loader = default_loader): # im_list 我们的文本。灵活处理,transform 是数据增强用的
super(MyDataset, self).__init__()
images=[]
for img_item in im_list:
# D:\TestDemo\PytrochLearning\datasets\cifar10\test\airplane\aeroplane_s_000002.png
im_label_name = img_item.split('\\')[-2] # 获取类别
images.append([img_item,label_dict[im_label_name]])
# 把这些变成类内变量
self.transform = transform
self.loader = loader
self.images = images
# 获取元素,定义图片的获取,增强,然后返回图片的label
def __getitem__(self, index):
# 根据索引值读取相应的图片
im_path,im_label = self.images[index]
# 获取图片路径
im_data = self.loader(im_path)
# 有数据增强就使用数据增强
if self.transform is not None:
im_data = self.transform(im_data)
return im_data,im_label
# 获取长度
def __len__(self):
return len(self.images)
# 调用这个类
# st1.使用global 拿到训练数据和测试数据
im_train_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\train\*\*.png')
im_test_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\test\*\*.png')
train_dataset = MyDataset(im_train_list,
transform=train_transform)
test_dataset = MyDataset(im_test_list,
transform=transforms.ToTensor())
traint_data_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,num_workers=4)
test_data_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False,num_workers=4)
print("num 0f train dataset:",len(train_dataset))
print("num 0f test dataset:",len(test_dataset))
代码分成四个部分,标签加工、辅助函数、核心类和数据加载。标签加工,我们之前保存图片,把标签转化成了文件夹名字,所以这里我们读取数据后,还需要把标签给贴上数字才行,机器只认数字。所以建立一个空的字典先把标签和数字对应上,比如:“airplane”:0,...,"truck" = 9
。
辅助类,第一个默认加载器,把图片转化成统一的 RGB 三通道格式。第二训练的 transform,用于数据增强,这两个辅助都会在核心中使用到。有数据增强,数据不敏感就直接抄。
transforms.RandomCrop((28, 28)),
# Crop 之后尺寸会小一点transforms.RandomHorizontalFlip(),
# 随机水平反转,默认概率0.5transforms.RandomVerticalFlip(),
# 垂直翻转 定义反转注意我们的数据是否对翻转敏感,敏感的意思是翻转之后的label会不会发生变化,变化需要重新定义labeltransforms.RandomRotation(90),
# 前面说了transforms.RandomGrayscale(0.1),
# 灰度变化,概率0.1发生transforms.ColorJitter(0.3, 0.3, 0.3,0.3),
# 颜色增强transforms.ToTensor()
# 转化成输入数据
核心类class MyDataset(Dataset)
,其中有三个核心方法,初始化函数,获取元素的函数,获取长度的函数,最后一个最简单了。第一个函数:def __init__(self,im_list,transform=None,loader = default_loader):
,实际上传入参数只需要传入两个,im_list,transform, loader 我们自己写死了,这句 self.loader = loader
,传递参数就与 default_loader 绑定了。如果少了 transformer,连数据增强都可以不要,Python 编程方式好是灵活,换作 Java 写一堆去了。
images=[]
for img_item in im_list:
# D:\TestDemo\PytrochLearning\datasets\cifar10\test\airplane\aeroplane_s_000002.png
im_label_name = img_item.split('\\')[-2] # 获取类别
images.append([img_item,label_dict[im_label_name]])
# 把这些变成类内变量
self.images = images
# print("数据增强")
self.transform = transform
self.loader = loader
初始化函数有这么一段代码,按照路径分割,windows这里一定要 \\
代表就是 \ ,我们读入的是整个数据集,按照每个类别建立索引。是一个地址列和数字化的类别结合起来的列表,昨晚调试了两个小时才搞明白。后面三句把这个类的成员变量赋值,图片,transform 数据增强的工具,数据加载器。 注意:这里只是选择数据加载和transform 的类型,并没有去执行,真正要执行的是在索引和求长度的时候。
# 获取元素,定义图片的获取,增强,然后返回图片的label
def __getitem__(self, index):
print("执行了getitem")
# 根据索引值读取相应的图片
im_path, im_label = self.images[index]
# 获取图片路径
im_data = self.loader(im_path)
# 有数据增强就使用数据增强
if self.transform is not None:
im_data = self.transform(im_data)
return im_data, im_label
# 获取长度
def __len__(self):
print("调用了len\n")
return len(self.images)
这两个函数以及上一个 __init__
都是python的内置函数,自动触发,不需要显示调用。当使用索引时
print("train_dataset:\n",train_dataset[0][0].shape)
就会自动运行__getitem__
然后根据images 去加载数据,判断transformer,我们初始化的时候已经传入了,进行数据增强。
print("num 0f test dataset:", len(test_dataset))
这句话,len 也是python的内置函数,由于我们在类中重写了,所以调用的是类中的len。
im_train_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\train\*\*.png')
im_test_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\test\*\*.png')
train_dataset = MyDataset(im_train_list,
transform=train_transform)
test_dataset = MyDataset(im_test_list,
transform=transforms.ToTensor())
traint_data_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,num_workers=4)
test_data_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False,num_workers=4)
glob 获取地址路径,MyDataset 初始化操作,内容你要看,根据索引去查,就会自动调用数据加载,动态的。你不调用就是数据的文件路径和类别名字组成的变量images.
哪里不懂?打断点调式,print 输出提示,自己过一遍,费时间但是绝对是最快入门的技巧。只有了解数据流向,当我们设计程序的时候才会调用对应的函数去控制流向,看懂了不代表会编写自己的代码。最后希望大家帮助关注一下小店,增加浏览量。