【PyTorch学习】(四)Dataloader的正确打开方式

torch.utils.data.DataLoader源码地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py


Dataloader接口主要将上篇博客中自定义的数据读取接口的输出按照batch size封装成Tensor,后续再包装成Variable作为模型的输入。该接口就像是“快递公司(dataloader)”,把一个个“物品(数据)”“打包(封装)成快递(Tensor)”,然后进入“高速公路(搭建好的网络)”传输。

 

Dataloader的官方定义:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0):
  • dataset:上一篇介绍的按自定义形式(一般为(image, label)形式)准备好的数据集,也就是“快递物品”;
  • batch_size:关键参数,可以理解为“一车运送多少件物品”;
  • shuffle:每个epoch是否将数据打乱,默认是False即不打乱;
  • sampler:抽样方法,定义一种策略从dataset中返回一个index,即"要运送哪一件物品";
  • batch_sampler:批量抽样方法,定义一种策略从dataset中返回批量的indices,即"要运送哪几件物品";
  • num_workers:使用几个进程(与cpu核数有关)加载数据,默认"0"代表只使用主进程进行数据加载;
  • collate_fn:merges a list of samples to form a mini-batch of Tensors,即“如何把物品装箱,横着放还是竖着放”;
  • pin_memory:设置为True时,pytorch会采取一些列骚操作(源码很长...我昏了)把数据拷贝到GPU进行加速,即“用快车(GPU)运还是慢车(CPU)运”;
  • drop_last:len(dataset) / batch_size (总运送物品/每车运送的物品数)如果整除万事大吉,但如果不能整除,最后剩下的不足一车的物品如何处理,默认是False即“即使最后仅剩一件,也派一辆车送”,设为True则“丢弃最后剩下的”;
  • timeout:即“每次装完一车物品休息多久”;

以官方的人脸的例子为例:

1.初始化一个batch_size=4的数据加载器对象,即初始化了一个“装车策略

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

2.由于dataloader类中定义了__iter__()魔法函数,所以可以使用for/enumerate进行调用

此处的sample_batched就是在该“装车策略下装好的一辆车”,即一个batch的Tensor。

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

3.show_landmarks_batch 可视化 dataloader返回的batch samples。

# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

【PyTorch学习】(四)Dataloader的正确打开方式_第1张图片

你可能感兴趣的:(PyTorch学习)