python学习——PyTorch框架的学习(1)

PyTorch框架的学习

简要声明

基于小土堆课程视频
BiliBili课程资源


文章目录

  • PyTorch框架的学习
    • 简要声明
  • 一、概要
  • 二、自定义`Dataset`实现数据加载
    • 1. `MyDataset`类的定义
      • 数据集初始化
      • 获取单个数据样本
      • 获取数据集长度
    • 2. `MyDataset`的实例化与合并
  • 三、使用TensorBoard进行可视化
    • 1. 引入必要的库
    • 2. 初始化`SummaryWriter`
    • 3. 图像数据的处理与可视化
    • 4. 标量数据的记录与可视化
    • 5. 关闭`SummaryWriter`
    • 6. 如何查看TensorBoard
  • 四、总结


一、概要

该篇笔记提供了一个PyTorch入门的数据加载和可视化指南。首先详细讲解了如何通过自定义torch.utils.data.Dataset类来灵活加载图片数据,包括初始化数据集、按索引获取单个样本以及确定数据集大小。其次则介绍了如何使用TensorBoard来可视化深度学习过程,涵盖了SummaryWriter的初始化、图像(HWC格式)的记录以及标量数据的曲线展示,帮助你直观地监控训练进程.

二、自定义Dataset实现数据加载

在PyTorch中,torch.utils.data.Dataset是一个抽象类,用于表示数据集。你需要继承这个类并实现其核心方法,才能让PyTorch知道如何访问你的数据。这份笔记将讲解如何通过自定义Dataset来加载图片数据。

1. MyDataset类的定义

代码定义了一个名为MyDataset的类,它继承自PyTorch的Dataset。这个类负责处理数据集的初始化、单个数据样本的获取以及数据集的长度计算。

import os
from torch.utils.data import Dataset
from PIL import Image

class MyDataset(Dataset):
    # 类定义开始

数据集初始化

这个方法在创建MyDataset实例时被调用。它的主要作用是设置数据集的根目录和当前类别的子目录,并获取所有图片的文件名列表。

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 数据集的根目录
        self.label_dir = label_dir  # 当前数据集的标签(即子目录名)
        # 拼接出完整的数据路径
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 获取该路径下所有图片的文件名列表
        self.image_path = os.listdir(self.path)
  • root_dir: 指定所有数据所在的顶级文件夹。
  • label_dir: 指定当前类别图片所在的子文件夹名。
  • self.path: 通过os.path.joinroot_dirlabel_dir组合成一个完整的路径,指向特定类别的图片存储位置。
  • self.image_path: 使用os.listdir列出self.path下所有文件和文件夹的名字。在这里,它被用于存储所有图片的文件名。

获取单个数据样本

这个方法是Dataset的核心。当你需要从数据集中获取第index个数据时,PyTorch会自动调用此方法。

    def __getitem__(self, index):
        img_name = self.image_path[index]  # 根据索引获取图片文件名
        img_item_path = os.path.join(self.path, img_name) # 拼接出图片文件的完整路径
        img = Image.open(img_item_path)  # 使用PIL库打开图片
        label = self.label_dir  # 将当前图片的标签设置为其所在的目录名
        return img, label  # 返回加载的图片(PIL Image对象)和其对应的标签(字符串)
  • index: PyTorch提供的一个整数,表示要获取的数据样本的索引。
  • img_name: 根据indexself.image_path列表中获取具体的图片文件名。
  • img_item_path: 组合self.pathimg_name,得到图片的完整文件路径。
  • Image.open(img_item_path): 使用PIL(Python Imaging Library)库的Image.open()函数打开指定路径的图片文件,并将其作为PIL Image对象返回。
  • label = self.label_dir: 直接将当前图片所属的文件夹名称作为其标签。

获取数据集长度

这个方法返回数据集中样本的总数量。

    def __len__(self):
        return len(self.image_path)  # 返回图片文件列表的长度,即数据集中图片的总数
  • len(self.image_path): 返回self.image_path列表中元素的数量,这代表了当前类别图片的总数。

2. MyDataset的实例化与合并

在定义了MyDataset类之后,可以创建它的实例来代表具体的数据集。

root_dir = "Dataset/hymenoptera_data/train"
ants_label_dir = "ants_images"
bees_label_dir = "bees_images"

# 实例化第一个数据集:加载蚂蚁图片
ants_dataset = MyDataset(root_dir, ants_label_dir)

# 实例化第二个数据集:加载蜜蜂图片
bees_dataset = MyDataset(root_dir, bees_label_dir)

# 将两个数据集合并为一个训练数据集
train_dataset = ants_dataset + bees_dataset
  • root_dir: 定义了数据集的根路径。
  • ants_label_dirbees_label_dir: 定义了蚂蚁和蜜蜂图片各自的子目录名。
  • ants_dataset = MyDataset(root_dir, ants_label_dir): 创建了一个MyDataset对象,专门用于加载ants_images目录下的图片。
  • bees_dataset = MyDataset(root_dir, bees_label_dir): 创建了另一个MyDataset对象,用于加载bees_images目录下的图片。
  • train_dataset = ants_dataset + bees_dataset: PyTorch的Dataset类支持通过+运算符进行连接。这行代码将ants_datasetbees_dataset中的所有样本合并到一个新的train_dataset对象中。现在,train_dataset可以被视为一个包含所有蚂蚁和蜜蜂图片及其对应标签的单一数据集。

