PyTorch深度学习入门笔记(六)torchvision 中的数据集使用

课程学习笔记,课程链接
学习笔记同步发布在我的个人网站上,欢迎来访查看。

文章目录

  • 一、torchvision
  • 二、CIFAR数据集
    • 2.1 下载数据集
    • 2.2 数据集的使用
    • 2.3 transforms的使用
    • 2.3 其他数据集的使用

目的:

  1. 如何把数据集和 Transforms结合在一起
  2. 介绍科研中使用的一些标准数据集和下载、查看、使用方法

一、torchvision

pytorch官网:https://pytorch.org/
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第1张图片
点进 torchvision:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第2张图片
torchvision 文档列出了很多科研或者毕设常用的一些数据集,如入门数据集MNIST,用于手写文字。这些数据集位于 torchvision.datasets模块,可以通过该模块对数据集进行下载,转换等操作。
torchvision还有 io模块,但不常用
torchvision.models会提供一些训练好的神经网络模型,在之后会用到。
torchvision.transforms之前已经学习过了,主要提供一些数据处理的工具。
接下来主要讲解如何联合使用 torchvision.datasetstorchvision.transforms

二、CIFAR数据集

用到的数据集是 CIFIAR,点击官网文档进行查看:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第3张图片
具有一些参数,如root,设置数据集所在目录路径等。
数据集官网:https://www.cs.toronto.edu/~kriz/cifar.html
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第4张图片
这个数据集包含了 60000张32*32像素的10个类别的彩色图片,每个种类6000张图片,其中60000张中50000张是训练图片,10000张是测试图片。

2.1 下载数据集

代码:

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set  = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

然后运行,torchvision 就会自动进行 CIFAR10 数据集的下载:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第5张图片

这里分别下载训练集和测试集,下载好后会放到所设置的路径下,这里下载的数据集会被放带当前目录的 dataset目录下。

2.2 数据集的使用

查看下测试数据集中每个数据包含什么:

print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)

输出:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第6张图片
即一个数据单元里包含输入图片和对应的 tag,这里用数字进行映射,数字 3 也就是表示 cat ,用 img.show() 查看下图片:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第7张图片
因为这个数据集比较小,只有百MB,图片像素只有32*32,所以挺模糊的,但这里是将其分类为猫。

2.3 transforms的使用

由于数据集的图片类型是PIL Image,torch无法直接使用,所以要先转为tensor,通过 transforms 实现。
这里用transforms将图片转为 tensor类型后,用tensorboard进行查看:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])


train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

结果:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第8张图片
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第9张图片

这里查看了索引为[0]~[9]的十张图片

2.3 其他数据集的使用

使用其他数据集的方法也很简单,设置好所需参数就可以了:
PyTorch深度学习入门笔记(六)torchvision 中的数据集使用_第10张图片
以COCO为例,一样的,设置存储路径,json文件存储路径,transform,target_transform,transforms 就可以了。
注:如果下载速度比较慢,可以通过将下载路径复制到迅雷中进行数据集下载,下载好后设置好存放目录的路径就行。

你可能感兴趣的:(深度学习,pytorch,深度学习,python)