import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from typing import Tuple, Callable, Optional, List, Dict
import numpy as np
import os
from PIL import Image
# 1. 数据预处理与加载函数(增强版)
def get_data_loaders(
dataset_name: str = 'MNIST', # 可选: 'MNIST'或'CIFAR10'
batch_size: int = 64,
data_dir: str = './data',
num_workers: int = 2,
use_augmentation: bool = True # 是否使用数据增强
) -> Tuple[DataLoader, DataLoader]:
"""获取训练和测试数据加载器,支持灰度(MNIST)和彩色(CIFAR10)数据集"""
# 根据数据集类型设置不同的转换
if dataset_name == 'MNIST':
# 基础转换
base_transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor并归一化到[0,1]
transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差
])
# 增强转换
augment_transform = transforms.Compose([
transforms.RandomRotation(10), # 随机旋转±10度
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 随机平移
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_transform = augment_transform if use_augmentation else base_transform
train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=train_transform)
test_dataset = datasets.MNIST(data_dir, train=False, transform=base_transform)
elif dataset_name == 'CIFAR10':
# 基础转换
base_transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor并归一化到[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
])
# 增强转换
augment_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转±15度
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_transform = augment_transform if use_augmentation else base_transform
train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(data_dir, train=False, transform=base_transform)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
return train_loader, test_loader
# 2. 卷积神经网络定义(替代原有MLP)
class Flatten(nn.Module):
"""自定义展平层,保留batch维度"""
def forward(self, x):
return x.view(x.size(0), -1)
class CNNClassifier(nn.Module):
"""卷积神经网络分类器,支持灰度和彩色图像"""
def __init__(
self,
input_channels: int = 1, # MNIST:1, CIFAR10:3
num_classes: int = 10,
use_batchnorm: bool = True # 是否使用Batch归一化
):
super().__init__()
# 定义特征提取部分(卷积层)
if input_channels == 1: # 灰度图像(MNIST)
self.features = nn.Sequential(
# 第一层卷积
nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32) if use_batchnorm else nn.Identity(),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出: 14x14
# 第二层卷积
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64) if use_batchnorm else nn.Identity(),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出: 7x7
# 第三层卷积(用于Grad-CAM)
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128) if use_batchnorm else nn.Identity(),
nn.ReLU() # 输出: 7x7
)
# 展平后的特征维度
flattened_size = 128 * 7 * 7
self.target_layer = self.features[6] # 第三层卷积
else: # 彩色图像(CIFAR10)
self.features = nn.Sequential(
# 第一层卷积
nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(32) if use_batchnorm else nn.Identity(),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出: 16x16
# 第二层卷积
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64) if use_batchnorm else nn.Identity(),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 输出: 8x8
# 第三层卷积(用于Grad-CAM)
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128) if use_batchnorm else nn.Identity(),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) # 输出: 4x4
)
# 展平后的特征维度
flattened_size = 128 * 4 * 4
self.target_layer = self.features[6] # 第三层卷积
# 定义分类部分(全连接层)
self.classifier = nn.Sequential(
Flatten(),
nn.Dropout(0.5),
nn.Linear(flattened_size, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, num_classes)
)
# 用于特征图可视化的钩子
self.feature_maps = []
self.register_forward_hook(self._save_feature_maps)
# 用于Grad-CAM的钩子
self.gradients = None
self.activations = None
# 注册前向和反向钩子
self.target_layer.register_forward_hook(self._extract_activations)
self.target_layer.register_full_backward_hook(self._extract_gradients)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
def _save_feature_maps(self, module, input, output):
"""保存中间特征图,用于可视化"""
if isinstance(module, nn.Conv2d):
self.feature_maps.append(output.detach())
def _extract_activations(self, module, input, output):
"""提取激活值用于Grad-CAM"""
self.activations = output.detach()
def _extract_gradients(self, module, grad_input, grad_output):
"""提取梯度用于Grad-CAM"""
self.gradients = grad_output[0].detach()
# 3. Grad-CAM实现
class GradCAM:
"""
Grad-CAM实现:通过计算梯度来可视化模型关注的区域
"""
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.feature_maps = []
self.gradients = []
# 注册钩子
self.hook_handles = []
# 前向钩子
def forward_hook(module, input, output):
self.feature_maps.append(output.detach())
# 反向钩子
def backward_hook(module, grad_in, grad_out):
self.gradients.append(grad_out[0].detach())
# 获取目标层
target_found = False
for name, module in self.model.named_modules():
if module == target_layer:
self.hook_handles.append(module.register_forward_hook(forward_hook))
self.hook_handles.append(module.register_full_backward_hook(backward_hook))
target_found = True
break
if not target_found:
raise ValueError(f"Target layer {target_layer} not found in model")
def __call__(self, x, class_idx=None):
"""
计算Grad-CAM热力图
参数:
x: 输入图像,形状为 (1, C, H, W)
class_idx: 目标类别索引,如果为None则使用预测的最高得分类别
"""
# 确保模型处于评估模式
self.model.eval()
# 清空之前的特征图和梯度
self.feature_maps = []
self.gradients = []
# 前向传播
x.requires_grad_()
outputs = self.model(x)
if class_idx is None:
class_idx = torch.argmax(outputs, dim=1)
# 反向传播
self.model.zero_grad()
one_hot = torch.zeros_like(outputs)
one_hot[0, class_idx] = 1
outputs.backward(gradient=one_hot, retain_graph=True)
# 获取特征图和梯度
feature_map = self.feature_maps[0] # [1, C, H, W]
gradient = self.gradients[0] # [1, C, H, W]
# 计算权重 (全局平均池化梯度)
weights = torch.mean(gradient, dim=(2, 3), keepdim=True) # [1, C, 1, 1]
# 加权组合特征图
cam = torch.sum(weights * feature_map, dim=1, keepdim=True) # [1, 1, H, W]
# ReLU激活,因为我们只关心正面影响
cam = torch.relu(cam)
# 归一化
if torch.max(cam) > 0:
cam = cam / torch.max(cam)
# 调整大小以匹配输入图像
cam = torch.nn.functional.interpolate(
cam, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False
)
return cam.squeeze().detach().cpu().numpy()
def remove_hooks(self):
"""移除注册的钩子"""
for handle in self.hook_handles:
handle.remove()
# 4. 增强的特征图可视化函数
def visualize_feature_maps(model, sample_image, device, num_channels=5):
"""可视化卷积层输出的特征图"""
model.eval()
sample_image = sample_image.to(device).unsqueeze(0) # 添加batch维度
# 前向传播并捕获特征图
model.feature_maps = [] # 清空之前的特征图
with torch.no_grad():
model(sample_image)
# 创建一个大的图像来显示所有结果
total_plots = min(3, len(model.feature_maps)) # 最多显示3个卷积层
fig, axes = plt.figure(figsize=(5 * (num_channels + 1), 5 * total_plots)), []
# 绘制原始图像
ax = plt.subplot(total_plots, num_channels + 1, 1)
img = sample_image.cpu().squeeze(0).permute(1, 2, 0)
if img.shape[2] == 1: # 灰度图像
img = img.squeeze(-1)
plt.imshow(img, cmap='gray')
else: # 彩色图像
img = img * torch.tensor([0.5, 0.5, 0.5]) + torch.tensor([0.5, 0.5, 0.5]) # 反归一化
img = np.clip(img.numpy(), 0, 1)
plt.imshow(img)
plt.title('Original Image')
plt.axis('off')
# 绘制每个卷积层的前几个通道
plot_idx = 2
for i, feature_map in enumerate(model.feature_maps[:total_plots]):
# 显示该层的标题
plt.subplot(total_plots, num_channels + 1, plot_idx)
plt.text(0.5, 0.5, f'Layer {i+1} Feature Maps', horizontalalignment='center',
verticalalignment='center', fontsize=12)
plt.axis('off')
plot_idx += 1
# 显示前几个通道
for j in range(min(num_channels, feature_map.shape[1])):
plt.subplot(total_plots, num_channels + 1, plot_idx)
plt.imshow(feature_map[0, j].cpu(), cmap='viridis')
plt.title(f'Channel {j+1}')
plt.axis('off')
plot_idx += 1
plt.tight_layout()
plt.savefig('feature_maps.png')
plt.close()
print("特征图已保存为 'feature_maps.png'")
# 5. Grad-CAM可视化函数
def visualize_grad_cam(model, sample_image, device, class_names=None):
"""使用Grad-CAM可视化模型关注的区域"""
model.eval()
sample_image = sample_image.to(device).unsqueeze(0) # 添加batch维度
# 创建Grad-CAM对象
grad_cam = GradCAM(model, model.target_layer)
# 预测类别
with torch.no_grad():
output = model(sample_image)
pred_class = torch.argmax(output, dim=1).item()
# 生成Grad-CAM热力图
cam = grad_cam(sample_image, pred_class)
# 移除钩子
grad_cam.remove_hooks()
# 准备原始图像用于叠加
img = sample_image.cpu().squeeze(0).permute(1, 2, 0)
if img.shape[2] == 1: # 灰度图像
# 转换为三通道
img = img.repeat(1, 1, 3)
else: # 彩色图像
img = img * torch.tensor([0.5, 0.5, 0.5]) + torch.tensor([0.5, 0.5, 0.5]) # 反归一化
img = np.clip(img.numpy(), 0, 1)
# 可视化
plt.figure(figsize=(12, 5))
# 显示原始图像
plt.subplot(1, 3, 1)
plt.imshow(img)
plt.title('Original Image')
plt.axis('off')
# 显示Grad-CAM热力图
plt.subplot(1, 3, 2)
plt.imshow(cam, cmap='jet')
plt.title('Grad-CAM Heatmap')
plt.axis('off')
# 叠加图像
plt.subplot(1, 3, 3)
plt.imshow(img)
plt.imshow(cam, cmap='jet', alpha=0.5) # 叠加热力图
pred_label = class_names[pred_class] if class_names else f'Class {pred_class}'
plt.title(f'Overlay - Predicted: {pred_label}')
plt.axis('off')
plt.tight_layout()
plt.savefig('grad_cam.png')
plt.close()
print("Grad-CAM可视化已保存为 'grad_cam.png'")
# 6. 回调函数示例
class Callback:
"""训练过程中的回调基类"""
def on_epoch_start(self, epoch, model, optimizer):
pass
def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
pass
def on_batch_start(self, batch_idx, data, target):
pass
def on_batch_end(self, batch_idx, loss):
pass
class SaveBestModelCallback(Callback):
"""保存验证集上性能最佳的模型"""
def __init__(self, save_path, metric='accuracy', mode='max'):
self.save_path = save_path
self.metric = metric
self.mode = mode
self.best_value = float('-inf') if mode == 'max' else float('inf')
def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
current_value = val_acc if self.metric == 'accuracy' else -val_loss
if (self.mode == 'max' and current_value > self.best_value) or \
(self.mode == 'min' and current_value < self.best_value):
self.best_value = current_value
torch.save(model.state_dict(), self.save_path)
print(f"Epoch {epoch}: Best model saved with {self.metric}: {current_value:.4f}")
class LearningRateMonitorCallback(Callback):
"""监控并打印学习率变化"""
def on_epoch_start(self, epoch, model, optimizer):
lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch} starting with learning rate: {lr:.6f}")
class EarlyStoppingCallback(Callback):
"""当验证集性能停止提升时提前停止训练"""
def __init__(self, patience=3, metric='accuracy', mode='max'):
self.patience = patience
self.metric = metric
self.mode = mode
self.best_value = float('-inf') if mode == 'max' else float('inf')
self.counter = 0
self.should_stop = False
def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
current_value = val_acc if self.metric == 'accuracy' else -val_loss
if (self.mode == 'max' and current_value > self.best_value) or \
(self.mode == 'min' and current_value < self.best_value):
self.best_value = current_value
self.counter = 0
else:
self.counter += 1
print(f"Early stopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.should_stop = True
print(f"Early stopping triggered after {epoch} epochs")
def check_early_stop(self):
return self.should_stop
# 7. 训练函数(支持回调函数)
def train(
model: nn.Module,
train_loader: DataLoader,
criterion: nn.Module,
optimizer: optim.Optimizer,
scheduler: Optional[StepLR], # 学习率调度器
device: torch.device,
epoch: int,
log_interval: int = 100,
callbacks: List[Callback] = None
) -> float:
"""训练模型一个epoch"""
model.train() # 启用训练模式(激活dropout等)
running_loss = 0.0
# 通知回调函数:epoch开始
if callbacks:
for callback in callbacks:
callback.on_epoch_start(epoch, model, optimizer)
for batch_idx, (data, target) in enumerate(train_loader):
# 通知回调函数:batch开始
if callbacks:
for callback in callbacks:
callback.on_batch_start(batch_idx, data, target)
data, target = data.to(device), target.to(device)
# 前向传播
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# 反向传播
loss.backward()
optimizer.step()
running_loss += loss.item()
# 打印训练进度
if batch_idx % log_interval == 0:
current_lr = optimizer.param_groups[0]['lr']
print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
f'Loss: {loss.item():.6f}\tLR: {current_lr:.6f}')
# 通知回调函数:batch结束
if callbacks:
for callback in callbacks:
callback.on_batch_end(batch_idx, loss.item())
# 打印平均损失
avg_loss = running_loss / len(train_loader)
print(f'Epoch {epoch} average loss: {avg_loss:.4f}')
# 更新学习率
if scheduler:
scheduler.step()
print(f'Epoch {epoch} learning rate updated to: {optimizer.param_groups[0]["lr"]:.6f}')
return avg_loss
# 8. 测试函数(保持不变)
def test(
model: nn.Module,
test_loader: DataLoader,
criterion: nn.Module,
device: torch.device
) -> Tuple[float, float]:
"""评估模型在测试集上的性能"""
model.eval() # 启用评估模式(关闭dropout等)
test_loss = 0
correct = 0
with torch.no_grad(): # 不计算梯度,节省内存和计算资源
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # 累加批次损失
pred = output.argmax(dim=1, keepdim=True) # 获取最大概率的类别
correct += pred.eq(target.view_as(pred)).sum().item() # 统计正确预测数
# 计算平均损失和准确率
test_loss /= len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
f'({accuracy:.2f}%)\n')
return test_loss, accuracy
# 9. 主函数:训练和测试流程(增强版)
def main(
dataset_name: str = 'MNIST',
batch_size: int = 64,
epochs: int = 5,
lr: float = 0.001,
gamma: float = 0.7, # 学习率衰减因子
step_size: int = 1, # 学习率衰减步数
use_augmentation: bool = True,
use_batchnorm: bool = True,
use_cuda: bool = True
) -> None:
"""主函数:整合数据加载、模型训练和测试流程"""
# 设置设备
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
print(f"Using device: {device}")
# 获取数据加载器
train_loader, test_loader = get_data_loaders(
dataset_name=dataset_name,
batch_size=batch_size,
use_augmentation=use_augmentation
)
# 确定输入参数
if dataset_name == 'MNIST':
input_channels = 1
num_classes = 10
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
elif dataset_name == 'CIFAR10':
input_channels = 3
num_classes = 10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
# 初始化模型
model = CNNClassifier(
input_channels=input_channels,
num_classes=num_classes,
use_batchnorm=use_batchnorm
).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
# 定义学习率调度器
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
# 定义回调函数
callbacks = [
SaveBestModelCallback(f"{dataset_name}_best_cnn_model.pth"),
LearningRateMonitorCallback(),
EarlyStoppingCallback(patience=3)
]
# 训练和测试循环
best_accuracy = 0.0
for epoch in range(1, epochs + 1):
train_loss = train(model, train_loader, criterion, optimizer, scheduler, device, epoch, callbacks=callbacks)
test_loss, accuracy = test(model, test_loader, criterion, device)
# 通知回调函数:epoch结束
for callback in callbacks:
callback.on_epoch_end(epoch, model, optimizer, train_loss, test_loss, accuracy)
# 检查是否应该提前停止
if any(callback.check_early_stop() for callback in callbacks if hasattr(callback, 'check_early_stop')):
break
# 更新最佳准确率
if accuracy > best_accuracy:
best_accuracy = accuracy
# 可视化特征图
sample_image, _ = next(iter(test_loader))
visualize_feature_maps(model, sample_image[0], device)
# 可视化Grad-CAM
visualize_grad_cam(model, sample_image[0], device, class_names)
print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")
if __name__ == "__main__":
# 训练MNIST模型
main(
dataset_name='MNIST',
batch_size=64,
epochs=5,
lr=0.001,
gamma=0.7,
step_size=1,
use_augmentation=True,
use_batchnorm=True
)
@浙大疏锦行