DAY 42 超大力王爱学Python

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from typing import Tuple, Callable, Optional, List, Dict
import numpy as np
import os
from PIL import Image

# 1. 数据预处理与加载函数(增强版)
def get_data_loaders(
    dataset_name: str = 'MNIST',  # 可选: 'MNIST'或'CIFAR10'
    batch_size: int = 64,
    data_dir: str = './data',
    num_workers: int = 2,
    use_augmentation: bool = True  # 是否使用数据增强
) -> Tuple[DataLoader, DataLoader]:
    """获取训练和测试数据加载器,支持灰度(MNIST)和彩色(CIFAR10)数据集"""
    # 根据数据集类型设置不同的转换
    if dataset_name == 'MNIST':
        # 基础转换
        base_transform = transforms.Compose([
            transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]
            transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差
        ])
        
        # 增强转换
        augment_transform = transforms.Compose([
            transforms.RandomRotation(10),  # 随机旋转±10度
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        train_transform = augment_transform if use_augmentation else base_transform
        
        train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=train_transform)
        test_dataset = datasets.MNIST(data_dir, train=False, transform=base_transform)
    
    elif dataset_name == 'CIFAR10':
        # 基础转换
        base_transform = transforms.Compose([
            transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到[-1,1]
        ])
        
        # 增强转换
        augment_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # 随机裁剪
            transforms.RandomHorizontalFlip(),  # 随机水平翻转
            transforms.RandomRotation(15),  # 随机旋转±15度
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        train_transform = augment_transform if use_augmentation else base_transform
        
        train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=train_transform)
        test_dataset = datasets.CIFAR10(data_dir, train=False, transform=base_transform)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )
    
    return train_loader, test_loader

# 2. 卷积神经网络定义(替代原有MLP)
class Flatten(nn.Module):
    """自定义展平层,保留batch维度"""
    def forward(self, x):
        return x.view(x.size(0), -1)

