Pytorch 实战三 Cifar 10 数据加载

系列文章目录


文章目录

  • 系列文章目录
  • 前言
  • 一、原始数据的读取
  • 二、数据的加载
    • 源代码
    • 标签加工
    • 辅助
    • 核心类


前言

  前两个实战然我们基本学会了Pytorch的使用,本文的 Cifar 10 是经典的人工智能深度学习数据集。将详细介绍数据输入、网络搭建等其中的细致处理,这几个工程下来,我们一定会学会 Pytorch 编程的。

一、原始数据的读取

  Cifar 10 的数据集下载地址(可以在官网,这里我给出网盘地址):
通过网盘分享的文件:cifar-10-python.tar.gz
链接: https://pan.baidu.com/s/1nFIdo2-gFqNWbwhk9IBTyQ?pwd=6ypf 提取码: 6ypf

import torch
import os
import cv2
import numpy as np
import glob
import pickle
# 数据解析
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

save_path = r"D:\TestDemo\PytrochLearning\datasets\cifar10\test"
train_list = glob.glob(r"D:\论文\数据\cifar-10-python\cifar-10-batches-py/test_batch")
label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]
print("train_list:\n",train_list)

for l in train_list:
    print(l)
    l_dict = unpickle(l)
    print(l_dict.keys())

    for im_index ,im_data in enumerate(l_dict[b'data']):
        # print(im_index)
        # print(im_data)
        # 保存label 和 name
        im_label = l_dict[b'labels'][im_index]
        im_name = l_dict[b'filenames'][im_index]
        # print(im_label,im_name,im_data)
        im_label_name = label_name[im_label]  # 使用到10个名字的数组,im_lable 是一个索引个
        # 使用numpy 进行张量化
        im_data = np.reshape(im_data,[3,32,32])
        # 转化成Pytorch 的张量
        im_data = np.transpose(im_data,(1,2,0))
        # 可视化
        # cv2.namedWindow("image", cv2.WINDOW_NORMAL)  # 控制打印的大小
        # cv2.imshow("image", im_data)  # 控制打印窗口大小
        # cv2.waitKey(0)

        # 保存文件,我们会为每一个类别创建一个文件夹
        if not os.path.exists("{}/{}".format(save_path,im_label_name)):
            os.makedirs("{}/{}".format(save_path,im_label_name))

        # im2 写入图片
        cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)



  数据解析函数和标签是官方提供的,采用键值对表示数据。打印出的数据是为了方便我们初步了解数据,文件路径一定要对。data_batch_* 因为文件很大,分了几个区块加载,* 的含义是选中所有区块。

Pytorch 实战三 Cifar 10 数据加载_第1张图片
图 1.1 原始数据

我们把原始数据打印出来发现是一个数组,所以我们需要使用reshape函数,把数组转换成张量 3 x 32 x 32,但是 Pytorch 处理图片的张量形式为 32 x 32 x 3 ,所以就要把数据维度交换一下。图片的类别是数字,我们需要把数字作为label_name索引,获得其英文名字。

    for im_index ,im_data in enumerate(l_dict[b'data']):
        # print(im_index)
        # print(im_data)
        # 保存label 和 name
        im_label = l_dict[b'labels'][im_index]
        im_name = l_dict[b'filenames'][im_index]
        # print(im_label,im_name,im_data)
        im_label_name = label_name[im_label]  # 使用到10个名字的数组,im_lable 是一个索引个
        # 使用numpy 进行张量化
        im_data = np.reshape(im_data,[3,32,32])
        # 转化成Pytorch 的张量
        im_data = np.transpose(im_data,(1,2,0))
        # 可视化
        # cv2.namedWindow("image", cv2.WINDOW_NORMAL)  # 控制打印的大小
        # cv2.imshow("image", im_data)  # 控制打印窗口大小
        # cv2.waitKey(0)

  这串代码我们从每个batch入手,然后提取图片的标签、名字、数据,把数据转化成Tensor张量。可视化是为了看看图片的模样,把窗口放小一点,非常小才能看清图片的内容哦。