三、使用TensorBoard进行可视化

在深度学习项目开发中,可视化工具是不可或缺的,它可以帮助我们更好地理解模型训练过程、调试模型、观察数据分布等。PyTorch通常结合TensorBoard来实现这些功能。TensorBoard是Google开发的神经网络训练可视化工具,它可以记录并可视化训练过程中的各种指标、图像、模型结构等。

1. 引入必要的库

首先,我们需要导入使用TensorBoard和图像处理所需的库:

from  torch.utils.tensorboard import SummaryWriter # TensorBoard的主要接口
from  PIL import Image # 用于处理图像,例如打开图片文件
import numpy as np # 用于将PIL图像转换为NumPy数组
  • SummaryWriter: 这是TensorBoard在PyTorch中的核心接口,通过它可以创建事件文件,并将各种数据写入其中,供TensorBoard读取和展示。
  • PIL.Image: Pillow库的Image模块,用于方便地打开、操作和保存各种图像文件。
  • numpy: Python中用于科学计算的基础库,我们通常会使用它来处理图像数据,例如将PIL Image对象转换为数组格式。

2. 初始化SummaryWriter

在使用TensorBoard记录数据之前,你需要创建一个SummaryWriter实例。

writer = SummaryWriter("logs")
  • SummaryWriter("logs"): 这会在当前工作目录下创建一个名为logs的文件夹(如果不存在)。所有通过writer对象记录的数据都会保存在这个文件夹内,形成TensorBoard可以读取的事件文件。你可以自定义这个文件夹的名称。

3. 图像数据的处理与可视化

TensorBoard可以方便地可视化图像。通常,我们需要将图像从PIL Image对象转换为NumPy数组或PyTorch Tensor。

image_path = "Dataset/data/train/bees_image/39747887_42df2855ee.jpg" # 指定图片文件路径
img_PIL = Image.open(image_path) # 使用PIL打开图片,得到PIL Image对象
img_array = np.array(img_PIL) # 将PIL Image对象转换为NumPy数组

print(type(img_array)) # 打印转换后数组的类型
print(img_array.shape) # 打印数组的形状(高度,宽度,通道数)

writer.add_image("test", img_array, 4,dataformats="HWC")
  • image_path: 定义了你想要加载的图片文件的路径。
  • Image.open(image_path): 读取指定路径的图片,返回一个PIL Image对象。
  • np.array(img_PIL): 将PIL Image对象转换为NumPy数组。对于彩色图片,其形状通常是 (height, width, channels),这被称为HWC(Height, Width, Channel)格式。
  • writer.add_image("test", img_array, 4, dataformats="HWC"):
    • 第一个参数 "test" 是这个图像在TensorBoard中显示的标签名称
    • 第二个参数 img_array 是要添加的图像数据,必须是NumPy数组或PyTorch Tensor。
    • 第三个参数 4 是一个可选的**global_step**(全局步数)。在训练过程中,这通常是当前的训练迭代次数或epoch数,用于区分不同时间点记录的数据。在这里,你简单地将其设置为4
    • dataformats="HWC": 非常重要! 这个参数告诉TensorBoard你的图像数据是以“高度(Height)、宽度(Width)、通道(Channel)”的顺序排列的。TensorBoard默认期望的图像格式是“通道、高度、宽度”(CHW)。如果你的数据是HWC,必须明确指定,否则图像可能会显示不正确。

4. 标量数据的记录与可视化

TensorBoard最常用的功能之一是记录和可视化标量(Scalar)数据,例如损失(loss)、准确率(accuracy)等指标随训练步数的变化。

for i in range(100):
    writer.add_scalar("y=3x",3*i,i)
  • for i in range(100): 这里模拟了100个训练步骤。
  • writer.add_scalar("y=3x", 3*i, i):
    • 第一个参数 "y=3x" 是这个标量值在TensorBoard中显示的标签名称
    • 第二个参数 3*i 是你要记录的
    • 第三个参数 i 是当前的**global_step**。TensorBoard会根据这个步数来绘制曲线图,显示值随步数的变化趋势。

5. 关闭SummaryWriter

完成所有数据记录后,务必关闭SummaryWriter

writer.close()

6. 如何查看TensorBoard

在运行了上述Python脚本后,你可以在命令行中导航到包含logs文件夹的目录(或者你自定义的日志文件夹),然后运行以下命令来启动TensorBoard:

tensorboard --logdir=logs

或者,如果你的日志文件夹就在当前目录:

tensorboard --logdir=.

然后,TensorBoard会给出一个本地地址(通常是http://localhost:6006/),你可以在浏览器中打开这个地址,然后切换到"Images"和"Scalars"等选项卡,就可以看到你记录的数据可视化结果了。


四、总结

本次练习为pytorch的系统学习开了头,初步了解了torch相关的应用,下一周将继续学习。

你可能感兴趣的:(python学习,python,学习,pytorch)