作业:day43的时候我们安排大家对自己找的数据集用简单cnn训练,现在可以尝试下借助这几天的知识来实现精度的进一步提高
使用CIFAR-10图像数据
import torch
from torchvision import datasets, transforms
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到 [-1, 1]
])
# 下载并加载训练集和测试集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
# 创建 DataLoader
batch_size = 128
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False
)
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x))) # 16x16x16
x = self.pool(self.relu(self.conv2(x))) # 32x8x8
x = x.view(-1, 32 * 8 * 8)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(in_channels, in_channels // ratio),
nn.ReLU(),
nn.Linear(in_channels // ratio, in_channels)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.mlp(self.avg_pool(x).squeeze(-1).squeeze(-1))
max_out = self.mlp(self.max_pool(x).squeeze(-1).squeeze(-1))
channel_weights = self.sigmoid(avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
return x * channel_weights
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
spatial = torch.cat([avg_out, max_out], dim=1)
spatial_weights = self.sigmoid(self.conv(spatial))
return x * spatial_weights
class CBAM(nn.Module):
def __init__(self, in_channels):
super(CBAM, self).__init__()
self.ca = ChannelAttention(in_channels)
self.sa = SpatialAttention()
def forward(self, x):
x = self.ca(x)
x = self.sa(x)
return x
class CNNWithCBAM(nn.Module):
def __init__(self):
super(CNNWithCBAM, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.cbam1 = CBAM(16)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.cbam2 = CBAM(32)
self.fc1 = nn.Linear(32 * 8 * 8, 256)
self.fc2 = nn.Linear(256, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.cbam1(x) # 添加CBAM
x = self.pool(x)
x = self.relu(self.conv2(x))
x = self.cbam2(x) # 添加CBAM
x = self.pool(x)
x = x.view(-1, 32 * 8 * 8)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # 可视化训练进度
# 设置训练使用的设备CPU/GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 模型训练
def train_model(model, train_loader, test_loader, epochs=50, lr=1e-3):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
best_acc = 0.0
history = {'train_loss': [], 'test_acc': []}
# 训练阶段
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
history['train_loss'].append(epoch_loss)
# 测试阶段
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_acc = 100 * correct / total
history['test_acc'].append(epoch_acc)
# 保存最佳模型
if epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model.state_dict(), f"{model.__class__.__name__}_best.pth")
print(f"Epoch {epoch+1}: Loss={epoch_loss:.4f}, Test Acc={epoch_acc:.2f}%")
print(f"Best Test Accuracy: {best_acc:.2f}%")
return history
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# 初始化两个模型
model_baseline = SimpleCNN()
model_cbam = CNNWithCBAM()
# 训练基线模型(SimpleCNN)
print("Training Baseline Model (SimpleCNN)...")
history_baseline = train_model(model_baseline, train_loader, test_loader)
# 训练CBAM增强模型(CNNWithCBAM)
print("\nTraining CBAM-enhanced Model (CNNWithCBAM)...")
history_cbam = train_model(model_cbam, train_loader, test_loader)
# 加载最佳模型并最终评估
model_baseline.load_state_dict(torch.load("SimpleCNN_best.pth"))
model_cbam.load_state_dict(torch.load("CNNWithCBAM_best.pth"))
final_acc_baseline = evaluate_model(model_baseline, test_loader)
final_acc_cbam = evaluate_model(model_cbam, test_loader)
print(f"\nFinal Results:")
print(f"Baseline Model Test Accuracy: {final_acc_baseline:.2f}%")
print(f"CBAM-enhanced Model Test Accuracy: {final_acc_cbam:.2f}%")
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history_baseline['train_loss'], label='Baseline')
plt.plot(history_cbam['train_loss'], label='CBAM')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history_baseline['test_acc'], label='Baseline')
plt.plot(history_cbam['test_acc'], label='CBAM')
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.legend()
plt.show()