保存为训练集

        # 保存文件,我们会为每一个类别创建一个文件夹
        if not os.path.exists("{}/{}".format(save_path,im_label_name)):
            os.makedirs("{}/{}".format(save_path,im_label_name))

        # im2 写入图片
        cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,im_name.decode("utf-8")),im_data)

  保存为训练集,也就是创建10个类别的文件夹,然后写入数据这两步。train文件夹自己创建,如图1.2所示。

Pytorch 实战三 Cifar 10 数据加载_第2张图片
图 1.2 训练集数据

  如法炮制,我们生成测试集的数据。改两个地方文件末尾,test,记得自己创建 test 文件夹,效果图就不展示了。为什么要保存为图片,我们下载的数据是数组,只有跑完上面的代码才能看见效果图,我们保存为图片,更加直观,当然你说你看到数组就能知道这图片是啥,那也就无所谓了,建议参加最强大脑,哈哈哈。

图 1.3 修改路径

二、数据的加载

源代码

  创建一个新的 Python 文件,跟着我起飞,直接上源代码。


import glob
# 这个包用于增强数据
from torchvision import transforms
# 导入数据加载相应的类
from torch.utils.data import DataLoader,Dataset
import os
# 数据读取
from PIL import Image
import numpy as np

# 标签加工
label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]
# 创建一个空列表
label_dict = {}
for index, name in enumerate(label_name):
    label_dict[name] = index

# 辅助
def default_loader(path):
    return Image.open(path).convert('RGB')

# transform 在训练和测试是不同的,这里我们在训练阶段,拼接多个方法
train_transform = transforms.Compose([
    transforms.RandomCrop((28, 28)),  # Crop 之后尺寸会小一点
    transforms.RandomHorizontalFlip(),# 随机水平反转,默认概率0.5
    transforms.RandomVerticalFlip(), # 垂直翻转  定义反转注意我们的数据是否对反转敏感,敏感的意思是翻转之后的label会不会发生变化,变化重新定义label
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1), # 灰度变化,概率0.1
    transforms.ColorJitter(0.3, 0.3, 0.3,0.3), # 颜色增强
    transforms.ToTensor() # 转化成输入数据
])
# 核心
class MyDataset(Dataset):

    # 初始化函数,完成对数据的读取和简单处理,放在一个列表中
    def __init__(self,im_list,transform=None,loader = default_loader):   # im_list 我们的文本。灵活处理,transform 是数据增强用的
        super(MyDataset, self).__init__()
        images=[]
        for img_item in im_list:
            # D:\TestDemo\PytrochLearning\datasets\cifar10\test\airplane\aeroplane_s_000002.png
            im_label_name = img_item.split('\\')[-2]   # 获取类别
            images.append([img_item,label_dict[im_label_name]])

        # 把这些变成类内变量
        self.transform = transform
        self.loader = loader
        self.images = images
    # 获取元素,定义图片的获取,增强,然后返回图片的label
    def __getitem__(self, index):
        # 根据索引值读取相应的图片
        im_path,im_label = self.images[index]
        # 获取图片路径
        im_data = self.loader(im_path)
        # 有数据增强就使用数据增强
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data,im_label
    # 获取长度
    def __len__(self):
        return len(self.images)

# 调用这个类
# st1.使用global 拿到训练数据和测试数据
im_train_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\train\*\*.png')
im_test_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\test\*\*.png')

train_dataset = MyDataset(im_train_list,
                          transform=train_transform)
test_dataset = MyDataset(im_test_list,
                          transform=transforms.ToTensor())
traint_data_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,num_workers=4)
test_data_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False,num_workers=4)

print("num 0f train dataset:",len(train_dataset))
print("num 0f test dataset:",len(test_dataset))

标签加工

  代码分成四个部分,标签加工、辅助函数、核心类和数据加载。标签加工,我们之前保存图片,把标签转化成了文件夹名字,所以这里我们读取数据后,还需要把标签给贴上数字才行,机器只认数字。所以建立一个空的字典先把标签和数字对应上,比如:“airplane”:0,...,"truck" = 9

