PyTorch中DataLoader与Dataset的使用、关系&原理

DataLoader(torch.utils.data.DataLoader)

功能:构建可迭代的数据装载器

类中的几个主要变量定义功能介绍如下,除此之外还有11个参数
dataset:Dataset类,决定数据从哪里读取以及如何读取
batchsize:批处理的大小
num_works:是否多进程读取数据
shuffle:每个epoch是否乱序 (true or false)
drop_last:当样本数不能被batchsize整除时,是否丢弃最后一批数据。(true or false)

Dataset(torch.utils.data.Dataset)

功能:Dataset是一个抽象类,所有自定义的Dataset都需要继承这个父类,并且重载其中的__getitem__()函数。
getitem:输入是索引,输出是样本(加标签),也就是定义了索引到样本的映射规则。

代码示例

以图像的分类任务为例(代码中具体是人民币图像识别1元与100元,进行面值二分类任务),重点观察dataset与dataloader如何完成数据读取的任务。
首先主函数中有如下定义

train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True

你可能感兴趣的:(pytorch学习笔记,python,机器学习,深度学习)