接续训练调优:model.load_state_dict不会自动加载优化器参数,它仅加载模型的权重参数。要完整恢复训练状态,需要单独处理优化器和学习率调度器。

注意:model.load_state_dict(torch.load(self.config.ckpt_path), strict=False) 不会自动加载优化器和调度器参数,所以会导致接续训练时候loss为什么不是接上的,而是再次训练时候会上升一点点。原因是:学习率和优化器重置了,没有保存之前的状态。

1.标准检查点应包含的内容

完整的训练检查点(checkpoint)通常包含以下键值:

{
    'epoch': 100,                      # 训练轮次
    'model_state_dict': model.state_dict(),      # 模型参数
    'optimizer_state_dict': optimizer.state_dict(),  # 优化器状态
    'scheduler_state_dict': scheduler.state_dict(),  # 学习率调度器状态
    'loss': 0.1234                     # 当前损失值
}

2. 正确加载方式

(1) 仅加载模型权重(当前做法)
model.load_state_dict(torch.load(self.config.ckpt_path)['model_state_dict'], strict=False)
  • ✅ 加载模型参数

  • ❌ 不恢复优化器/训练进度

(2) 完整恢复训练状态
checkpoint = torch.load(self.config.ckpt_path)

# 加载模型
model.load_state_dict(checkpoint['model_state_dict'], strict=False)  

# 加载优化器(需提前定义同类型优化器)
optimizer = AdamW(model.parameters(), lr=0.001)  # 参数需与原始训练一致
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 加载学习率调度器(如果存在)
if 'scheduler_state_dict' in checkpoint:
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

3. 关键注意事项

问题场景 后果 解决方案
优化器未加载 学习率重置,动量缓存丢失 必须显式加载optimizer_state_dict
优化器类型不匹配 加载失败或参数错位 确保优化器超参数(如lr)与保存时一致
GPU/CPU设备切换 张量设备冲突 添加map_location参数:
torch.load(..., map_location='cuda:0')

你可能感兴趣的:(AI相关,学习,深度学习,人工智能)