关于模型训练中的参数保存与argparse库

argparse库,用于编写命令行接口

ArgumentParser object, 传递命令行参数

对象:parser = argparse.ArgumentParser(description='ImgCap2Embedding')

This description will be displayed when the user requests help (e.g., by running python step02_get_training_data.py --help).

parser.add_argument(...)

方法1:parser.add_argument('--subj', dest='subj', help='', type=str, default='subj01')

The -- indicates it's an optional argument.

value will be stored in the object returned by parse_args(). So, after parsing, you'll access this value as args.subj.

方法2:parser.add_argument('--model_type', dest='model_type', help='', type=str, default="ViT-B/32", choices=["ResNet", "ViT-B/32", "RN50x4"])

  • choices=["ResNet", "ViT-B/32", "RN50x4"]: This is very useful. It restricts the acceptable values for --model_type to only those in the list. If the user provides a value not in this list, argparse will automatically show an error message.

方法3:parser.add_argument('--mode', type=str, default='img_extractor', choices=["img_extractor", "text_extractor", "debug"])

  • Here, dest is not explicitly specified. argparse will automatically set dest='mode' based on the option name '--mode'.

defining various configurable parameters for your script like the subject, model, data directory, processing level, etc.

模型训练,保存模型配置

# 保存模型检查点
def save_ckpt(tag):
    """
    保存模型检查点到指定路径。

    参数:
        tag (str): 检查点标签(如 'best' 或 'epoch-X')。
    """
    if not os.path.exists(f'{args.checkpoint}/{args.subj}/{args.level}'):
        os.makedirs(f'{args.checkpoint}/{args.subj}/{args.level}')  # 创建目录
    ckpt_path = f'{args.checkpoint}/{args.subj}/{args.level}/{tag}.pth'
    print(f'saving {ckpt_path}', flush=True)
    #"./my_checkpoints/subj01/al/best.pth"
    try:
        torch.save({
            'epoch': epoch,
            'model_state_dict': BrainMLP.state_dict(),
            '''
            BrainMLP.state_dict() 返回一个字典,其中包含了模型所有可学习的参数(即权重 weights 和偏置 biases)。
            用途:这是保存模型最核心的部分,它存储了模型通过训练学到的所有知识。加载这个状态字典可以将模型的参数恢复到保存时的状态。
            '''
            'optimizer_state_dict': optimizer.state_dict(),
            '''
            Adam优化器,这可能包括每个参数的一阶矩估计(exp_avg)和二阶矩估计(exp_avg_sq),以及迭代步数等。
            用途:在恢复训练时,加载优化器的状态可以确保训练从完全相同的优化状态继续,而不是重新初始化优化器。
            这对于一些自适应学习率的优化器(如Adam)尤为重要
            '''
            'lr_scheduler': lr_scheduler.state_dict(),
            '''
            含义:这是学习率调度器 (lr_scheduler,例如 torch.optim.lr_scheduler.StepLR, 
            torch.optim.lr_scheduler.ReduceLROnPlateau 等) 的状态字典。
            torch.optim.lr_scheduler.state_dict() 保存了调度器的内部状态,
            比如上一次调整学习率的轮次、当前的计数器值(对于某些调度器)、最佳度量值(如ReduceLROnPlateau)等。
            用途:确保在恢复训练时,学习率能够按照预定的策略继续调整。
            '''
            'train_losses': losses,#训练损失值的列表或历史记录每个epoch的平均训练损失,或者是每个batch的训练损失。
            'val_losses': val_losses,#评估模型在未见过数据上的泛化能力,监控是否发生过拟合,并且常常用于选择“最佳”模型检查点(例如,当验证损失达到最小时)。
            'lrs': lrs,#训练过程中每个轮次或每个优化步骤所使用的学习率的历史记录。
        }, ckpt_path)
    except Exception as e:
        print(f"Couldn't save: {e}... moving on to prevent crashing.")
# --- 这是一个 .pth 文件被 torch.load() 加载后,你得到的 Python 字典内容的示例 ---
# 文件路径示例 (基于默认 args.checkpoint='./ckpt'): ./ckpt/subj07/low/epoch-1.pth

