使用 PyTorch 和 SwanLab 实时可视化模型训练

以下是一个使用 PyTorch 和 SwanLab 实现训练可视化监控的完整示例,以 MNIST 手写数字识别为例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import swanlab

# 初始化 SwanLab 实验 (自动生成仪表盘)
swanlab.init(
    experiment_name="MNIST_CNN",
    description="Simple CNN on MNIST with SwanLab monitoring",
    config={
   
        "batch_size": 64,
        "epochs": 10,
        "learning_rate": 0.01,
        "model": "CNN"
    }
)

# 1. 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=swanlab.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False

你可能感兴趣的:(pytorch,深度学习,人工智能)