PyTorch 是一个开源的深度学习框架,由 Facebook AI 开发,它的特点是:
torchvision 是 PyTorch 的官方图像处理工具包,它包含:
它们组合起来特别适合做:
ImageFolder
读取图片transforms
做预处理mobilenet_v3
、resnet18
等预训练模型做分类
这张图展示了使用 PyTorch + torchvision 进行图像分类的标准工作流程,简单直观,适合入门理解整个训练流程:
Dataset(数据集)
原始图片数据,分好类放在文件夹里(通常是 train/0
, train/1
这样)
ImageFolder Dataset
使用 torchvision.datasets.ImageFolder
加载数据,它会自动根据子文件夹名来分配类别标签。
⚙️ transforms(图像变换)
使用 transforms
对图像进行预处理,如裁剪、缩放、标准化等,以适配模型输入要求。
Train / Classify(训练 / 分类)
构建分类模型(通常是 CNN),训练并评估分类性能。
Pre-trained Model(预训练模型)
加载如 resnet18
、efficientnet
等模型,并可选择是否使用其预训练权重(如在 ImageNet 上预训练过的)。
DataLoader(数据加载器)
使用 torch.utils.data.DataLoader
来高效地批量加载训练数据,支持多线程、shuffle 等功能。
核心理念:
ImageFolder
管理数据结构transform
对图像进行标准预处理DataLoader
管理训练加载效率预训练模型
加快收敛和提升精度