课程学习笔记,课程链接
目录
一、torchvision
二、CIFAR数据集
下载数据集
数据集的使用
transforms的使用
目的:
如何把数据集和 Transforms 结合在一起;
介绍科研中使用的一些标准数据集和下载、查看、使用方法
PyTorch官网
点进 torchvision
torchvision 文档列出了很多科研或者毕设常用的一些数据集,如入门数据集 MNIST,用于手写文字。这些数据集位于 torchvision.datasets 模块,可以通过该模块对数据集进行下载,转换等操作。
torchvision 还有 io模块,但不常用;torchvision.models 会提供一些训练好的神经网络模型,在之后会用到;torchvision.transforms 之前已经学习过了,主要提供一些数据处理的工具。
接下来主要讲解如何联合使用 torchvision.datasets 和 torchvision.transforms
用到的数据集是 CIFIAR,点击官网文档进行查看。
数据集官网,该数据集包含了 60000 张 32*32 像素的 10 个类别的彩色图片,每个种类 6000 张图片,其中 60000 张中 50000 张是训练图片,10000 张是测试图片。
下载路径,非代码下载
import torchvision
'''
若有 ssl 报错,添加如下代码
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
'''
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
然后运行,torchvision 就会自动进行 CIFAR10 数据集的下载。
这里分别下载训练集和测试集,下载好后会放到所设置的路径下,这里下载的数据集会被放带当前目录的 dataset目录下。
查看下测试数据集中每个数据包含什么:
import torchvision
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
输出:
即一个数据单元里包含输入图片和对应的 tag,这里用数字进行映射,数字 3 也就是表示 cat ,可用 img.show() 查看下图片。因为这个数据集比较小,只有百 MB,图片像素只有 32*32,所以模糊,这里是将其分类为猫。
由于数据集的图片类型是 PIL Image,torch 无法直接使用,所以要先转为 tensor,通过 transforms 实现。这里用 transforms 将图片转为 tensor 类型后,并用 tensorboard 进行查看。
import torchvision
import ssl
from torch.utils.tensorboard import SummaryWriter
ssl._create_default_https_context = ssl._create_unverified_context
# 目的将设置成 ToTensor 的 transforms 应用到数据集的每张图片,都转成 tensor 数据类型
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
print(test_set[0])
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
结果,这里查看了索引为 [0]~[9] 的十张图片