基于小土堆课程视频
BiliBili课程资源
该篇笔记提供了一个PyTorch入门的数据加载和可视化指南。首先详细讲解了如何通过自定义torch.utils.data.Dataset
类来灵活加载图片数据,包括初始化数据集、按索引获取单个样本以及确定数据集大小。其次则介绍了如何使用TensorBoard来可视化深度学习过程,涵盖了SummaryWriter
的初始化、图像(HWC格式)的记录以及标量数据的曲线展示,帮助你直观地监控训练进程.
Dataset
实现数据加载在PyTorch中,torch.utils.data.Dataset
是一个抽象类,用于表示数据集。你需要继承这个类并实现其核心方法,才能让PyTorch知道如何访问你的数据。这份笔记将讲解如何通过自定义Dataset
来加载图片数据。
MyDataset
类的定义代码定义了一个名为MyDataset
的类,它继承自PyTorch的Dataset
。这个类负责处理数据集的初始化、单个数据样本的获取以及数据集的长度计算。
import os
from torch.utils.data import Dataset
from PIL import Image
class MyDataset(Dataset):
# 类定义开始
这个方法在创建MyDataset
实例时被调用。它的主要作用是设置数据集的根目录和当前类别的子目录,并获取所有图片的文件名列表。
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 数据集的根目录
self.label_dir = label_dir # 当前数据集的标签(即子目录名)
# 拼接出完整的数据路径
self.path = os.path.join(self.root_dir, self.label_dir)
# 获取该路径下所有图片的文件名列表
self.image_path = os.listdir(self.path)
root_dir
: 指定所有数据所在的顶级文件夹。label_dir
: 指定当前类别图片所在的子文件夹名。self.path
: 通过os.path.join
将root_dir
和label_dir
组合成一个完整的路径,指向特定类别的图片存储位置。self.image_path
: 使用os.listdir
列出self.path
下所有文件和文件夹的名字。在这里,它被用于存储所有图片的文件名。这个方法是Dataset
的核心。当你需要从数据集中获取第index
个数据时,PyTorch会自动调用此方法。
def __getitem__(self, index):
img_name = self.image_path[index] # 根据索引获取图片文件名
img_item_path = os.path.join(self.path, img_name) # 拼接出图片文件的完整路径
img = Image.open(img_item_path) # 使用PIL库打开图片
label = self.label_dir # 将当前图片的标签设置为其所在的目录名
return img, label # 返回加载的图片(PIL Image对象)和其对应的标签(字符串)
index
: PyTorch提供的一个整数,表示要获取的数据样本的索引。img_name
: 根据index
从self.image_path
列表中获取具体的图片文件名。img_item_path
: 组合self.path
和img_name
,得到图片的完整文件路径。Image.open(img_item_path)
: 使用PIL(Python Imaging Library)库的Image.open()
函数打开指定路径的图片文件,并将其作为PIL Image对象返回。label = self.label_dir
: 直接将当前图片所属的文件夹名称作为其标签。这个方法返回数据集中样本的总数量。
def __len__(self):
return len(self.image_path) # 返回图片文件列表的长度,即数据集中图片的总数
len(self.image_path)
: 返回self.image_path
列表中元素的数量,这代表了当前类别图片的总数。MyDataset
的实例化与合并在定义了MyDataset
类之后,可以创建它的实例来代表具体的数据集。
root_dir = "Dataset/hymenoptera_data/train"
ants_label_dir = "ants_images"
bees_label_dir = "bees_images"
# 实例化第一个数据集:加载蚂蚁图片
ants_dataset = MyDataset(root_dir, ants_label_dir)
# 实例化第二个数据集:加载蜜蜂图片
bees_dataset = MyDataset(root_dir, bees_label_dir)
# 将两个数据集合并为一个训练数据集
train_dataset = ants_dataset + bees_dataset
root_dir
: 定义了数据集的根路径。ants_label_dir
和 bees_label_dir
: 定义了蚂蚁和蜜蜂图片各自的子目录名。ants_dataset = MyDataset(root_dir, ants_label_dir)
: 创建了一个MyDataset
对象,专门用于加载ants_images
目录下的图片。bees_dataset = MyDataset(root_dir, bees_label_dir)
: 创建了另一个MyDataset
对象,用于加载bees_images
目录下的图片。train_dataset = ants_dataset + bees_dataset
: PyTorch的Dataset
类支持通过+
运算符进行连接。这行代码将ants_dataset
和bees_dataset
中的所有样本合并到一个新的train_dataset
对象中。现在,train_dataset
可以被视为一个包含所有蚂蚁和蜜蜂图片及其对应标签的单一数据集。在深度学习项目开发中,可视化工具是不可或缺的,它可以帮助我们更好地理解模型训练过程、调试模型、观察数据分布等。PyTorch通常结合TensorBoard来实现这些功能。TensorBoard是Google开发的神经网络训练可视化工具,它可以记录并可视化训练过程中的各种指标、图像、模型结构等。
首先,我们需要导入使用TensorBoard和图像处理所需的库:
from torch.utils.tensorboard import SummaryWriter # TensorBoard的主要接口
from PIL import Image # 用于处理图像,例如打开图片文件
import numpy as np # 用于将PIL图像转换为NumPy数组
SummaryWriter
: 这是TensorBoard在PyTorch中的核心接口,通过它可以创建事件文件,并将各种数据写入其中,供TensorBoard读取和展示。PIL.Image
: Pillow
库的Image
模块,用于方便地打开、操作和保存各种图像文件。numpy
: Python中用于科学计算的基础库,我们通常会使用它来处理图像数据,例如将PIL Image对象转换为数组格式。SummaryWriter
在使用TensorBoard记录数据之前,你需要创建一个SummaryWriter
实例。
writer = SummaryWriter("logs")
SummaryWriter("logs")
: 这会在当前工作目录下创建一个名为logs
的文件夹(如果不存在)。所有通过writer
对象记录的数据都会保存在这个文件夹内,形成TensorBoard可以读取的事件文件。你可以自定义这个文件夹的名称。TensorBoard可以方便地可视化图像。通常,我们需要将图像从PIL Image对象转换为NumPy数组或PyTorch Tensor。
image_path = "Dataset/data/train/bees_image/39747887_42df2855ee.jpg" # 指定图片文件路径
img_PIL = Image.open(image_path) # 使用PIL打开图片,得到PIL Image对象
img_array = np.array(img_PIL) # 将PIL Image对象转换为NumPy数组
print(type(img_array)) # 打印转换后数组的类型
print(img_array.shape) # 打印数组的形状(高度,宽度,通道数)
writer.add_image("test", img_array, 4,dataformats="HWC")
image_path
: 定义了你想要加载的图片文件的路径。Image.open(image_path)
: 读取指定路径的图片,返回一个PIL Image对象。np.array(img_PIL)
: 将PIL Image对象转换为NumPy数组。对于彩色图片,其形状通常是 (height, width, channels)
,这被称为HWC(Height, Width, Channel)格式。writer.add_image("test", img_array, 4, dataformats="HWC")
:
"test"
是这个图像在TensorBoard中显示的标签名称。img_array
是要添加的图像数据,必须是NumPy数组或PyTorch Tensor。4
是一个可选的**global_step
**(全局步数)。在训练过程中,这通常是当前的训练迭代次数或epoch数,用于区分不同时间点记录的数据。在这里,你简单地将其设置为4
。dataformats="HWC"
: 非常重要! 这个参数告诉TensorBoard你的图像数据是以“高度(Height)、宽度(Width)、通道(Channel)”的顺序排列的。TensorBoard默认期望的图像格式是“通道、高度、宽度”(CHW)。如果你的数据是HWC,必须明确指定,否则图像可能会显示不正确。TensorBoard最常用的功能之一是记录和可视化标量(Scalar)数据,例如损失(loss)、准确率(accuracy)等指标随训练步数的变化。
for i in range(100):
writer.add_scalar("y=3x",3*i,i)
for i in range(100)
: 这里模拟了100个训练步骤。writer.add_scalar("y=3x", 3*i, i)
:
"y=3x"
是这个标量值在TensorBoard中显示的标签名称。3*i
是你要记录的值。i
是当前的**global_step
**。TensorBoard会根据这个步数来绘制曲线图,显示值随步数的变化趋势。SummaryWriter
完成所有数据记录后,务必关闭SummaryWriter
。
writer.close()
在运行了上述Python脚本后,你可以在命令行中导航到包含logs
文件夹的目录(或者你自定义的日志文件夹),然后运行以下命令来启动TensorBoard:
tensorboard --logdir=logs
或者,如果你的日志文件夹就在当前目录:
tensorboard --logdir=.
然后,TensorBoard会给出一个本地地址(通常是http://localhost:6006/
),你可以在浏览器中打开这个地址,然后切换到"Images"和"Scalars"等选项卡,就可以看到你记录的数据可视化结果了。
本次练习为pytorch的系统学习开了头,初步了解了torch相关的应用,下一周将继续学习。