Python打卡训练营day43——2025.06.02

import torch
import torch.nn as nn
import torch.optim as optim
from data_loader import create_data_loaders
from model import CNNModel
from train import train_model
from visualize import visualize_grad_cam
import os
 
def main():
    # 设置参数
    data_dir = '/kaggle/input/dogs-vs-cats'  # Kaggle猫狗数据集路径
    batch_size = 32
    num_epochs = 5
    learning_rate = 0.001
    img_size = 224
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 加载数据
    print("加载数据...")
    train_loader, test_loader, class_names = create_data_loaders(
        data_dir=data_dir,
        batch_size=batch_size,
        img_size=img_size
    )
    
    # 创建模型
    print("初始化模型...")
    model = CNNModel(num_classes=len(class_names), pretrained=True).to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 训练模型
    print("开始训练...")
    trained_model = train_model(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
        device=device
    )
    
    # 加载最佳模型
    if os.path.exists('best_model.pth'):
        trained_model.load_state_dict(torch.load('best_model.pth'))
        print("已加载最佳模型")
    
    # 可视化Grad-CAM
    print("生成Grad-CAM可视化...")
    test_dir = os.path.join(data_dir, 'test')
    sample_images = []
    
    # 从每个类别中选择一张图像
    for class_name in class_names:
        class_dir = os.path.join(test_dir, class_name)
        if os.path.exists(class_dir):
            img_files = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
            if img_files:
                sample_images.append(os.path.join(class_dir, img_files[0]))
    
    for img_path in sample_images:
        visualize_grad_cam(
            img_path=img_path,
            model=trained_model,
            class_names=class_names,
            transform=train_loader.dataset.transform,
            device=device
        )
 
if __name__ == "__main__":
    main()    

@浙大疏锦行

你可能感兴趣的:(python,深度学习,人工智能)