辅助

  辅助类,第一个默认加载器,把图片转化成统一的 RGB 三通道格式。第二训练的 transform,用于数据增强,这两个辅助都会在核心中使用到。有数据增强,数据不敏感就直接抄。

  • transforms.RandomCrop((28, 28)), # Crop 之后尺寸会小一点
  • transforms.RandomHorizontalFlip(), # 随机水平反转,默认概率0.5
  • transforms.RandomVerticalFlip(), # 垂直翻转 定义反转注意我们的数据是否对翻转敏感,敏感的意思是翻转之后的label会不会发生变化,变化需要重新定义label
  • transforms.RandomRotation(90), # 前面说了
  • transforms.RandomGrayscale(0.1), # 灰度变化,概率0.1发生
  • transforms.ColorJitter(0.3, 0.3, 0.3,0.3), # 颜色增强
  • transforms.ToTensor() # 转化成输入数据

核心类

   核心类class MyDataset(Dataset) ,其中有三个核心方法,初始化函数,获取元素的函数,获取长度的函数,最后一个最简单了。第一个函数:def __init__(self,im_list,transform=None,loader = default_loader):,实际上传入参数只需要传入两个,im_list,transform, loader 我们自己写死了,这句 self.loader = loader ,传递参数就与 default_loader 绑定了。如果少了 transformer,连数据增强都可以不要,Python 编程方式好是灵活,换作 Java 写一堆去了。

        images=[]
        for img_item in im_list:
            # D:\TestDemo\PytrochLearning\datasets\cifar10\test\airplane\aeroplane_s_000002.png
            im_label_name = img_item.split('\\')[-2]   # 获取类别
            images.append([img_item,label_dict[im_label_name]])
               # 把这些变成类内变量
        self.images = images
        # print("数据增强")
        self.transform = transform
        self.loader = loader

初始化函数有这么一段代码,按照路径分割,windows这里一定要 \\ 代表就是 \ ,我们读入的是整个数据集,按照每个类别建立索引。是一个地址列和数字化的类别结合起来的列表,昨晚调试了两个小时才搞明白。后面三句把这个类的成员变量赋值,图片,transform 数据增强的工具,数据加载器。 注意:这里只是选择数据加载和transform 的类型,并没有去执行,真正要执行的是在索引和求长度的时候。

 # 获取元素,定义图片的获取,增强,然后返回图片的label
    def __getitem__(self, index):
        print("执行了getitem")
        # 根据索引值读取相应的图片
        im_path, im_label = self.images[index]
        # 获取图片路径
        im_data = self.loader(im_path)
        # 有数据增强就使用数据增强
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data, im_label

    # 获取长度
    def __len__(self):
        print("调用了len\n")
        return len(self.images)

这两个函数以及上一个 __init__ 都是python的内置函数,自动触发,不需要显示调用。当使用索引时

print("train_dataset:\n",train_dataset[0][0].shape)

就会自动运行__getitem__然后根据images 去加载数据,判断transformer,我们初始化的时候已经传入了,进行数据增强。

print("num 0f test dataset:", len(test_dataset))

这句话,len 也是python的内置函数,由于我们在类中重写了,所以调用的是类中的len。

im_train_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\train\*\*.png')
im_test_list = glob.glob(r'D:\TestDemo\PytrochLearning\datasets\cifar10\test\*\*.png')

train_dataset = MyDataset(im_train_list,
                          transform=train_transform)
test_dataset = MyDataset(im_test_list,
                          transform=transforms.ToTensor())
traint_data_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True,num_workers=4)
test_data_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False,num_workers=4)

glob 获取地址路径,MyDataset 初始化操作,内容你要看,根据索引去查,就会自动调用数据加载,动态的。你不调用就是数据的文件路径和类别名字组成的变量images.
哪里不懂?打断点调式,print 输出提示,自己过一遍,费时间但是绝对是最快入门的技巧。只有了解数据流向,当我们设计程序的时候才会调用对应的函数去控制流向,看懂了不代表会编写自己的代码。最后希望大家帮助关注一下小店,增加浏览量。

你可能感兴趣的:(Pytorch,pytorch,人工智能,python)