本文将使用PyTorch从零实现经典的LeNet-5模型,并在交通标志识别数据集上进行训练和部署。完整代码可直接运行。
(lenet_env)
):# 创建虚拟环境
python -m venv lenet_env
# 激活环境
.\lenet_env\Scripts\activate
# GPU版本(需CUDA 11.8)推荐
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118/
# 或CPU版本
pip install torch torchvision torchaudio
pip install matplotlib numpy flask requests onnx onnxruntime
check_env.py
文件并运行:import torch
print("PyTorch版本:", torch.__version__)
print("CUDA可用:", torch.cuda.is_available())
print("设备数量:", torch.cuda.device_count())
预期输出示例:PyTorch版本: 2.3.1+cu118
CUDA可用: True
设备数量: 1
C:/
└─your_project/
├─data/
│ ├─train/
│ │ └─GTSRB/Final_Training/Images/...
│ └─test/
│ └─GTSRB/Final_Test/Images/...
└─code/
nvidia-smi
验证显卡识别C:/your_project/data/train
)使用德国交通标志识别基准(GTSRB)数据集:
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载数据集
train_set = datasets.ImageFolder(root='./data/train', transform=transform)
test_set = datasets.ImageFolder(root='./data/test', transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
print(f"训练集大小: {len(train_set)}")
print(f"测试集大小: {len(test_set)}")
print(f"类别数量: {len(train_set.classes)}")
LeNet-5的PyTorch实现:
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self, num_classes=43):
super(LeNet5, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5), # 输入通道改为3(RGB)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
model = LeNet5()
print(model)
训练配置与执行:
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 训练循环
for epoch in range(20):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 验证
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Epoch [{epoch+1}/20] Loss: {running_loss/len(train_loader):.4f} | Acc: {100*correct/total:.2f}%")
# 保存模型
torch.save(model.state_dict(), "lenet5_traffic_sign.pth")
dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(model, dummy_input, "lenet5.onnx",
input_names=["input"], output_names=["output"])
from flask import Flask, request, jsonify
from PIL import Image
import numpy as np
app = Flask(__name__)
model.load_state_dict(torch.load("lenet5_traffic_sign.pth"))
model.eval()
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform(image).unsqueeze(0)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file uploaded'})
file = request.files['file']
image = Image.open(file.stream).convert('RGB')
tensor = preprocess_image(image).to(device)
with torch.no_grad():
outputs = model(tensor)
_, predicted = torch.max(outputs, 1)
return jsonify({'class_id': predicted.item(),
'class_name': train_set.classes[predicted.item()]})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
使用curl测试:
curl -X POST -F "file=@test_sign.jpg" http://localhost:5000/predict
通过本文我们实现了:
完整代码需配合GTSRB数据集使用,数据集可从这里下载。建议使用GPU加速训练过程。