1、Dataset+DataLoader实现自定义数据集读取方法
创建自己的数据集需要继承父类torch.utils.data.Dataset,同时需要重载两个私有成员函数:def __len__(self)和def __getitem__(self, index) 。 def __len__(self)应该返回数据集的大小;def __getitem__(self, index)接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
基于下面的模板写类和方法 来加载自己的图像
class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
def __init__(self):
#对继承自父类的属性进行初始化(好像没有这句也可以??)
super(MyDataset,self).__init__()
# TODO
#1、初始化一些参数和函数,方便在__getitem__函数中调用。
#2、制作__getitem__函数所要用到的图片和对应标签的list。
#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
#1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
#2、预处理数据(例如torchvision.Transform)。
#3、返回数据对(例如图像和标签)。
#这里需要注意的是,这步所处理的是index所对应的一个样本。
pass
def __len__(self):
#返回数据集大小
return len()
下面是一个 按照上面的模板 写的例子
第一步:先收集几张图片作为自己的数据集,然后自己手动创建一个txt文件(关于txt文件,通过python和matlab可以很容易的创建)储存图片对应的label。
图像照片
txt文件
第二步:按照上面的模板,制作自己的数据集类。
import torch
import torchvision
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
#路径是自己电脑里所对应的路径
datapath = r'E:\Python\DeepLearning\Datasets\testdata'
txtpath = r'E:\Python\DeepLearning\Datasets\testdata\label.txt'
class MyDataset(Dataset):
def __init__(self,txtpath):
#创建一个list用来储存图片和标签信息
imgs = []
#打开第一步创建的txt文件,按行读取,将结果以元组方式保存在imgs里
datainfo = open(txtpath,'r')
for line in datainfo:
line = line.strip('\n')
words = line.split()
imgs.append((words[0],words[1]))
self.imgs = imgs
#返回数据集大小
def __len__(self):
return len(self.imgs)
#打开index对应图片进行预处理后return回处理后的图片和标签
def __getitem__(self, index):
pic,label = self.imgs[index]
pic = Image.open(datapath+'\\'+pic)
pic = transforms.ToTensor()(pic)
return pic,label
#实例化对象
data = MyDataset(txtpath)
#将数据集导入DataLoader,进行shuffle以及选取batch_size
data_loader = DataLoader(data,batch_size=2,shuffle=True,num_workers=0)
#Windows里num_works只能为0,其他值会报错
代码很简单,你可以根据自己的实际数据集情况,修改上面的代码。
查看一下data_loader:
for pics,label in data_loader:
print(pics,label)
输出如下
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) ('4', '2')
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) ('2', '1')
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) ('0', '1')
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) ('4', '0')
tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]]) ('3', '3')
data_loader这个迭代器里存储的是每2个一组(batch_size)的图片像素信息以及对应的标签信息,也就是我们后续要导入到神经网络里的数据。
只是一个最简单的用来加深理解的例子,实际应用时会比这复杂很多,比如图像的渲染、变换等,但是基本流程都是一样的。
1.3、txt文件的生成
这个 很简单,用python脚本,写一个就行。
2、ImageFolder+DataLoader实现本地数据导入(这个很方便)
在pytorch中提供了torchvision.datasets.ImageFolder让我们训练自己的图像。ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
它主要有四个参数:
root:在root指定的路径下寻找图片
transform:对loader读取图片的返回对象进行转换操作(ToTensor等)
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
文件夹严格按照如下方式保存:
.
├──train
| ├──类别1
| | ├──*.jpg
| | ├──*.jpg
| | └──...
| ├──类别2
| | ├──*.jpg
| | ├──*.jpg
| | └──...
| └──...
└──test
├──类别1
| ├──*.jpg
| ├──*.jpg
| └──...
├──类别2
| ├──*.jpg
| ├──*.jpg
| └──...
└──...
实现代码如下
transform = transforms.ToTensor()
root = r'E:\Python\DeepLearning\Datasets\mymnist\train'
# 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
train_data = torchvision.datasets.ImageFolder(root, transform=transform)
train_iter = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True, num_workers=0)
test_data = torchvision.datasets.ImageFolder(root, transform=transform)
test_iter = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0)
参考:
(87条消息) Pytorch创建自己的数据集(一)_生活所迫^_^的博客-CSDN博客_pytorch构建自己的数据集