saved_checkpoint_data = {

    'epoch': 1,  # 当前是第 1 个 epoch 结束时 (即已完成2个epoch的训练)

    'model_state_dict': {
        # BrainMLP 模型的参数 (基于默认配置 in_dim=3491, hidden=2048, out_dim=640, n_blocks=4)

        # self.lin0 (nn.Sequential)
        'lin0.0.weight': "torch.Tensor object, shape=(2048, 3491)", # nn.Linear(3491, 2048) 的权重
        'lin0.0.bias':   "torch.Tensor object, shape=(2048,)",      # nn.Linear(3491, 2048) 的偏置
        'lin0.1.weight': "torch.Tensor object, shape=(2048,)",      # nn.LayerNorm(2048) 的权重 (gamma)
        'lin0.1.bias':   "torch.Tensor object, shape=(2048,)",      # nn.LayerNorm(2048) 的偏置 (beta)
        # lin0.2 (nn.GELU) 和 lin0.3 (nn.Dropout) 没有可学习参数

        # self.mlp (nn.ModuleList with 4 blocks)
        # Block 0:
        'mlp.0.0.weight': "torch.Tensor object, shape=(2048, 2048)", # mlp[0].Linear 的权重
        'mlp.0.0.bias':   "torch.Tensor object, shape=(2048,)",      # mlp[0].Linear 的偏置
        'mlp.0.1.weight': "torch.Tensor object, shape=(2048,)",      # mlp[0].LayerNorm 的权重
        'mlp.0.1.bias':   "torch.Tensor object, shape=(2048,)",      # mlp[0].LayerNorm 的偏置
        # Block 1:
        'mlp.1.0.weight': "torch.Tensor object, shape=(2048, 2048)",
        'mlp.1.0.bias':   "torch.Tensor object, shape=(2048,)",
        'mlp.1.1.weight': "torch.Tensor object, shape=(2048,)",
        'mlp.1.1.bias':   "torch.Tensor object, shape=(2048,)",
        # Block 2:
        'mlp.2.0.weight': "torch.Tensor object, shape=(2048, 2048)",
        # ... (以此类推,直到 mlp.3) ...
        'mlp.3.1.bias':   "torch.Tensor object, shape=(2048,)",

        # self.lin1
        'lin1.weight': "torch.Tensor object, shape=(640, 2048)", # nn.Linear(2048, 640) 的权重
        'lin1.bias':   "torch.Tensor object, shape=(640,)"       # nn.Linear(2048, 640) 的偏置

        # 由于 use_projector=False 和 use_linear=False,projector 和 lin_proj 的参数不会在这里
    },

    'optimizer_state_dict': {
        # AdamW 优化器的状态
        'state': {
            # 键是参数的 ID (整数)。假设模型有 N 个独立的参数张量,就会有 0 到 N-1 个条目。
            # 例如,对应 'lin0.0.weight' 参数的状态:
            0: {
                'step': "torch.Tensor object (scalar, e.g., tensor(200.))", # 总共的优化步数 (2 epochs * 100 batches/epoch)
                'exp_avg': "torch.Tensor object, shape=(2048, 3491)",     # AdamW 的一阶矩估计
                'exp_avg_sq': "torch.Tensor object, shape=(2048, 3491)"   # AdamW 的二阶矩估计
            },
            # 例如,对应 'lin0.0.bias' 参数的状态:
            1: {
                'step': "torch.Tensor object (scalar, e.g., tensor(200.))",
                'exp_avg': "torch.Tensor object, shape=(2048,)",
                'exp_avg_sq': "torch.Tensor object, shape=(2048,)"
            },
            # ... 其他所有模型参数对应的优化器状态
        },
        'param_groups': [
            {
                'lr': 0.000000, # LinearLR 在2个epoch结束时,学习率可能已经降到很低或为0 (取决于 total_iters 设置)
                                # 假设 total_iters = 2 * 100 = 200。在第200步,线性衰减可能接近结束。
                                # 如果 total_iters 设置不当或 epochs 很少,这里可能是 0 或一个非常小的值。
                                # 如果 LinearLR 配置为从 max_lr 衰减到 0, 此处是 0。
                'betas': (0.9, 0.999),
                'eps': 1e-08,
                'weight_decay': 0.01, # AdamW 默认的 weight_decay 是 0.01,但代码中没有显式设置,用的是 lr=args.max_lr 初始化。
                                      # PyTorch AdamW 默认 weight_decay 是 0.01。
                                      # filter(lambda p: p.requires_grad, BrainMLP.parameters())
                                      # torch.optim.AdamW(params, lr=args.max_lr) -> 这里的 lr 是 3e-4。
                                      # 所以,lr 会从 3e-4 开始,根据 LinearLR 衰减。
                                      # 假设 LinearLR 的 start_factor=1.0, end_factor=0.0, total_iters=200
                                      # 那么在第200步 (epoch=1结束时),lr 应该是 0.0 * args.max_lr = 0.0
                'amsgrad': False,
                'maximize': False,
                # ... 其他 AdamW 参数组的默认或设置值
                'params': [0, 1, 2, ..., N-1] # 该参数组包含的所有参数的 ID 列表
            }
        ]
    },

    'lr_scheduler': { # torch.optim.lr_scheduler.LinearLR 的状态
        'last_epoch': 200, # LinearLR 的 last_epoch 是按 step 计数的, (2 epochs * 100 batches/epoch)
        'total_iters': 200, # (args.epochs * (len(train_ds) // args.batch_size))
        'start_factor': 1.0/3.0, # LinearLR 默认 start_factor 是 1/3,但代码中没有设置,
                                 # torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=...)
                                 # 如果没指定 start_factor 和 end_factor,它们有默认值。
                                 # PyTorch 1.13+ LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=...)
                                 # 假设使用默认 start_factor=1.0, end_factor=0.0
        'end_factor': 0.0,
        '_last_lr': [0.0] # 最后一个计算出的学习率 (0.0 * args.max_lr)
    },

    'train_losses': [
        # 这是一个很长的列表,包含了两个 epoch 中所有训练 batch 的损失值。
        # (2 epochs * 100 batches/epoch = 200 个损失值)
        # 例如:
        0.85, 0.83, ..., 0.60, # 第一个 epoch 的 batch 损失
        0.58, 0.57, ..., 0.35  # 第二个 epoch 的 batch 损失
    ],

    'val_losses': [
        # 这也是一个列表,包含了两个 epoch 中所有验证 batch 的损失值。
        # (2 epochs * 20 batches/epoch = 40 个损失值)
        # 例如:
        0.75, 0.73, ..., 0.55, # 第一个 epoch 的验证 batch 损失
        0.53, 0.52, ..., 0.40  # 第二个 epoch 的验证 batch 损失
    ],

    'lrs': [
        # 这是一个很长的列表,包含了两个 epoch 中所有训练 batch 计算前的学习率。
        # (200 个学习率值)
        # 对于 LinearLR 从 3e-4 衰减到 0,经过 200步:
        3.000e-4, 2.985e-4, 2.970e-4, ..., # 第一个 epoch 的学习率
        ..., 0.0015e-4, 0.0000e-4          # 第二个 epoch 结束时的学习率,最后为0
    ]
}
# --- 示例结束 ---

