PyTorch实战:从零构建CNN模型,轻松搞定MNIST手写数字识别

PyTorch实战:从零构建CNN模型,轻松搞定MNIST手写数字识别

大家好!欢迎来到我的深度学习博客!

对于每个踏入计算机视觉领域的人来说,MNIST手写数字识别就像是编程世界的“Hello, World!”。它足够简单,能够让我们快速上手;也足够完整,可以帮我们走通一个深度学习项目的全流程。

之前我们可能用Keras体验过“搭积木”式的快乐,今天,我们将换一个同样强大且灵活的框架——PyTorch,来感受一下它更为“Pythonic”的魅力。本文将手把手带你完成以下任务:

  1. 准备环境并使用torchvision加载和预处理MNIST数据集。
  2. 定义一个经典的卷积神经网络(CNN)结构。
  3. 编写训练循环,从零开始掌控模型的每一步训练。
  4. 评估模型在测试集上的表现。
  5. 实际预测一张图片,看看我们亲手训练的AI“视力”如何。

准备好了吗?让我们开始吧!

步骤一:环境准备与数据加载

首先,确保你已经安装了PyTorch和相关的库。

pip install torch torchvision matplotlib numpy

PyTorch通过torchvision库为我们提供了许多经典的数据集,包括MNIST。加载数据非常方便。同时,我们需要对数据进行一些预处理:

  • transforms.ToTensor(): 将PIL图像或NumPy数组转换为PyTorch的Tensor,并将像素值从[0, 255]范围缩放到[0, 1]范围。
  • transforms.Normalize((0.1307,), (0.3081,)): 用给定的均值(0.1307)和标准差(0.3081)对张量进行标准化。这两个值是MNIST数据集的全局均值和标准差,这样做可以帮助模型更好地收敛。
  • DataLoader: 这是PyTorch强大的数据加载工具,它可以帮我们实现数据的批处理(batching)、打乱(shuffling)和并行加载。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# --- 1. 数据加载与预处理 ---

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(), # 转换为Tensor,并归一化到[0, 1]
    transforms.Normalize((0.1307,), (0.3081,)) # 标准化
])

# 下载/加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建DataLoader
batch_size = 64
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 可视化一些数据样本
examples = iter(train_loader)
example_data, example_targets = next(examples)

fig = plt.figure()
for i in range(6):
  plt.subplot(2, 3, i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()

(代码运行后会显示6张手写数字图片及其真实标签,让你对数据有个直观认识)

步骤二:构建CNN模型

在PyTorch中,我们通过创建一个继承自nn.Module的类来定义自己的神经网络。

  • __init__(self): 在构造函数中,我们定义模型需要用到的所有“层”,例如卷积层(nn.Conv2d)、池化层(nn.MaxPool2d)和全连接层(nn.Linear)。
  • forward(self, x): 在这里,我们定义数据在网络中的“流动”路径,也就是前向传播的过程。这也是PyTorch与Keras最显著的区别之一,它让你能用任何Python代码来控制数据流,非常灵活。

我们将构建一个经典的LeNet-style结构:Conv -> Pool -> Conv -> Pool -> Flatten -> FC -> FC

# --- 2. 构建CNN模型 ---

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 第一个卷积块
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2) # 输出: (16, 14, 14)
        )
        # 第二个卷积块
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2) # 输出: (32, 7, 7)
        )
        # 全连接层
        self.fc1 = nn.Linear(32 * 7 * 7, 128) # 输入是展平后的特征图
        self.fc2 = nn.Linear(128, 10) # 10个类别

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # 展平操作
        x = x.view(x.size(0), -1) # x.size(0)是batch_size, -1是自动计算维度
        x = self.fc1(x)
        x = nn.ReLU()(x)
        x = self.fc2(x)
        return x

# 检查设备(使用GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

# 实例化模型并移动到指定设备
model = Net().to(device)
print(model)

print(model)会打印出我们定义的网络结构,非常直观。

