Pytorch如何自定义dataloader的返回格式

例子

例如:自定义的dataset返回的单个样本格式为:(string, tensor)
直接用dataloader(dataset)得到的loader是不能够自动把上述格式转换为batch的。

解决方法:

需要自定义一个collate_function用于返回batch。

def collate_function(data):
	"""
	:data: a list for a batch of samples. [[string, tensor], ..., [string, tensor]]
	"""
    transposed_data = list(zip(*data))
    directorys, imgs = transposed_data[0], transposed_data[1]
    imgs = torch.stack(imgs, 0)
    return (directorys, imgs)

dataloader = torch.utils.data.DataLoader(Dataset(transforms=data_transforms, train=False),
                                         batch_size=2, collate_fn=collate_function, shuffle=True, num_workers=1, pin_memory=True)

参考:

Dataloader的官网源码
PytorchDiscuss

你可能感兴趣的:(学术)