class CNNClassifier(nn.Module):
    """卷积神经网络分类器,支持灰度和彩色图像"""
    def __init__(
        self,
        input_channels: int = 1,  # MNIST:1, CIFAR10:3
        num_classes: int = 10,
        use_batchnorm: bool = True  # 是否使用Batch归一化
    ):
        super().__init__()
        
        # 定义特征提取部分(卷积层)
        if input_channels == 1:  # 灰度图像(MNIST)
            self.features = nn.Sequential(
                # 第一层卷积
                nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32) if use_batchnorm else nn.Identity(),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 14x14
                
                # 第二层卷积
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64) if use_batchnorm else nn.Identity(),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 7x7
                
                # 第三层卷积(用于Grad-CAM)
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128) if use_batchnorm else nn.Identity(),
                nn.ReLU()  # 输出: 7x7
            )
            # 展平后的特征维度
            flattened_size = 128 * 7 * 7
            self.target_layer = self.features[6]  # 第三层卷积
        else:  # 彩色图像(CIFAR10)
            self.features = nn.Sequential(
                # 第一层卷积
                nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32) if use_batchnorm else nn.Identity(),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 16x16
                
                # 第二层卷积
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64) if use_batchnorm else nn.Identity(),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # 输出: 8x8
                
                # 第三层卷积(用于Grad-CAM)
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128) if use_batchnorm else nn.Identity(),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2)  # 输出: 4x4
            )
            # 展平后的特征维度
            flattened_size = 128 * 4 * 4
            self.target_layer = self.features[6]  # 第三层卷积
        
        # 定义分类部分(全连接层)
        self.classifier = nn.Sequential(
            Flatten(),
            nn.Dropout(0.5),
            nn.Linear(flattened_size, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
        # 用于特征图可视化的钩子
        self.feature_maps = []
        self.register_forward_hook(self._save_feature_maps)
        
        # 用于Grad-CAM的钩子
        self.gradients = None
        self.activations = None
        
        # 注册前向和反向钩子
        self.target_layer.register_forward_hook(self._extract_activations)
        self.target_layer.register_full_backward_hook(self._extract_gradients)
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def _save_feature_maps(self, module, input, output):
        """保存中间特征图,用于可视化"""
        if isinstance(module, nn.Conv2d):
            self.feature_maps.append(output.detach())
    
    def _extract_activations(self, module, input, output):
        """提取激活值用于Grad-CAM"""
        self.activations = output.detach()
    
    def _extract_gradients(self, module, grad_input, grad_output):
        """提取梯度用于Grad-CAM"""
        self.gradients = grad_output[0].detach()

# 3. Grad-CAM实现
class GradCAM:
    """
    Grad-CAM实现:通过计算梯度来可视化模型关注的区域
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.feature_maps = []
        self.gradients = []
        
        # 注册钩子
        self.hook_handles = []
        
        # 前向钩子
        def forward_hook(module, input, output):
            self.feature_maps.append(output.detach())
        
        # 反向钩子
        def backward_hook(module, grad_in, grad_out):
            self.gradients.append(grad_out[0].detach())
        
        # 获取目标层
        target_found = False
        for name, module in self.model.named_modules():
            if module == target_layer:
                self.hook_handles.append(module.register_forward_hook(forward_hook))
                self.hook_handles.append(module.register_full_backward_hook(backward_hook))
                target_found = True
                break
        
        if not target_found:
            raise ValueError(f"Target layer {target_layer} not found in model")
    
    def __call__(self, x, class_idx=None):
        """
        计算Grad-CAM热力图
        
        参数:
            x: 输入图像,形状为 (1, C, H, W)
            class_idx: 目标类别索引,如果为None则使用预测的最高得分类别
        """
        # 确保模型处于评估模式
        self.model.eval()
        
        # 清空之前的特征图和梯度
        self.feature_maps = []
        self.gradients = []
        
        # 前向传播
        x.requires_grad_()
        outputs = self.model(x)
        
        if class_idx is None:
            class_idx = torch.argmax(outputs, dim=1)
        
        # 反向传播
        self.model.zero_grad()
        one_hot = torch.zeros_like(outputs)
        one_hot[0, class_idx] = 1
        outputs.backward(gradient=one_hot, retain_graph=True)
        
        # 获取特征图和梯度
        feature_map = self.feature_maps[0]  # [1, C, H, W]
        gradient = self.gradients[0]  # [1, C, H, W]
        
        # 计算权重 (全局平均池化梯度)
        weights = torch.mean(gradient, dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
        
        # 加权组合特征图
        cam = torch.sum(weights * feature_map, dim=1, keepdim=True)  # [1, 1, H, W]
        
        # ReLU激活,因为我们只关心正面影响
        cam = torch.relu(cam)
        
        # 归一化
        if torch.max(cam) > 0:
            cam = cam / torch.max(cam)
        
        # 调整大小以匹配输入图像
        cam = torch.nn.functional.interpolate(
            cam, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False
        )
        
        return cam.squeeze().detach().cpu().numpy()
    
    def remove_hooks(self):
        """移除注册的钩子"""
        for handle in self.hook_handles:
            handle.remove()

# 4. 增强的特征图可视化函数
def visualize_feature_maps(model, sample_image, device, num_channels=5):
    """可视化卷积层输出的特征图"""
    model.eval()
    sample_image = sample_image.to(device).unsqueeze(0)  # 添加batch维度
    
    # 前向传播并捕获特征图
    model.feature_maps = []  # 清空之前的特征图
    with torch.no_grad():
        model(sample_image)
    
    # 创建一个大的图像来显示所有结果
    total_plots = min(3, len(model.feature_maps))  # 最多显示3个卷积层
    fig, axes = plt.figure(figsize=(5 * (num_channels + 1), 5 * total_plots)), []
    
    # 绘制原始图像
    ax = plt.subplot(total_plots, num_channels + 1, 1)
    img = sample_image.cpu().squeeze(0).permute(1, 2, 0)
    if img.shape[2] == 1:  # 灰度图像
        img = img.squeeze(-1)
        plt.imshow(img, cmap='gray')
    else:  # 彩色图像
        img = img * torch.tensor([0.5, 0.5, 0.5]) + torch.tensor([0.5, 0.5, 0.5])  # 反归一化
        img = np.clip(img.numpy(), 0, 1)
        plt.imshow(img)
    plt.title('Original Image')
    plt.axis('off')
    
    # 绘制每个卷积层的前几个通道
    plot_idx = 2
    for i, feature_map in enumerate(model.feature_maps[:total_plots]):
        # 显示该层的标题
        plt.subplot(total_plots, num_channels + 1, plot_idx)
        plt.text(0.5, 0.5, f'Layer {i+1} Feature Maps', horizontalalignment='center', 
                 verticalalignment='center', fontsize=12)
        plt.axis('off')
        plot_idx += 1
        
        # 显示前几个通道
        for j in range(min(num_channels, feature_map.shape[1])):
            plt.subplot(total_plots, num_channels + 1, plot_idx)
            plt.imshow(feature_map[0, j].cpu(), cmap='viridis')
            plt.title(f'Channel {j+1}')
            plt.axis('off')
            plot_idx += 1
    
    plt.tight_layout()
    plt.savefig('feature_maps.png')
    plt.close()
    print("特征图已保存为 'feature_maps.png'")

# 5. Grad-CAM可视化函数
def visualize_grad_cam(model, sample_image, device, class_names=None):
    """使用Grad-CAM可视化模型关注的区域"""
    model.eval()
    sample_image = sample_image.to(device).unsqueeze(0)  # 添加batch维度
    
    # 创建Grad-CAM对象
    grad_cam = GradCAM(model, model.target_layer)
    
    # 预测类别
    with torch.no_grad():
        output = model(sample_image)
        pred_class = torch.argmax(output, dim=1).item()
    
    # 生成Grad-CAM热力图
    cam = grad_cam(sample_image, pred_class)
    
    # 移除钩子
    grad_cam.remove_hooks()
    
    # 准备原始图像用于叠加
    img = sample_image.cpu().squeeze(0).permute(1, 2, 0)
    if img.shape[2] == 1:  # 灰度图像
        # 转换为三通道
        img = img.repeat(1, 1, 3)
    else:  # 彩色图像
        img = img * torch.tensor([0.5, 0.5, 0.5]) + torch.tensor([0.5, 0.5, 0.5])  # 反归一化
    img = np.clip(img.numpy(), 0, 1)
    
    # 可视化
    plt.figure(figsize=(12, 5))
    
    # 显示原始图像
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Original Image')
    plt.axis('off')
    
    # 显示Grad-CAM热力图
    plt.subplot(1, 3, 2)
    plt.imshow(cam, cmap='jet')
    plt.title('Grad-CAM Heatmap')
    plt.axis('off')
    
    # 叠加图像
    plt.subplot(1, 3, 3)
    plt.imshow(img)
    plt.imshow(cam, cmap='jet', alpha=0.5)  # 叠加热力图
    pred_label = class_names[pred_class] if class_names else f'Class {pred_class}'
    plt.title(f'Overlay - Predicted: {pred_label}')
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig('grad_cam.png')
    plt.close()
    print("Grad-CAM可视化已保存为 'grad_cam.png'")

# 6. 回调函数示例
class Callback:
    """训练过程中的回调基类"""
    def on_epoch_start(self, epoch, model, optimizer):
        pass
    
    def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
        pass
    
    def on_batch_start(self, batch_idx, data, target):
        pass
    
    def on_batch_end(self, batch_idx, loss):
        pass

class SaveBestModelCallback(Callback):
    """保存验证集上性能最佳的模型"""
    def __init__(self, save_path, metric='accuracy', mode='max'):
        self.save_path = save_path
        self.metric = metric
        self.mode = mode
        self.best_value = float('-inf') if mode == 'max' else float('inf')
    
    def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
        current_value = val_acc if self.metric == 'accuracy' else -val_loss
        
        if (self.mode == 'max' and current_value > self.best_value) or \
           (self.mode == 'min' and current_value < self.best_value):
            self.best_value = current_value
            torch.save(model.state_dict(), self.save_path)
            print(f"Epoch {epoch}: Best model saved with {self.metric}: {current_value:.4f}")

class LearningRateMonitorCallback(Callback):
    """监控并打印学习率变化"""
    def on_epoch_start(self, epoch, model, optimizer):
        lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch} starting with learning rate: {lr:.6f}")

class EarlyStoppingCallback(Callback):
    """当验证集性能停止提升时提前停止训练"""
    def __init__(self, patience=3, metric='accuracy', mode='max'):
        self.patience = patience
        self.metric = metric
        self.mode = mode
        self.best_value = float('-inf') if mode == 'max' else float('inf')
        self.counter = 0
        self.should_stop = False
    
    def on_epoch_end(self, epoch, model, optimizer, train_loss, val_loss, val_acc):
        current_value = val_acc if self.metric == 'accuracy' else -val_loss
        
        if (self.mode == 'max' and current_value > self.best_value) or \
           (self.mode == 'min' and current_value < self.best_value):
            self.best_value = current_value
            self.counter = 0
        else:
            self.counter += 1
            print(f"Early stopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.should_stop = True
                print(f"Early stopping triggered after {epoch} epochs")
    
    def check_early_stop(self):
        return self.should_stop

# 7. 训练函数(支持回调函数)
def train(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: Optional[StepLR],  # 学习率调度器
    device: torch.device,
    epoch: int,
    log_interval: int = 100,
    callbacks: List[Callback] = None
) -> float:
    """训练模型一个epoch"""
    model.train()  # 启用训练模式(激活dropout等)
    running_loss = 0.0
    
    # 通知回调函数:epoch开始
    if callbacks:
        for callback in callbacks:
            callback.on_epoch_start(epoch, model, optimizer)
    
    for batch_idx, (data, target) in enumerate(train_loader):
        # 通知回调函数:batch开始
        if callbacks:
            for callback in callbacks:
                callback.on_batch_start(batch_idx, data, target)
        
        data, target = data.to(device), target.to(device)
        
        # 前向传播
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # 打印训练进度
        if batch_idx % log_interval == 0:
            current_lr = optimizer.param_groups[0]['lr']
            print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                  f'Loss: {loss.item():.6f}\tLR: {current_lr:.6f}')
        
        # 通知回调函数:batch结束
        if callbacks:
            for callback in callbacks:
                callback.on_batch_end(batch_idx, loss.item())
    
    # 打印平均损失
    avg_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch} average loss: {avg_loss:.4f}')
    
    # 更新学习率
    if scheduler:
        scheduler.step()
        print(f'Epoch {epoch} learning rate updated to: {optimizer.param_groups[0]["lr"]:.6f}')
    
    return avg_loss

# 8. 测试函数(保持不变)
def test(
    model: nn.Module,
    test_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device
) -> Tuple[float, float]:
    """评估模型在测试集上的性能"""
    model.eval()  # 启用评估模式(关闭dropout等)
    test_loss = 0
    correct = 0
    
    with torch.no_grad():  # 不计算梯度,节省内存和计算资源
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # 累加批次损失
            pred = output.argmax(dim=1, keepdim=True)  # 获取最大概率的类别
            correct += pred.eq(target.view_as(pred)).sum().item()  # 统计正确预测数
    
    # 计算平均损失和准确率
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    
    return test_loss, accuracy

# 9. 主函数:训练和测试流程(增强版)
def main(
    dataset_name: str = 'MNIST',
    batch_size: int = 64,
    epochs: int = 5,
    lr: float = 0.001,
    gamma: float = 0.7,  # 学习率衰减因子
    step_size: int = 1,  # 学习率衰减步数
    use_augmentation: bool = True,
    use_batchnorm: bool = True,
    use_cuda: bool = True
) -> None:
    """主函数:整合数据加载、模型训练和测试流程"""
    # 设置设备
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
    print(f"Using device: {device}")
    
    # 获取数据加载器
    train_loader, test_loader = get_data_loaders(
        dataset_name=dataset_name,
        batch_size=batch_size,
        use_augmentation=use_augmentation
    )
    
    # 确定输入参数
    if dataset_name == 'MNIST':
        input_channels = 1
        num_classes = 10
        class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    elif dataset_name == 'CIFAR10':
        input_channels = 3
        num_classes = 10
        class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
                       'dog', 'frog', 'horse', 'ship', 'truck']
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    # 初始化模型
    model = CNNClassifier(
        input_channels=input_channels,
        num_classes=num_classes,
        use_batchnorm=use_batchnorm
    ).to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # 定义学习率调度器
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    
    # 定义回调函数
    callbacks = [
        SaveBestModelCallback(f"{dataset_name}_best_cnn_model.pth"),
        LearningRateMonitorCallback(),
        EarlyStoppingCallback(patience=3)
    ]
    
    # 训练和测试循环
    best_accuracy = 0.0
    for epoch in range(1, epochs + 1):
        train_loss = train(model, train_loader, criterion, optimizer, scheduler, device, epoch, callbacks=callbacks)
        test_loss, accuracy = test(model, test_loader, criterion, device)
        
        # 通知回调函数:epoch结束
        for callback in callbacks:
            callback.on_epoch_end(epoch, model, optimizer, train_loss, test_loss, accuracy)
        
        # 检查是否应该提前停止
        if any(callback.check_early_stop() for callback in callbacks if hasattr(callback, 'check_early_stop')):
            break
        
        # 更新最佳准确率
        if accuracy > best_accuracy:
            best_accuracy = accuracy
    
    # 可视化特征图
    sample_image, _ = next(iter(test_loader))
    visualize_feature_maps(model, sample_image[0], device)
    
    # 可视化Grad-CAM
    visualize_grad_cam(model, sample_image[0], device, class_names)
    
    print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")

if __name__ == "__main__":
    # 训练MNIST模型
    main(
        dataset_name='MNIST',
        batch_size=64,
        epochs=5,
        lr=0.001,
        gamma=0.7,
        step_size=1,
        use_augmentation=True,
        use_batchnorm=True
    )

@浙大疏锦行

你可能感兴趣的:(超大力王爱学Python,python,开发语言)