Pytorch读取数据集

Pytorch数据加载

  • 1. torch.nn.utils.data
    • 1.1 torch.utils.data.Dataset
    • 1.2 torch.utils.data.DataLoader
  • 2. label为number的数据集读取(eg:分类)
  • 3. label为image的数据集读取(eg:分割)
  • 4. torchvision.transforms

数据集是数据的集合,Dataset表示数据集的抽象类

1. torch.nn.utils.data

官方文档:https://pytorch.org/docs/stable/data.html

让pytorch读取数据集,主要用到 torch.utils.data.Dataset 和 torch.utils.data.DataLoader

1.1 torch.utils.data.Dataset

表示Dataset的抽象类。所有其他数据集都应该进行子类化。所有子类应该重写__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
官方代码:

class Dataset(Generic[T_co]):
    """
    An abstract class representing a:class:'Dataset'.

    All datasets that represent a map from keys to data samples should subclass it.
    All subclasses should overwrite:meth:'__getitem__', supporting fetching a data 
    sample for a given key.Subclasses could also optionally overwrite:meth:'__len__',
    which is expected to return the size of the dataset by many:class:
    '~torch.utils.data.Sampler' implementations and the default options of :class:
    '~torch.utils.data.DataLoader'.

    note:
    :class:`~torch.utils.data.DataLoader` by default constructs a index
    sampler that yields integral indices.  To make it work with a map-style
    dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index)
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError
    def __add__(self, other: 'Dataset[T_co]')
        return ConcatDataset([self, other])

1.2 torch.utils.data.DataLoader

数据加载器:

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, 
      sampler=None, num_workers=0, drop_last=False)

参数:
dataset (Dataset) – 加载数据的数据集。
batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

2. label为number的数据集读取(eg:分类)

第一步:定义自己的Dateset,继承 'torch.utils.data.Dataset ’ ,同时重写__getitem__和__len__。
类中的def__getitem__(self, index):接收一个index,然后通过index返回图片数据和标签。这个index指的是一个list的index,这个list中的每个元素包含了图片数据的路径和标签信息。
(1)构建一个上述的list
(2)重写__getitem__,返回list[index]中的数据
(3)重写__len__,返回list的长度
第二步:通过’torch.utils.data.DataLoader’ 读取自己的数据集。

具体细节:
:list是如何构建?
:首先需要一个txt(txt中是数据集的信息),通过读取txt构建list。
再问:txt什么样子呢?
再答
是这样:
Pytorch读取数据集_第1张图片
这个是mnist数据集,txt中每行是每张图片的信息(前半部分为图片位置,后半部分为图片的label)。注意: 往往图片的信息不是具体位置(往下看!)。

或者这样:
Pytorch读取数据集_第2张图片
这个是voc数据集,可以看到,不同的是,每行是由图片的名字和label组成。
再注意: 根据自己txt中的内容构建list,要会灵活变通!

举个栗子八!(按照mnist那个txt读取的哦):

from PIL import Image
import torch
import torchvision.transforms as transforms

# 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root, datatxt, transform=None, target_transform=None):
        # 初始化一些需要传入的参数。transform对图片进行预处理(减均值,除标准差,随机裁剪,旋转,翻转,放射变换,等操作)
        dataset_dir = open(root + datatxt, 'r')  # 打开这个txt,并读取内容
        data_list = []  # 创建一个空列表,用来存图片的路径和label信息
        for line in dataset_dir:  # 按行读取txt中的内容
            line = line.strip()  # 删除本行开头和末尾的指定字符(eg:空格),type为str
            line_list = line.split()  # 通过指定分隔符(eg:空格)对字符串进行切片,type为list
            data_list.append((line_list[0], int(line_list[1])))  # 把txt里的内容读入list, line_list[0]是图片信息,line_list[1]是lable, 组成数组,作为list中的一个元素
        self.data_list = data_list
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):  # 用于按照index读取每个元素的具体内容
        img_file, label = self.data_list[index] # img_file:data_list[0],即图片路径。label:data_list[1],即标签信息
        img = Image.open(img_file).convert('RGB') # 按照路径读取图片
        if self.transform is not None: # 是否进行transform
            img = self.transform(img) 
        return img, label

    def __len__(self):  # 返回的是数据集的长度
        return len(self.imgs)

root = "D:/MNIST/mnist_dataset/MNIST/raw/"
train_data = MyDataset(root, 'train.txt', transform=transforms.ToTensor())
test_data = MyDataset(root, 'test.txt', transform=transforms.ToTensor())

#最有就是调用DataLoader和刚刚创建的数据集,来创建dataloader
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=64)

3. label为image的数据集读取(eg:分割)

做法和上一小节是一样的,只不过这次我们的label不是数字啦,而是图片!
直接上栗子:
首先看看我们数据集的txt!
Pytorch读取数据集_第3张图片
咦?咋没有label信息呢?
答: 这是因为我们的图片和label都是同一个名称,只是在不同的文件夹中,我们就可以只通过名字来读取啦,看下边这张图!
Pytorch读取数据集_第4张图片
本例采用voc数据集:voc数据集中有很多文件夹,下面例子只用到ImageSets,JPEGImages和SegmentationClass文件夹。其中ImageSets:是各种不同txt文件,JPEGImages:是原图 ,SegmentationClass:是label图片。
通过txt读名称,通过名称分别对JPEGImages和SegmentationClass进行访问,读取对应的图片构建list
Pytorch读取数据集_第5张图片

OK!上代码!

代码是https://github.com/wkentaro/pytorch-fcn中的dataset

下图是加了标注!

class VOCClassSegBase(data.Dataset):

    class_names = np.array([
        'background',
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'potted plant',
        'sheep',
        'sofa',
        'train',
        'tv/monitor',
    ])

    '''
    这里的[104.00698793, 116.66876762, 122.67891434]指的是数据集里面所有训练集数据的
    平均像素值,三个值对应图像的BGR,用于后边对图片进行减均值的预处理。
    '''
    mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434])

    def __init__(self, root, split='train', transform=False):
        self.root = root
        self.split = split
        self._transform = transform

        # 此路径下是ImageSets,JPEGImages和SegmentationClass文件夹
        dataset_dir = osp.join(self.root, 'VOC/VOCdevkit/VOC2012')

        # Python中通过Key访问字典,当Key不存在时,会引发‘KeyError’异常。为了避免这种情况的
        #发生,可以使用collections类中的defaultdict()方法来为字典提供默认值。
        self.files = collections.defaultdict(list)

        # 构建list
        for split in ['train', 'val']:
            imgsets_file = osp.join(
                dataset_dir, 'ImageSets/Segmentation/%s.txt' % split) # txt路径
            for did in open(imgsets_file): # 按行读取txt中的内容
                did = did.strip()  # txt中只有名字!
                img_file = osp.join(dataset_dir, 'JPEGImages/%s.jpg' % did) # 原图路径
                lbl_file = osp.join(
                    dataset_dir, 'SegmentationClass/%s.png' % did)  # label路径
                self.files[split].append({
                    'img': img_file,
                    'lbl': lbl_file,
                }) # list中每个元素为字典{'img': img_file,'lbl': lbl_file,}

    def __len__(self):
        return len(self.files[self.split])

    def __getitem__(self, index):
        data_file = self.files[self.split][index] # {'img': img_file,'lbl': lbl_file,}
        # load image
        img_file = data_file['img']
        img = PIL.Image.open(img_file)       # RGB
        img = np.array(img, dtype=np.uint8) # 将图片转为np数组
        # load label
        lbl_file = data_file['lbl']
        lbl = PIL.Image.open(lbl_file)
        lbl = np.array(lbl, dtype=np.int32)
        lbl[lbl == 255] = -1 # 将label中像素为255的像素点置为-1,为什么这样做后边有解答!
        if self._transform:
            return self.transform(img, lbl)
        else:
            return img, lbl

    def transform(self, img, lbl):
        img = img[:, :, ::-1]  # RGB -> BGR (img中最后一维的维度为3,及RGB,因此只是最后一维中的变化,及,[1,2,3]--->[3,2,1] )
        img = img.astype(np.float64)
        img -= self.mean_bgr      # 去均值
        img = img.transpose(2, 0, 1)   # 转置  eg: 图像宽:232,高:309, shape为(309,232,3)--> (3,309,232)
        img = torch.from_numpy(img).float()  # 转为 Tensor
        lbl = torch.from_numpy(lbl).long()
        return img, lbl

解答:

  1. lbl[lbl == 255] = -1 # 将label中像素为255的像素点置为-1
    语义分割:像素级的分类,对每个像素进行分类即可得到分割图像。所以label图片里的像素值取值范围是0 ~ n - 1 ,即像素点可能的类别共有 n 类 。
    对于21个类别,灰度值为0—20来说,label图片看上去都是黑色的(label原来是几通道的无所谓,只要读取的时候转化成灰度图就行),只有白色的轮廓而已,而白色的轮廓的灰度值是255!
    但是这些轮廓在计算损失值的时候是不作为有效值的,也就是对于灰度值=255的点是忽略的。
  2. 去均值:是为了对图像进行标准化,可以移除图像的平均亮度值 (intensity)。在每个样本上减去数据的统计平均值可以移除共同的部分,凸显个体差异。
  3. transpose函数的理解:下边这个知乎大佬讲的很清晰

https://zhuanlan.zhihu.com/p/61203757?utm_source=wechat_session&utm_medium=social&utm_oi=872195207066718208&utm_campaign=shareopn

4. torchvision.transforms

https://pytorch.org/vision/stable/transforms.html
官方教程,有各种transforms。

在第二节中用到了 transform=transforms.ToTensor()
train_data = MyDataset(root, 'train.txt', transform=transforms.ToTensor())
简单介绍一下torchvision.transforms.ToTensor():

  1. transforms.ToTensor()具有归一化功能:将像素值从[0,255]转化为[0,1] 。
  2. transforms.ToTensor()具有维度转换功能:将shape (H x W x C)转换为 shape (C x H x W),即img = torch.from_numpy(pic.transpose((2, 0, 1)))

你可能感兴趣的:(python,深度学习,pytorch)