【AI大模型】PyTorch Lightning 简化工具

PyTorch Lightning 是一个轻量级的 PyTorch 封装库,它通过抽象训练循环的工程细节,让研究人员可以专注于模型设计和实验。以下是 PyTorch Lightning 的核心概念和实战指南。

核心优势

【AI大模型】PyTorch Lightning 简化工具_第1张图片

基础使用:三步搭建训练流程

1. 定义 LightningModule

import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchmetrics import Accuracy

class SimpleClassifier(pl.LightningModule):
    def __init__(self, input_size=28*28, hidden_size=128, num_classes=10, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()  # 保存超参数
        
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes)
        )
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.accuracy(logits, y)
        
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = self.accuracy(logits, y)
        
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        self.log("test_acc", acc)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

2. 准备数据

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split

# 数据集
dataset = MNIST(root="data", download=True, transform=ToTensor())
train_ds, val_ds = random_split(dataset, [55000, 5000])
test_ds = MNIST(root="data", train=False, transform=ToTensor())

# DataLoader
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=32, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=32, num_workers=4)

3. 训练模型

# 创建模型
model = SimpleClassifier(lr=1e-3)

# 创建训练器
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto",  # 自动选择GPU/CPU
    devices="auto",      # 使用所有可用设备
    logger=True,         # 内置TensorBoard日志
    deterministic=True,  # 确保可复现性
    enable_progress_bar=True,
)

# 训练和验证
trainer.fit(model, train_loader, val_loader)

# 测试
trainer.test(model, dataloaders=test_loader)

核心功能详解

1. 自动设备管理

trainer = pl.Trainer(
    accelerator="gpu",  # 使用GPU
    devices=2,          # 使用2个GPU
    strategy="ddp",     # 分布式数据并行
)

2. 高级训练控制

trainer = pl.Trainer(
    max_epochs=10,
    min_epochs=3,
    max_steps=1000,     # 最大训练步数
    gradient_clip_val=0.5,  # 梯度裁剪
    precision="16-mixed",   # 混合精度训练
    accumulate_grad_batches=4,  # 梯度累积
)

3. 内置回调系统

from pytorch_lightning.callbacks import (
    EarlyStopping, 
    ModelCheckpoint,
    LearningRateMonitor
)

callbacks = [
    EarlyStopping(monitor="val_loss", patience=3),
    ModelCheckpoint(
        dirpath="checkpoints",
        filename="best-model-{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        save_top_k=3,
    ),
    LearningRateMonitor(logging_interval="step"),
]

trainer = pl.Trainer(callbacks=callbacks)

高级功能实战

1. 学习率调度器

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode="min", 
        factor=0.1, 
        patience=2
    )
    
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": scheduler,
            "monitor": "val_loss",  # 监控验证损失
            "frequency": 1           # 每个epoch检查一次
        }
    }

2. 多优化器支持

def configure_optimizers(self):
    # 生成器优化器
    gen_opt = torch.optim.Adam(self.generator.parameters(), lr=1e-4)
    
    # 判别器优化器
    disc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)
    
    return [gen_opt, disc_opt], []  # 多个优化器,无调度器

3. 自定义训练步骤

def training_step(self, batch, batch_idx, optimizer_idx):
    x, y = batch
    
    if optimizer_idx == 0:  # 生成器训练
        # 生成假样本
        fake = self.generator(x)
        # 判别器输出
        d_out = self.discriminator(fake)
        # 生成器损失
        loss = self.generator_loss(d_out)
        self.log("gen_loss", loss)
        return loss
        
    if optimizer_idx == 1:  # 判别器训练
        # 真实样本判别
        real_out = self.discriminator(y)
        # 假样本判别
        fake = self.generator(x).detach()
        fake_out = self.discriminator(fake)
        # 判别器损失
        loss = self.discriminator_loss(real_out, fake_out)
        self.log("disc_loss", loss)
        return loss

调试与优化技巧

1. 快速开发调试

trainer = pl.Trainer(
    fast_dev_run=True,    # 只运行一个batch
    overfit_batches=10,   # 在小批次上过拟合
    limit_train_batches=0.1,  # 只使用10%的训练数据
)

2. 性能分析

trainer = pl.Trainer(
    profiler="simple",  # 简单性能分析
    # profiler="pytorch",  # 高级PyTorch分析器
)

3. 梯度监控

def on_after_backward(self):
    # 记录梯度范数
    total_norm = 0
    for p in self.parameters():
        if p.grad is not None:
            param_norm = p.grad.detach().data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    self.log("grad_norm", total_norm)

部署与生产化

1. 模型保存与加载

# 保存完整模型
trainer.save_checkpoint("model.ckpt")

# 加载模型
model = SimpleClassifier.load_from_checkpoint("model.ckpt")

# 导出为TorchScript
script = model.to_torchscript()
torch.jit.save(script, "model.pt")

2. 生产环境推理

class InferenceModel(pl.LightningModule):
    def __init__(self, model_path):
        super().__init__()
        self.model = SimpleClassifier.load_from_checkpoint(model_path)
        self.model.eval()
    
    def predict_step(self, batch, batch_idx):
        x, _ = batch
        return torch.argmax(self.model(x), dim=1)

# 创建预测器
predictor = InferenceModel("model.ckpt")
trainer = pl.Trainer(accelerator="auto")
predictions = trainer.predict(predictor, test_loader)

与原生 PyTorch 对比

功能 原生 PyTorch PyTorch Lightning
训练循环 手动实现 自动处理
设备管理 手动处理设备 自动处理
分布式训练 复杂实现 一行代码
混合精度 手动配置 参数设置
日志记录 手动实现 内置支持
模型检查点 手动保存 自动回调
超参数记录 手动记录 自动保存
代码复用

最佳实践

  1. 模块化设计

    class Encoder(nn.Module):
        ...
    
    class Decoder(nn.Module):
        ...
    
    class LitModel(pl.LightningModule):
        def __init__(self):
            self.encoder = Encoder()
            self.decoder = Decoder()

  2. 参数化配置

    class LitModel(pl.LightningModule):
        def __init__(self, learning_rate=1e-3, hidden_size=128):
            self.save_hyperparameters()
            ...

  3. 使用 LightningDataModule

    class MNISTDataModule(pl.LightningDataModule):
        def __init__(self, batch_size=32):
            super().__init__()
            self.batch_size = batch_size
        
        def setup(self, stage=None):
            self.train_ds = ...
            self.val_ds = ...
        
        def train_dataloader(self):
            return DataLoader(self.train_ds, batch_size=self.batch_size)
        
        def val_dataloader(self):
            return DataLoader(self.val_ds, batch_size=self.batch_size)
    
    # 使用
    dm = MNISTDataModule()
    trainer.fit(model, dm)

  4. 充分利用回调

    trainer = pl.Trainer(callbacks=[
        pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-2),
        pl.callbacks.RichProgressBar(),
        pl.callbacks.DeviceStatsMonitor()
    ])

PyTorch Lightning 通过标准化训练流程,减少了重复代码,同时保持了 PyTorch 的灵活性。它特别适合需要快速迭代实验的研究场景,以及需要可复现、可扩展的生产环境。

你可能感兴趣的:(学习AI记录,人工智能,pytorch,python,ai,AI编程)