步骤三:定义损失函数和优化器

  • 损失函数: 对于多分类问题,nn.CrossEntropyLoss是标准选择。注意:它内部已经包含了Softmax操作,所以我们的模型在最后一层不需要添加Softmax激活函数。
  • 优化器: Adam是一个鲁棒且高效的优化算法,是大多数任务的首选。我们需要将模型的参数model.parameters()传递给它,以便它知道要更新哪些权重。
# --- 3. 定义损失函数和优化器 ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

步骤四:训练模型

这是最激动人心的部分!与Keras的.fit()一行代码不同,在PyTorch中我们需要手动编写训练循环。这虽然代码稍多,但给了我们极大的控制权。

一个标准的训练循环包含以下步骤:

  1. 将模型设置为训练模式 (model.train())。
  2. 遍历DataLoader中的每个批次。
  3. 将数据和标签移动到GPU/CPU。
  4. 清零梯度 (optimizer.zero_grad())。
  5. 前向传播 (output = model(data))。
  6. 计算损失 (loss = criterion(...))。
  7. 反向传播 (loss.backward())。
  8. 更新权重 (optimizer.step())。

我们还会编写一个test函数,在每个epoch训练结束后,评估模型在测试集上的性能。

# --- 4. 训练模型 ---

def train(epoch):
    model.train() # 设置为训练模式
    for batch_idx, (data, target) in enumerate(train_loader):
        # 将数据移动到设备
        data, target = data.to(device), target.to(device)
        
        # 1. 清零梯度
        optimizer.zero_grad()
        
        # 2. 前向传播
        output = model(data)
        
        # 3. 计算损失
        loss = criterion(output, target)
        
        # 4. 反向传播
        loss.backward()
        
        # 5. 更新权重
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test():
    model.eval() # 设置为评估模式
    test_loss = 0
    correct = 0
    with torch.no_grad(): # 在此模式下,不计算梯度,节省内存和计算
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item() # 累加批次损失
            pred = output.argmax(dim=1, keepdim=True) # 获取最大概率的索引
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.2f}%)\n')


# 开始训练循环
epochs = 5
for epoch in range(1, epochs + 1):
    train(epoch)
    test()

看着每个epoch后准确率不断攀升,是不是很有成就感?只需要几个epoch,我们的模型就能在测试集上达到98%-99%的准确率!

步骤五:进行单张图片预测

模型训练好了,我们来实际用一下。从测试集中随机取一张图片,看看模型的判断是否准确。

# --- 5. 单张图片预测 ---

# 重新加载测试数据,这次不打乱,方便我们查看
test_loader_pred = DataLoader(test_dataset, batch_size=1, shuffle=False)
data_iter = iter(test_loader_pred)
images, labels = next(data_iter)

# 将图片和模型都移动到CPU上进行预测和可视化
images = images.to("cpu")
model.to("cpu")

# 进行预测
model.eval()
with torch.no_grad():
    output = model(images)
    prediction = output.argmax(dim=1, keepdim=True).item()

# 可视化结果
plt.imshow(images[0][0], cmap="gray")
plt.title(f"Prediction: {prediction}, Ground Truth: {labels.item()}")
plt.show()

print(f"模型预测这张图片是数字:{prediction}")
print(f"这张图片的真实标签是:{labels.item()}")
if prediction == labels.item():
    print("预测正确!干得漂亮!")
else:
    print("啊哦,预测错了。")

总结

恭喜你!你已经成功使用PyTorch从零开始构建、训练并评估了一个CNN模型。

通过这个过程,我们不仅学会了PyTorch的基本操作,更重要的是理解了它灵活而强大的设计哲学:

  • 动态计算图forward函数让模型结构像写普通Python代码一样自由。
  • 清晰的模块化:通过nn.Module定义网络,通过DataLoader处理数据,通过手动循环控制训练,每个部分各司其职。

希望这篇教程能为你打开PyTorch世界的大门。接下来,你可以尝试更复杂的网络结构、使用不同的数据集,或者深入研究学习率调整、正则化等高级技巧。

如果你觉得这篇文章对你有帮助,欢迎点赞、收藏和分享!我们下期再见!

你可能感兴趣的:(PyTorch实战:从零构建CNN模型,轻松搞定MNIST手写数字识别)