转载自: https://zhuanlan.zhihu.com/p/7224017315 ,仅作学习记录。
omegaConf 是一个 Python 库,用于管理和操作配置数据。它提供了灵活的方式来处理配置文件、命令行参数和环境变量,并支持嵌套结构、类型检查和动态更新。omegaConf 最初是为 Facebook 的 Hydra 项目开发的,但现在可以独立使用。
在训练模型的过程中,一个很令人头大的前置工作是如何对各项配置参数进行管理。模型结构、训练、评测、优化、数据预处理、数据加载、多机环境等,各项参数加起来经常有大好几十项。此时,某一个参数设置错误往往会导致整个实验白跑,于是实验之前不得不仔细检查每一个参数是否设置正确,真的非常麻烦。
为了解决这一问题,本文介绍omegaconf,一个简洁易用的配置管理系统,可以极大地简化和结构化配置信息,让参数配置简单起来!
omegaconf库官方使用指南: Link
在传统的参数配置方法中,参数往往通过**–key value键值对的方式**传入训练脚本,再通过argparse.ArgumentParser
或者transformers.HfArgumentParser
等参数解析器进行解析。
如llava仓库中的某个训练脚本finetune,sh:
#!/bin/bash
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./checkpoints/$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./playground/data/llava_instruct_80k.json \
--image_folder /path/to/coco/train2017 \
--vision_tower openai/clip-vit-large-patch14 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--bf16 True \
--output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
当参数数目较少时,这种方式简便易用。然而,当参数较多时,这种方式会导致整个训练脚本非常长(实际上上面的脚本已经有三十多行!),而且参数之间缺乏结构化管理,一个一个配置非常容易出错。同时,这种方式将训练设置和shell脚本进行强制绑定,也就是说一组实验就需要保留一个shell脚本,比较冗余。
除了上述方式以外,另外一种较为流行的方式是直接通过yaml等格式的配置文件进行管理。相比于–key value键值对,这种方式的优势在于可以把不同参数按照用途进行分类,便于管理,而且所有参数都在一个配置文件中,非常简洁。
但是,这种方式也具有一定局限性。简单举两点:
相比于–key value和yaml配置文件两种方法,omegaconf的主要优势如下:
举个例子,我们现在有如下需求:
我们通过yaml文件进行参数配置,但是有些字段需要训练时才能决定,比如输出文件夹、batch size等;
我们希望根据某些值的不同,使得同一个yaml文件具备不同的行为,比如debug模式和非debug模式的切换、本地物理机模式和容器模式的切换。
这些需求非常常见,但是传统的参数配置方式却很难完成,而omegaconf却可以轻松做到。接下来我们来看omegaconf的具体用法。
conf = omegaconf.OmegaConf.load('foo.yaml')
。这种方式和yaml.load('foo.yaml')
类似;conf = omegaconf.OmegaConf.from_cli()
。比如运行命令为python ./a.py k1=v1 k2=v2,那么conf的内容为 {‘k1’: ‘v1’, ‘k2’: ‘v2’} ,和字典类似;注意,omegaconf接收的传参格式是
key=value
,而不是–key value
这里重点介绍yaml和CLI变量搭配使用,这一组合是个人觉得最强大实用也最简单方便的。
有时候yaml文件中的某些字段不方便提前确定(比如输出文件夹)。针对这一痛点,omegaconf支持只在yaml中填充确定的参数,而不确定的参数字段的值用???
(三个英文问号)填充。这些不确定的参数字段可以后续通过命令行传入。
举个例子,yaml文件config.yaml中的内容如下:
output_dir: ???
model_args:
num_layers: 2
training_args:
batch_size: ???
arr: [1, 2, 3]
运行命令如下:
python ./train.py config_path=config.yaml output_dir=./.checkpoints training_args.batch_size=4
在train.py中,我们通过如下的代码解析参数:
from omegaconf import OmegaConf
conf_cli = OmegaConf.from_cli()
conf_yaml = OmegaConf.load(conf_cli.config_path)
conf = OmegaConf.merge(conf_cli, conf_yaml)
首先,我们通过conf_cli = OmegaConf.from_cli()
获取命令行传入的参数,
然后,我们再通过conf_yaml = OmegaConf.load(conf_cli.config_path)
从config.yaml中读取配置信息,可以看到,output_dir和training_args.batch_size字段为???。此时如果我们访问这些字段,print(conf_yaml.output_dir)
将会报错。
为了填充这些字段,我们通过conf = omegaconf.OmegaConf.merge(conf_yaml, conf_cli)
将yaml中和命令行中的配置信息进行合并,这样最终得到的配置信息如下:
{
'output_dir': './.checkpoints',
'model_args': {'num_layers': 2},
'config_path': 'config.yaml',
'training_args': {'batch_size': 4, 'arr': [1, 2, 3]}
}
此时可以发现,完整的配置信息中不包含任意缺失的字段,再次获取output_dir等字段的值不会报错。
有两点需要注意:
最终合并之后的配置信息中该字段的值将和命令行传入的值保持一致
,也就是说命令行参数的优先级高于yaml文件
。这一点也为临时修改某些配置提供了便利。比如,我们只希望临时修改某个参数的值看看效果,此时我们只需要在命令行中传入该参数的值即可,不需要额外创建一份yaml文件;omegaconf.dictconfig.DictConfig
;又比如conf.training_args.arr像是列表,但其实是omegaconf定义的一个类omegaconf.listconfig.ListConfig
。不过,omegaconf也提供了将上述conf转换为python内置数据类型的函数omegaconf.OmegaConf.to_object
,用于将DictConfig转为dict、将ListConfig转为list等。conf.training_args.arr[0]
-> 1omegaconf.OmegaConf.select(conf,"training_args.arr[0]")
-> 1conf.training_args.arr = [1, 2, 3]
omegaconf.OmegaConf.update(conf, "training_args.arr", [1, 2, 3])
omegaconf.OmegaConf.save(conf, "config.yaml")
omegaconf.OmegaConf.load("config.yaml")
除了上面的基础用法以外,omegaconf还提供了一些功能非常强大的高阶用法,本节主要介绍变量插值和变量解析器 (resolver) 两种。
变量插值指的是,在yaml文件中,我们可以将每个字段都视为一个类似python变量(list、dict、str等)的东西,然后在某一个字段中通过其余字段的值来对当前字段进行赋值。
变量插值的基础语法为${...}
。也就是说,假设某一个字段key的值为${training_args.value},那么key实际的值是training_args.value字段对应的值。
举个例子,假设config.yaml中的内容如下:
training_args:
batch_size: 10000
arr: [1, 2, 3]
dict:
k: v
key: 'k'
value: "${training_args.dict[${key}]}_${training_args.arr[0]}"
然后加载该文件:
config = omegaconf.OmegaConf.load("./config.yaml")
config变量中的内容为:
{
'training_args': {
'batch_size': 10000, 'arr': [1, 2, 3], 'dict': {'k': 'v'}
},
'key': 'k',
'value': '${training_args.dict[${key}]}_${training_args.arr[0]}'
}
此时,config中的内容和config.yaml文件中的内容完全一致,没有什么特别的。但是,当我们尝试获取config.value的值时,神奇的事情发生了:
print(config.value)
# v_1
输出的值为v_1,而不是’KaTeX parse error: Expected '}', got 'EOF' at end of input: …ning_args.dict[{key}]}_${training_args.arr[0]}'。
这是因为,由于这个字段的值包含变量插值,omegaconf对其进行了解析。即依次执行如下替换:
变量插值的解析只有当获取该值时才会发生。
除了变量插值以外,omegaconf另外一个很有用的特性是变量解析器 (resolver)。变量解析器的用途是自定义变量的解析方式,以实现更高级、灵活的变量解析。
OmegaConf支持用户自定义变量解析器,从而实现根据自定义规则来进行变量解析。同时,OmegaConf还提供了一些内置的解析器。
简单起见,本节只介绍一个内置变量解析器oc.env
,该解析器的作用是获取系统的环境变量。比如,我们希望将模型的输出文件夹指定为当前目录下的.checkpoints文件夹,那么就可以通过如下的yaml实现:
output_dir: ${oc:env:PWD}/.checkpoints
当配置信息和系统的环境变量存在一定绑定关系时,上述方式非常有用。
omegaConf 支持从环境变量中读取值。
import os
from omegaconf import OmegaConf
# 设置环境变量
os.environ["DB_HOST"] = "localhost"
# 使用环境变量插值
config = OmegaConf.create({
"database": {
"host": "${env:DB_HOST}",
"port": 5432
}
})
print(config.database.host) # 输出: localhost
omegaConf 支持类型检查,确保配置值的类型正确。
from omegaconf import OmegaConf, DictConfig
# 创建带类型检查的配置
config = OmegaConf.create({
"database": {
"host": "localhost",
"port": 5432
}
}, flags={"struct": True})
# 尝试修改为错误类型会抛出异常
try:
config.database.port = "invalid" # 应该为整数
except Exception as e:
print(f"Error: {e}")
omegaConf 支持动态更新配置。
from omegaconf import OmegaConf
config = OmegaConf.create({
"database": {
"host": "localhost",
"port": 5432
}
})
# 动态更新
OmegaConf.update(config, "database.host", "127.0.0.1")
print(config.database.host) # 输出: 127.0.0.1