torch.utils.data.DataLoader源码地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
Dataloader接口主要将上篇博客中自定义的数据读取接口的输出按照batch size封装成Tensor,后续再包装成Variable作为模型的输入。该接口就像是“快递公司(dataloader)”,把一个个“物品(数据)”“打包(封装)成快递(Tensor)”,然后进入“高速公路(搭建好的网络)”传输。
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0):
1.初始化一个batch_size=4的数据加载器对象,即初始化了一个“装车策略”
dataloader = DataLoader(transformed_dataset, batch_size=4,
shuffle=True, num_workers=4)
2.由于dataloader类中定义了__iter__()魔法函数,所以可以使用for/enumerate进行调用
此处的sample_batched就是在该“装车策略下装好的一辆车”,即一个batch的Tensor。
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(),
sample_batched['landmarks'].size())
# observe 4th batch and stop.
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
3.show_landmarks_batch 可视化 dataloader返回的batch samples。
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = \
sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
landmarks_batch[i, :, 1].numpy() + grid_border_size,
s=10, marker='.', c='r')
plt.title('Batch from dataloader')