如何使用 PyTorch Lightning 保存和加载检查点

【PL 基础】如何保存和加载检查点

  • 1. 什么是检查点(checkpoint)?
  • 2. 检查点有哪些内容
  • 3. 如何保存检查点
  • 4. 如何加载检查点
  • 5. 保存超参数
  • 6. nn.Module 的检查点
  • 7. 禁用 checkpointing
  • 8. 恢复训练状态

1. 什么是检查点(checkpoint)?

  当模型进行训练时,性能会随着它继续看到更多数据而发生变化。最佳实践是在整个训练过程中保存模型的状态。这将在模型开发过程中的每个关键点为您提供一个模型版本,即一个检查点。训练完成后,使用与您在训练过程中发现的最佳性能相对应的检查点。

  检查点还使您的训练能够在训练过程中断的情况下从原来的位置恢复。

  PyTorch Lightning 检查点在普通 PyTorch 中完全可用。

2. 检查点有哪些内容

  Lightning 检查点包含模型整个内部状态的转储。与普通的 PyTorch 不同,Lightning 可以保存恢复模型所需的一切,即使在最复杂的分布式训练环境中也是如此。

  • 16-bit scaling factor (if using 16-bit precision training) 16 位缩放因子(如果使用 16 位精度训练)
  • Current epoch 当前代
  • Global step 全局运行步数
  • LightningModule’s state_dict 模型参数
  • State of all optimizers 所有优化器的状态
  • State of all learning rate schedulers 所有学习率调度器的状态
  • State of all callbacks (for stateful callbacks) 所有回调的状态(用于有状态回调)
  • State of datamodule (for stateful datamodules) DataModule 的状态(用于有状态数据模块)
  • The hyperparameters (init arguments) with which the model was created
    用于创建模型的超参数 (init 参数)
  • The hyperparameters (init arguments) with which the datamodule was created
    用于创建 datamodule 的超参数 (init 参数)
  • State of Loops 循环状态

3. 如何保存检查点

  Lightning 会自动在您当前工作目录中为您保存一个检查点,其中包含您上一个训练 epoch 的状态。这可确保在训练中断时可以恢复训练。

# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()

要更改检查点路径,请使用 default_root_dir 参数:

# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")

4. 如何加载检查点

  要加载 LightningModule 及其权重和超参数,请使用以下方法:

model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

# disable randomness, dropout, etc...
model.eval()

# predict with the model
y_hat = model(x)

5. 保存超参数

   LightningModule 允许您通过调用 self.save_hyperparameters() 自动保存传递给 init 的所有超参数。

class MyLightningModule(LightningModule):
    def __init__(self, learning_rate, another_parameter, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

超参数将保存到检查点中的 hyper_parameters 键中

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

LightningModule 还可以访问 Hyperparameters

model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)

6. nn.Module 的检查点

   Lightning 检查点与普通torch nn.Modules 完全兼容。

checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())

例如,假设创建了一个 LightningModule,如下所示:

class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class Autoencoder(L.LightningModule):
    def __init__(self, encoder, decoder, *args, **kwargs):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder


autoencoder = Autoencoder(Encoder(), Decoder())

自动编码器训练完成后,通过以下代码为 torch nn.Modules 提取相关的权重:

checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}

7. 禁用 checkpointing

  通过传递以下命令来禁用检查点:

trainer = Trainer(enable_checkpointing=False)

8. 恢复训练状态

  如果不仅想加载权重,还想恢复完整训练,请执行以下命令:

model = LitModel()
trainer = Trainer()

# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")

你可能感兴趣的:(PyTorch,Lightning,pytorch,人工智能,Lighting)