官方文档:https://pytorch.org/docs/stable/data.html
让pytorch读取数据集,主要用到 torch.utils.data.Dataset 和 torch.utils.data.DataLoader
表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该重写__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
官方代码:
class Dataset(Generic[T_co]):
"""
An abstract class representing a:class:'Dataset'.
All datasets that represent a map from keys to data samples should subclass it.
All subclasses should overwrite:meth:'__getitem__', supporting fetching a data
sample for a given key.Subclasses could also optionally overwrite:meth:'__len__',
which is expected to return the size of the dataset by many:class:
'~torch.utils.data.Sampler' implementations and the default options of :class:
'~torch.utils.data.DataLoader'.
note:
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index)
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]')
return ConcatDataset([self, other])
数据加载器:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
sampler=None, num_workers=0, drop_last=False)
参数:
dataset (Dataset) – 加载数据的数据集。
batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
第一步:定义自己的Dateset,继承 'torch.utils.data.Dataset ’ ,同时重写__getitem__和__len__。
类中的def__getitem__(self, index):接收一个index,然后通过index返回图片数据和标签。这个index指的是一个list的index,这个list中的每个元素包含了图片数据的路径和标签信息。
(1)构建一个上述的list
(2)重写__getitem__,返回list[index]中的数据
(3)重写__len__,返回list的长度
第二步:通过’torch.utils.data.DataLoader’ 读取自己的数据集。
具体细节:
问:list是如何构建?
答:首先需要一个txt(txt中是数据集的信息),通过读取txt构建list。
再问:txt什么样子呢?
再答:
是这样:
这个是mnist数据集,txt中每行是每张图片的信息(前半部分为图片位置,后半部分为图片的label)。注意: 往往图片的信息不是具体位置(往下看!)。
或者这样:
这个是voc数据集,可以看到,不同的是,每行是由图片的名字和label组成。
再注意: 根据自己txt中的内容构建list,要会灵活变通!
举个栗子八!(按照mnist那个txt读取的哦):
from PIL import Image
import torch
import torchvision.transforms as transforms
# 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, datatxt, transform=None, target_transform=None):
# 初始化一些需要传入的参数。transform对图片进行预处理(减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等操作)
dataset_dir = open(root + datatxt, 'r') # 打开这个txt,并读取内容
data_list = [] # 创建一个空列表,用来存图片的路径和label信息
for line in dataset_dir: # 按行读取txt中的内容
line = line.strip() # 删除本行开头和末尾的指定字符(eg:空格),type为str
line_list = line.split() # 通过指定分隔符(eg:空格)对字符串进行切片,type为list
data_list.append((line_list[0], int(line_list[1]))) # 把txt里的内容读入list, line_list[0]是图片信息,line_list[1]是lable, 组成数组,作为list中的一个元素
self.data_list = data_list
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index): # 用于按照index读取每个元素的具体内容
img_file, label = self.data_list[index] # img_file:data_list[0],即图片路径。label:data_list[1],即标签信息
img = Image.open(img_file).convert('RGB') # 按照路径读取图片
if self.transform is not None: # 是否进行transform
img = self.transform(img)
return img, label
def __len__(self): # 返回的是数据集的长度
return len(self.imgs)
root = "D:/MNIST/mnist_dataset/MNIST/raw/"
train_data = MyDataset(root, 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(root, 'test.txt', transform=transforms.ToTensor())
#最有就是调用DataLoader和刚刚创建的数据集,来创建dataloader
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=64)
做法和上一小节是一样的,只不过这次我们的label不是数字啦,而是图片!
直接上栗子:
首先看看我们数据集的txt!
咦?咋没有label信息呢?
答: 这是因为我们的图片和label都是同一个名称,只是在不同的文件夹中,我们就可以只通过名字来读取啦,看下边这张图!
本例采用voc数据集:voc数据集中有很多文件夹,下面例子只用到ImageSets,JPEGImages和SegmentationClass文件夹。其中ImageSets:是各种不同txt文件,JPEGImages:是原图 ,SegmentationClass:是label图片。
通过txt读名称,通过名称分别对JPEGImages和SegmentationClass进行访问,读取对应的图片构建list 。
OK!上代码!
代码是https://github.com/wkentaro/pytorch-fcn中的dataset
下图是加了标注!
class VOCClassSegBase(data.Dataset):
class_names = np.array([
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'potted plant',
'sheep',
'sofa',
'train',
'tv/monitor',
])
'''
这里的[104.00698793, 116.66876762, 122.67891434]指的是数据集里面所有训练集数据的
平均像素值,三个值对应图像的BGR,用于后边对图片进行减均值的预处理。
'''
mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])
def __init__(self, root, split='train', transform=False):
self.root = root
self.split = split
self._transform = transform
# 此路径下是ImageSets,JPEGImages和SegmentationClass文件夹
dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012')
# Python中通过Key访问字典,当Key不存在时,会引发‘KeyError’异常。为了避免这种情况的
#发生,可以使用collections类中的defaultdict()方法来为字典提供默认值。
self.files = collections.defaultdict(list)
# 构建list
for split in ['train', 'val']:
imgsets_file = osp.join(
dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) # txt路径
for did in open(imgsets_file): # 按行读取txt中的内容
did = did.strip() # txt中只有名字!
img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) # 原图路径
lbl_file = osp.join(
dataset_dir, 'SegmentationClass/%s.png' % did) # label路径
self.files[split].append({
'img': img_file,
'lbl': lbl_file,
}) # list中每个元素为字典{'img': img_file,'lbl': lbl_file,}
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
data_file = self.files[self.split][index] # {'img': img_file,'lbl': lbl_file,}
# load image
img_file = data_file['img']
img = PIL.Image.open(img_file) # RGB
img = np.array(img, dtype=np.uint8) # 将图片转为np数组
# load label
lbl_file = data_file['lbl']
lbl = PIL.Image.open(lbl_file)
lbl = np.array(lbl, dtype=np.int32)
lbl[lbl == 255] = -1 # 将label中像素为255的像素点置为-1,为什么这样做后边有解答!
if self._transform:
return self.transform(img, lbl)
else:
return img, lbl
def transform(self, img, lbl):
img = img[:, :, ::-1] # RGB -> BGR (img中最后一维的维度为3,及RGB,因此只是最后一维中的变化,及,[1,2,3]--->[3,2,1] )
img = img.astype(np.float64)
img -= self.mean_bgr # 去均值
img = img.transpose(2, 0, 1) # 转置 eg: 图像宽:232,高:309, shape为(309,232,3)--> (3,309,232)
img = torch.from_numpy(img).float() # 转为 Tensor
lbl = torch.from_numpy(lbl).long()
return img, lbl
解答:
lbl[lbl == 255] = -1 # 将label中像素为255的像素点置为-1
https://zhuanlan.zhihu.com/p/61203757?utm_source=wechat_session&utm_medium=social&utm_oi=872195207066718208&utm_campaign=shareopn
https://pytorch.org/vision/stable/transforms.html
官方教程,有各种transforms。
在第二节中用到了 transform=transforms.ToTensor()
train_data = MyDataset(root, 'train.txt', transform=transforms.ToTensor())
简单介绍一下torchvision.transforms.ToTensor():
img = torch.from_numpy(pic.transpose((2, 0, 1)))