ArgumentParser
object, 传递命令行参数
对象:parser = argparse.ArgumentParser(description='ImgCap2Embedding')
This description will be displayed when the user requests help (e.g., by running python step02_get_training_data.py --help
).
parser.add_argument(...)
方法1:parser.add_argument('--subj', dest='subj', help='', type=str, default='subj01')
The --
indicates it's an optional argument.
value will be stored in the object returned by parse_args()
. So, after parsing, you'll access this value as args.subj
.
方法2:parser.add_argument('--model_type', dest='model_type', help='', type=str, default="ViT-B/32", choices=["ResNet", "ViT-B/32", "RN50x4"])
choices=["ResNet", "ViT-B/32", "RN50x4"]
: This is very useful. It restricts the acceptable values for --model_type
to only those in the list. If the user provides a value not in this list, argparse
will automatically show an error message.方法3:parser.add_argument('--mode', type=str, default='img_extractor', choices=["img_extractor", "text_extractor", "debug"])
dest
is not explicitly specified. argparse
will automatically set dest='mode'
based on the option name '--mode'
.defining various configurable parameters for your script like the subject, model, data directory, processing level, etc.
# 保存模型检查点
def save_ckpt(tag):
"""
保存模型检查点到指定路径。
参数:
tag (str): 检查点标签(如 'best' 或 'epoch-X')。
"""
if not os.path.exists(f'{args.checkpoint}/{args.subj}/{args.level}'):
os.makedirs(f'{args.checkpoint}/{args.subj}/{args.level}') # 创建目录
ckpt_path = f'{args.checkpoint}/{args.subj}/{args.level}/{tag}.pth'
print(f'saving {ckpt_path}', flush=True)
#"./my_checkpoints/subj01/al/best.pth"
try:
torch.save({
'epoch': epoch,
'model_state_dict': BrainMLP.state_dict(),
'''
BrainMLP.state_dict() 返回一个字典,其中包含了模型所有可学习的参数(即权重 weights 和偏置 biases)。
用途:这是保存模型最核心的部分,它存储了模型通过训练学到的所有知识。加载这个状态字典可以将模型的参数恢复到保存时的状态。
'''
'optimizer_state_dict': optimizer.state_dict(),
'''
Adam优化器,这可能包括每个参数的一阶矩估计(exp_avg)和二阶矩估计(exp_avg_sq),以及迭代步数等。
用途:在恢复训练时,加载优化器的状态可以确保训练从完全相同的优化状态继续,而不是重新初始化优化器。
这对于一些自适应学习率的优化器(如Adam)尤为重要
'''
'lr_scheduler': lr_scheduler.state_dict(),
'''
含义:这是学习率调度器 (lr_scheduler,例如 torch.optim.lr_scheduler.StepLR,
torch.optim.lr_scheduler.ReduceLROnPlateau 等) 的状态字典。
torch.optim.lr_scheduler.state_dict() 保存了调度器的内部状态,
比如上一次调整学习率的轮次、当前的计数器值(对于某些调度器)、最佳度量值(如ReduceLROnPlateau)等。
用途:确保在恢复训练时,学习率能够按照预定的策略继续调整。
'''
'train_losses': losses,#训练损失值的列表或历史记录每个epoch的平均训练损失,或者是每个batch的训练损失。
'val_losses': val_losses,#评估模型在未见过数据上的泛化能力,监控是否发生过拟合,并且常常用于选择“最佳”模型检查点(例如,当验证损失达到最小时)。
'lrs': lrs,#训练过程中每个轮次或每个优化步骤所使用的学习率的历史记录。
}, ckpt_path)
except Exception as e:
print(f"Couldn't save: {e}... moving on to prevent crashing.")
# --- 这是一个 .pth 文件被 torch.load() 加载后,你得到的 Python 字典内容的示例 ---
# 文件路径示例 (基于默认 args.checkpoint='./ckpt'): ./ckpt/subj07/low/epoch-1.pth
saved_checkpoint_data = {
'epoch': 1, # 当前是第 1 个 epoch 结束时 (即已完成2个epoch的训练)
'model_state_dict': {
# BrainMLP 模型的参数 (基于默认配置 in_dim=3491, hidden=2048, out_dim=640, n_blocks=4)
# self.lin0 (nn.Sequential)
'lin0.0.weight': "torch.Tensor object, shape=(2048, 3491)", # nn.Linear(3491, 2048) 的权重
'lin0.0.bias': "torch.Tensor object, shape=(2048,)", # nn.Linear(3491, 2048) 的偏置
'lin0.1.weight': "torch.Tensor object, shape=(2048,)", # nn.LayerNorm(2048) 的权重 (gamma)
'lin0.1.bias': "torch.Tensor object, shape=(2048,)", # nn.LayerNorm(2048) 的偏置 (beta)
# lin0.2 (nn.GELU) 和 lin0.3 (nn.Dropout) 没有可学习参数
# self.mlp (nn.ModuleList with 4 blocks)
# Block 0:
'mlp.0.0.weight': "torch.Tensor object, shape=(2048, 2048)", # mlp[0].Linear 的权重
'mlp.0.0.bias': "torch.Tensor object, shape=(2048,)", # mlp[0].Linear 的偏置
'mlp.0.1.weight': "torch.Tensor object, shape=(2048,)", # mlp[0].LayerNorm 的权重
'mlp.0.1.bias': "torch.Tensor object, shape=(2048,)", # mlp[0].LayerNorm 的偏置
# Block 1:
'mlp.1.0.weight': "torch.Tensor object, shape=(2048, 2048)",
'mlp.1.0.bias': "torch.Tensor object, shape=(2048,)",
'mlp.1.1.weight': "torch.Tensor object, shape=(2048,)",
'mlp.1.1.bias': "torch.Tensor object, shape=(2048,)",
# Block 2:
'mlp.2.0.weight': "torch.Tensor object, shape=(2048, 2048)",
# ... (以此类推,直到 mlp.3) ...
'mlp.3.1.bias': "torch.Tensor object, shape=(2048,)",
# self.lin1
'lin1.weight': "torch.Tensor object, shape=(640, 2048)", # nn.Linear(2048, 640) 的权重
'lin1.bias': "torch.Tensor object, shape=(640,)" # nn.Linear(2048, 640) 的偏置
# 由于 use_projector=False 和 use_linear=False,projector 和 lin_proj 的参数不会在这里
},
'optimizer_state_dict': {
# AdamW 优化器的状态
'state': {
# 键是参数的 ID (整数)。假设模型有 N 个独立的参数张量,就会有 0 到 N-1 个条目。
# 例如,对应 'lin0.0.weight' 参数的状态:
0: {
'step': "torch.Tensor object (scalar, e.g., tensor(200.))", # 总共的优化步数 (2 epochs * 100 batches/epoch)
'exp_avg': "torch.Tensor object, shape=(2048, 3491)", # AdamW 的一阶矩估计
'exp_avg_sq': "torch.Tensor object, shape=(2048, 3491)" # AdamW 的二阶矩估计
},
# 例如,对应 'lin0.0.bias' 参数的状态:
1: {
'step': "torch.Tensor object (scalar, e.g., tensor(200.))",
'exp_avg': "torch.Tensor object, shape=(2048,)",
'exp_avg_sq': "torch.Tensor object, shape=(2048,)"
},
# ... 其他所有模型参数对应的优化器状态
},
'param_groups': [
{
'lr': 0.000000, # LinearLR 在2个epoch结束时,学习率可能已经降到很低或为0 (取决于 total_iters 设置)
# 假设 total_iters = 2 * 100 = 200。在第200步,线性衰减可能接近结束。
# 如果 total_iters 设置不当或 epochs 很少,这里可能是 0 或一个非常小的值。
# 如果 LinearLR 配置为从 max_lr 衰减到 0, 此处是 0。
'betas': (0.9, 0.999),
'eps': 1e-08,
'weight_decay': 0.01, # AdamW 默认的 weight_decay 是 0.01,但代码中没有显式设置,用的是 lr=args.max_lr 初始化。
# PyTorch AdamW 默认 weight_decay 是 0.01。
# filter(lambda p: p.requires_grad, BrainMLP.parameters())
# torch.optim.AdamW(params, lr=args.max_lr) -> 这里的 lr 是 3e-4。
# 所以,lr 会从 3e-4 开始,根据 LinearLR 衰减。
# 假设 LinearLR 的 start_factor=1.0, end_factor=0.0, total_iters=200
# 那么在第200步 (epoch=1结束时),lr 应该是 0.0 * args.max_lr = 0.0
'amsgrad': False,
'maximize': False,
# ... 其他 AdamW 参数组的默认或设置值
'params': [0, 1, 2, ..., N-1] # 该参数组包含的所有参数的 ID 列表
}
]
},
'lr_scheduler': { # torch.optim.lr_scheduler.LinearLR 的状态
'last_epoch': 200, # LinearLR 的 last_epoch 是按 step 计数的, (2 epochs * 100 batches/epoch)
'total_iters': 200, # (args.epochs * (len(train_ds) // args.batch_size))
'start_factor': 1.0/3.0, # LinearLR 默认 start_factor 是 1/3,但代码中没有设置,
# torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=...)
# 如果没指定 start_factor 和 end_factor,它们有默认值。
# PyTorch 1.13+ LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=...)
# 假设使用默认 start_factor=1.0, end_factor=0.0
'end_factor': 0.0,
'_last_lr': [0.0] # 最后一个计算出的学习率 (0.0 * args.max_lr)
},
'train_losses': [
# 这是一个很长的列表,包含了两个 epoch 中所有训练 batch 的损失值。
# (2 epochs * 100 batches/epoch = 200 个损失值)
# 例如:
0.85, 0.83, ..., 0.60, # 第一个 epoch 的 batch 损失
0.58, 0.57, ..., 0.35 # 第二个 epoch 的 batch 损失
],
'val_losses': [
# 这也是一个列表,包含了两个 epoch 中所有验证 batch 的损失值。
# (2 epochs * 20 batches/epoch = 40 个损失值)
# 例如:
0.75, 0.73, ..., 0.55, # 第一个 epoch 的验证 batch 损失
0.53, 0.52, ..., 0.40 # 第二个 epoch 的验证 batch 损失
],
'lrs': [
# 这是一个很长的列表,包含了两个 epoch 中所有训练 batch 计算前的学习率。
# (200 个学习率值)
# 对于 LinearLR 从 3e-4 衰减到 0,经过 200步:
3.000e-4, 2.985e-4, 2.970e-4, ..., # 第一个 epoch 的学习率
..., 0.0015e-4, 0.0000e-4 # 第二个 epoch 结束时的学习率,最后为0
]
}
# --- 示例结束 ---
这两个字典——'optimizer_state_dict'
和 'lr_scheduler'
(实际上是 lr_scheduler.state_dict()
)——对于实现可中断和可恢复的训练过程至关重要。它们能确保当你从一个保存的检查点(checkpoint)恢复训练时,训练能够从它离开时的确切状态无缝地继续下去,而不仅仅是加载模型的权重。
1.对于 Adam/AdamW (你的代码中使用的是AdamW),这些关键状态包括:
step
(针对每个参数):记录了该参数被更新了多少步。这对于Adam中的偏差修正(bias correction)很重要,确保在训练早期估计的准确性。exp_avg
(一阶矩估计,针对每个参数):梯度的指数移动平均值,可以看作是“动量”项,它帮助加速在一致梯度方向上的学习并抑制震荡。exp_avg_sq
(二阶矩估计,针对每个参数):梯度平方的指数移动平均值,用于自适应地调整每个参数的学习率。梯度变化剧烈的参数学习率会小一些,变化平稳的参数学习率会大一些。'param_groups'
部分:
lr
(学习率):记录了保存时该参数组的当前学习率。这很重要,因为学习率会通过调度器动态变化。betas
, eps
, weight_decay
等: 这些是优化器(如AdamW)的核心超参数。保存它们是为了确保当你加载检查点时,优化器的配置与保存时完全一致。params
: 包含属于该参数组的模型参数的ID列表当你从检查点加载模型以继续训练时,你需要先初始化你的优化器,然后加载这个状态字典:
# 假设 BrainMLP 是你的模型实例,args.max_lr 是初始学习率
optimizer = torch.optim.AdamW(BrainMLP.parameters(), lr=args.max_lr)
# 加载检查点
checkpoint = torch.load(ckpt_path)
BrainMLP.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print("优化器状态已加载!")
2.学习率调度器用于在训练过程中动态地调整学习率(例如,逐渐降低学习率,或者在特定轮次降低)。调度器本身也需要维护状态。
LinearLR
(你的代码中默认使用的):
last_epoch
: 这通常是一个内部计数器,记录了调度器的 step()
方法被调用了多少次。在你的代码中,lr_scheduler.step()
是在每个训练批次后调用的,所以 last_epoch
实际上是已处理的批次数 (steps)。total_iters
: 定义了学习率从开始到结束变化所需的总步数。start_factor
, end_factor
: 定义了初始学习率和最终学习率相对于优化器中基础学习率的比例。_last_lr
: 包含了上一次 step()
调用后计算出的学习率列表(对应每个参数组)。# ... BrainMLP 和 optimizer 已经初始化并加载了状态 ...
# 初始化学习率调度器
if args.lr_scheduler_type == 'linear':
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, # 关键:使用已经加载了状态的优化器
total_iters=int(args.epochs * (len(train_ds) // args.batch_size)),
# start_factor 和 end_factor 等应与初次创建时一致
)
# elif ... 其他类型的调度器
if 'lr_scheduler' in checkpoint: # 或者你的键名是 'lr_scheduler_state_dict'
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
print("学习率调度器状态已加载!")
注意顺序:通常建议先加载优化器的状态,然后再创建和加载学习率调度器的状态,因为调度器需要关联一个优化器。
保存和加载 optimizer_state_dict
和 lr_scheduler
(的 state_dict
) 的核心价值在于实现真正的“断点续训”。这使得你的训练过程可以被中断(无论是手动停止、意外崩溃还是为了在不同机器上继续),之后能从完全相同的点恢复,保证了训练过程的一致性和可复现性,避免了因状态重置可能引入的训练行为变化。这对于长时间的训练任务尤其重要。