这两个字典——'optimizer_state_dict''lr_scheduler' (实际上是 lr_scheduler.state_dict())——对于实现可中断和可恢复的训练过程至关重要。它们能确保当你从一个保存的检查点(checkpoint)恢复训练时,训练能够从它离开时的确切状态无缝地继续下去,而不仅仅是加载模型的权重。

1.对于 Adam/AdamW (你的代码中使用的是AdamW),这些关键状态包括:

  • step (针对每个参数):记录了该参数被更新了多少步。这对于Adam中的偏差修正(bias correction)很重要,确保在训练早期估计的准确性。
  • exp_avg (一阶矩估计,针对每个参数):梯度的指数移动平均值,可以看作是“动量”项,它帮助加速在一致梯度方向上的学习并抑制震荡。
  • exp_avg_sq (二阶矩估计,针对每个参数):梯度平方的指数移动平均值,用于自适应地调整每个参数的学习率梯度变化剧烈的参数学习率会小一些,变化平稳的参数学习率会大一些。

'param_groups' 部分

  • lr (学习率):记录了保存时该参数组的当前学习率。这很重要,因为学习率会通过调度器动态变化。
  • betas, eps, weight_decay: 这些是优化器(如AdamW)的核心超参数。保存它们是为了确保当你加载检查点时,优化器的配置与保存时完全一致。
  • params: 包含属于该参数组的模型参数的ID列表

当你从检查点加载模型以继续训练时,你需要先初始化你的优化器,然后加载这个状态字典:

# 假设 BrainMLP 是你的模型实例,args.max_lr 是初始学习率
optimizer = torch.optim.AdamW(BrainMLP.parameters(), lr=args.max_lr)

# 加载检查点
checkpoint = torch.load(ckpt_path)
BrainMLP.load_state_dict(checkpoint['model_state_dict'])

if 'optimizer_state_dict' in checkpoint:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("优化器状态已加载!")

2.学习率调度器用于在训练过程中动态地调整学习率(例如,逐渐降低学习率,或者在特定轮次降低)。调度器本身也需要维护状态。

  • 对于 LinearLR (你的代码中默认使用的):
    • last_epoch: 这通常是一个内部计数器,记录了调度器的 step() 方法被调用了多少次。在你的代码中,lr_scheduler.step() 是在每个训练批次后调用的,所以 last_epoch 实际上是已处理的批次数 (steps)
    • total_iters: 定义了学习率从开始到结束变化所需的总步数。
    • start_factor, end_factor: 定义了初始学习率和最终学习率相对于优化器中基础学习率的比例。
    • _last_lr: 包含了上一次 step() 调用后计算出的学习率列表(对应每个参数组)。
# ... BrainMLP 和 optimizer 已经初始化并加载了状态 ...

# 初始化学习率调度器
if args.lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, # 关键:使用已经加载了状态的优化器
        total_iters=int(args.epochs * (len(train_ds) // args.batch_size)),
        # start_factor 和 end_factor 等应与初次创建时一致
    )
# elif ... 其他类型的调度器

if 'lr_scheduler' in checkpoint: # 或者你的键名是 'lr_scheduler_state_dict'
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    print("学习率调度器状态已加载!")

注意顺序:通常建议先加载优化器的状态,然后再创建和加载学习率调度器的状态,因为调度器需要关联一个优化器。

保存和加载 optimizer_state_dictlr_scheduler (的 state_dict) 的核心价值在于实现真正的“断点续训”。这使得你的训练过程可以被中断(无论是手动停止、意外崩溃还是为了在不同机器上继续),之后能从完全相同的点恢复,保证了训练过程的一致性和可复现性,避免了因状态重置可能引入的训练行为变化。这对于长时间的训练任务尤其重要。

你可能感兴趣的:(人工智能,算法,深度学习,动态规划,python,机器学习,pycharm)