Dataset和DataLoader

目录

  • 加载PyTorch数据集
  • 创建自定义数据集
  • 使用DataLoader处理数据集

用于处理数据样本的代码可能会变得混乱且难以维护。理想情况下,我们希望我们的数据集代码与模型训练代码分离,以获得更好的可读性和模块化。PyTorch 提供了两个数据基元: torch.utils.data.DataLoadertorch.utils.data.Dataset,它允许您使用预加载的数据集或者自己的数据。 Dataset存储样本及其相应的标签, DataLoader使用Dataset创建可迭代对象,快速访问数据。

PyTorch 自己的库提供了许多预加载的数据集(Dataset),都是torch.utils.data.Dataset的子类。还提供了一些针对特定数据集的方法函数。这些数据集可用于模型设计和基准测试。您可以在此处找到它们:图像数据集,文本数据集和音频数据集。

加载PyTorch数据集

下面以导入TorchVision的Fashion-MNIST数据集为例。
Fashion-MNIST是Zalando文章中的图像数据集,由60000个训练示例和10000个测试示例组成。每个示例包括一个 28×28 灰度图像和一个标签(共有十个类别)。

我们使用以下参数加载数据集:

  • root:是存储训练/测试数据的路径
  • train:True代表加载训练数据,False加载测试数据。
  • download:True从互联网上下载数据到root,如果root有数据就不会下载。
  • transform:输入特征的变换(默认None)
  • target_transform:标签的变换(默认None)
from torchvision import datasets
from torchvision.transforms import ToTensor

# 导入训练数据
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

# 导入测试数据
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Dataset类对象就像list一样可以通过索引获取对应位置数据。

# 获取第一个样本数据
img, label = training_data[0]

创建自定义数据集

一个自定义的数据集类需要重载Dataset类,而且必须从写以下三个函数:_init_ 、_len_ 和 _getitem_。
__init__在Dataset对象初始化的时候调用。一般需要初始化样本和标签的存储路径,根据不同的数据结构有不同的初始化方式。
__len__返回数据数量。
__getitem__返回一个样本从数据集通过传入的索引idx

使用DataLoader处理数据集

Dataset每次只能提取一个样本。但在训练模型的时候,常常需要minibatch,每次传入多个数据;每一个epoch结束后需要shuffle(打乱)数据,来防止过拟合。所以,DataLoader应运而生,可以轻松实现以上功能。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

注意:DataLoader是一个可迭代对象,可以使用next访问或者循环,不能使用索引访问数据。

# 使用next访问
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

# 使用for循环
for data in train_dataloader:
	imgs,labels = data

你可能感兴趣的:(PyTorch快速入门,深度学习,pytorch,python)