当模型进行训练时,性能会随着它继续看到更多数据而发生变化。最佳实践是在整个训练过程中保存模型的状态。这将在模型开发过程中的每个关键点为您提供一个模型版本,即一个检查点。训练完成后,使用与您在训练过程中发现的最佳性能相对应的检查点。
检查点还使您的训练能够在训练过程中断的情况下从原来的位置恢复。
PyTorch Lightning 检查点在普通 PyTorch 中完全可用。
Lightning 检查点包含模型整个内部状态的转储。与普通的 PyTorch 不同,Lightning 可以保存恢复模型所需的一切,即使在最复杂的分布式训练环境中也是如此。
Lightning 会自动在您当前工作目录中为您保存一个检查点,其中包含您上一个训练 epoch 的状态。这可确保在训练中断时可以恢复训练。
# simply by using the Trainer you get automatic checkpointing
trainer = Trainer()
要更改检查点路径,请使用 default_root_dir
参数:
# saves checkpoints to 'some/path/' at every epoch end
trainer = Trainer(default_root_dir="some/path/")
要加载 LightningModule
及其权重和超参数,请使用以下方法:
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
# disable randomness, dropout, etc...
model.eval()
# predict with the model
y_hat = model(x)
LightningModule
允许您通过调用 self.save_hyperparameters()
自动保存传递给 init
的所有超参数。
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, another_parameter, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
超参数将保存到检查点中的 hyper_parameters
键中
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}
LightningModule
还可以访问 Hyperparameters
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)
Lightning
检查点与普通torch nn.Modules
完全兼容。
checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())
例如,假设创建了一个 LightningModule
,如下所示:
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class Autoencoder(L.LightningModule):
def __init__(self, encoder, decoder, *args, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
autoencoder = Autoencoder(Encoder(), Decoder())
自动编码器训练完成后,通过以下代码为 torch nn.Modules
提取相关的权重:
checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
通过传递以下命令来禁用检查点:
trainer = Trainer(enable_checkpointing=False)
如果不仅想加载权重,还想恢复完整训练,请执行以下命令:
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, ckpt_path="path/to/your/checkpoint.ckpt")