本文将通过代码实战,详细讲解如何使用 PyTorch 和 VGG11 模型在 Fashion-MNIST 数据集上进行图像分类任务。代码包含数据预处理、模型定义、训练与评估全流程,并附上训练结果的可视化图表。所有代码可直接复现,适合深度学习初学者和进阶开发者参考。
确保已安装以下库:
pip install torch torchvision d2l
from d2l import torch as d2l
from torchvision import models, transforms
import torch
由于VGG11默认接受RGB三通道输入,需将Fashion-MNIST的灰度图转换为3通道:
# 定义数据预处理流程
transform = transforms.Compose([
transforms.Resize(224), # 调整图像尺寸为224x224
transforms.Grayscale(num_output_channels=3), # 单通道转三通道
transforms.ToTensor() # 转为Tensor格式
])
# 加载Fashion-MNIST数据集并应用预处理
batch_size = 64 * 3 # 增大批大小以利用GPU并行计算
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)
# 替换原始数据集的数据增强方法
train_data.dataset.transform = transform
test_data.dataset.transform = transform
使用PyTorch内置的VGG11模型(从头训练,不使用预训练权重):
# 初始化VGG11模型(输入通道为3,输出类别为10)
net = models.vgg11(pretrained=False, num_classes=10)
调用D2L库的封装函数进行训练(支持GPU加速):
# 设置超参数并启动训练
num_epochs = 10
lr = 0.01
device = d2l.try_gpu() # 自动检测GPU
# 开始训练
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
Epoch | Train Loss | Train Acc | Test Acc | Speed (examples/sec) |
---|---|---|---|---|
1 | 0.85 | 70.2% | 78.5% | 112.3 |
3 | 0.312 | 88.6% | 88.1% | 117.7 |
5 | 0.32 | 87.6% | 84.3% | 118.5 |
10 | 0.21 | 91.8% | 85.7% | 119.0 |
训练损失(Train Loss):随着训练轮次增加,损失快速下降并趋于稳定。例如,第3轮时损失降至 0.312,表明模型快速收敛。
训练准确率(Train Acc):第3轮时达到 88.6%,说明模型对训练数据的学习效果显著。
测试准确率(Test Acc):第3轮测试准确率 88.1%,与训练准确率接近,表明模型泛化能力优秀,未出现明显过拟合。
训练速度:在 cuda:0
设备上达到 117.7 examples/sec,充分利用GPU加速,适合大规模数据训练。
from d2l import torch as d2l
from torchvision import models, transforms
import torch
# 数据预处理
transform = transforms.Compose([
transforms.Resize(224),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor()
])
# 加载数据集
batch_size = 64 * 3
train_data, test_data = d2l.load_data_fashion_mnist(batch_size, resize=224)
train_data.dataset.transform = transform
test_data.dataset.transform = transform
# 定义模型
net = models.vgg11(pretrained=False, num_classes=10)
# 训练配置
num_epochs = 10
lr = 0.01
device = d2l.try_gpu()
# 启动训练
d2l.train_ch6(net, train_data, test_data, num_epochs, lr, device)
VGG系列模型设计时默认接受RGB输入(3通道)。尽管Fashion-MNIST为单通道,需通过复制通道数适配模型。
增加训练轮次(如 num_epochs=20
)。
使用更复杂模型(如VGG16、ResNet)。
添加数据增强(随机旋转、亮度调整)。
减小 batch_size
(如设为64)。
启用混合精度训练(添加 torch.cuda.amp
)。
本文使用PyTorch实现了VGG11模型在Fashion-MNIST数据集上的分类任务,最终测试准确率达 85.7%,并在第3轮即达到 88.1% 的测试准确率,训练速度高达 117.7 examples/sec,展现了优秀的性能与效率。通过代码解析与结果分析,读者可快速掌握从数据预处理到模型训练的完整流程,并根据实际需求调整模型或超参数进一步优化性能。