【python第三方库】omegaconf库入门教程

转载自: https://zhuanlan.zhihu.com/p/7224017315 ,仅作学习记录。

文章目录

  • 前言
    • 传统参数配置痛点
    • omegaconf的优势
  • 基础用法
    • 创建配置
    • 获取和更新值
      • 获取值
      • 更新值
    • 写入和读取本地文件
  • 高阶用法
    • 变量插值
    • 变量解析器
    • 类型检查
    • 动态更新

前言

omegaConf 是一个 Python 库,用于管理和操作配置数据。它提供了灵活的方式来处理配置文件、命令行参数和环境变量,并支持嵌套结构、类型检查和动态更新。omegaConf 最初是为 Facebook 的 Hydra 项目开发的,但现在可以独立使用。

在训练模型的过程中,一个很令人头大的前置工作是如何对各项配置参数进行管理。模型结构、训练、评测、优化、数据预处理、数据加载、多机环境等,各项参数加起来经常有大好几十项。此时,某一个参数设置错误往往会导致整个实验白跑,于是实验之前不得不仔细检查每一个参数是否设置正确,真的非常麻烦。

为了解决这一问题,本文介绍omegaconf,一个简洁易用的配置管理系统,可以极大地简化和结构化配置信息,让参数配置简单起来!

omegaconf库官方使用指南: Link

传统参数配置痛点

在传统的参数配置方法中,参数往往通过**–key value键值对的方式**传入训练脚本,再通过argparse.ArgumentParser或者transformers.HfArgumentParser等参数解析器进行解析。
如llava仓库中的某个训练脚本finetune,sh:

#!/bin/bash

deepspeed llava/train/train_mem.py \
    --deepspeed ./scripts/zero2.json \
    --model_name_or_path ./checkpoints/$MODEL_VERSION \
    --version $PROMPT_VERSION \
    --data_path ./playground/data/llava_instruct_80k.json \
    --image_folder /path/to/coco/train2017 \
    --vision_tower openai/clip-vit-large-patch14 \
    --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --bf16 True \
    --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 1 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 50000 \
    --save_total_limit 1 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --dataloader_num_workers 4 \
    --lazy_preprocess True \
    --report_to wandb

当参数数目较少时,这种方式简便易用。然而,当参数较多时,这种方式会导致整个训练脚本非常长(实际上上面的脚本已经有三十多行!),而且参数之间缺乏结构化管理,一个一个配置非常容易出错。同时,这种方式将训练设置和shell脚本进行强制绑定,也就是说一组实验就需要保留一个shell脚本,比较冗余

除了上述方式以外,另外一种较为流行的方式是直接通过yaml等格式的配置文件进行管理。相比于–key value键值对,这种方式的优势在于可以把不同参数按照用途进行分类,便于管理,而且所有参数都在一个配置文件中,非常简洁。

但是,这种方式也具有一定局限性。简单举两点:

  • 不够灵活:比如,某两组实验除了模型保存路径以外,其余设置完全相同。此时,由于输出文件夹在配置文件中写死了,所以我们仍然需要两个配置文件;
  • 缺乏值的共享:比如,我们希望不同参数共享同一个值,这时我们仍然需要为这些参数一一设置,没法实现值的共享,设置的时候容易出错。

omegaconf的优势

相比于–key value和yaml配置文件两种方法,omegaconf的主要优势如下:

  • 是yaml的超集:omegaconf支持通过yaml文件创建配置,因此包含yaml的所有优点;
  • 提供yaml高阶语法:omegaconf对yaml文件的解析非常强大,提供变量插值、变量解析器等高阶语法(本文后面会讲);
  • 多源配置合并:支持从 YAML 文件、数据类/对象和命令行参数等多种来源合并配置。
  • 运行时类型安全:通过结构化配置提供运行时类型验证和转换。
  • 一致的 API:无论配置来源如何,API 保持一致。

基础用法

举个例子,我们现在有如下需求:

我们通过yaml文件进行参数配置,但是有些字段需要训练时才能决定,比如输出文件夹、batch size等;
我们希望根据某些值的不同,使得同一个yaml文件具备不同的行为,比如debug模式和非debug模式的切换、本地物理机模式和容器模式的切换。
这些需求非常常见,但是传统的参数配置方式却很难完成,而omegaconf却可以轻松做到。接下来我们来看omegaconf的具体用法。

创建配置

  • 从yaml创建conf = omegaconf.OmegaConf.load('foo.yaml')。这种方式和yaml.load('foo.yaml')类似;
  • 从CLI变量创建conf = omegaconf.OmegaConf.from_cli()。比如运行命令为python ./a.py k1=v1 k2=v2,那么conf的内容为 {‘k1’: ‘v1’, ‘k2’: ‘v2’} ,和字典类似;

注意,omegaconf接收的传参格式是key=value,而不是–key value

  • 其他:从dict、list、str等变量创建;从dataclass创建,本文不一一介绍,详情可参官方文档。

这里重点介绍yaml和CLI变量搭配使用,这一组合是个人觉得最强大实用也最简单方便的。

有时候yaml文件中的某些字段不方便提前确定(比如输出文件夹)。针对这一痛点,omegaconf支持只在yaml中填充确定的参数,而不确定的参数字段的值用???(三个英文问号)填充这些不确定的参数字段可以后续通过命令行传入

举个例子,yaml文件config.yaml中的内容如下:

output_dir: ???
model_args:
    num_layers: 2
training_args:
    batch_size: ???
    arr: [1, 2, 3]

运行命令如下:

python ./train.py config_path=config.yaml output_dir=./.checkpoints training_args.batch_size=4

在train.py中,我们通过如下的代码解析参数:

from omegaconf import OmegaConf

conf_cli = OmegaConf.from_cli()
conf_yaml = OmegaConf.load(conf_cli.config_path)
conf = OmegaConf.merge(conf_cli, conf_yaml)

首先,我们通过conf_cli = OmegaConf.from_cli()获取命令行传入的参数,
然后,我们再通过conf_yaml = OmegaConf.load(conf_cli.config_path)从config.yaml中读取配置信息,可以看到,output_dir和training_args.batch_size字段为???。此时如果我们访问这些字段,print(conf_yaml.output_dir)将会报错。
为了填充这些字段,我们通过conf = omegaconf.OmegaConf.merge(conf_yaml, conf_cli)将yaml中和命令行中的配置信息进行合并,这样最终得到的配置信息如下:

{
    'output_dir': './.checkpoints', 
    'model_args': {'num_layers': 2}, 
    'config_path': 'config.yaml', 
    'training_args': {'batch_size': 4, 'arr': [1, 2, 3]}
}

此时可以发现,完整的配置信息中不包含任意缺失的字段,再次获取output_dir等字段的值不会报错。

有两点需要注意:

  • 参数优先级:当命令行参数和yaml文件都指定了同一个参数的值时,那么最终合并之后的配置信息中该字段的值将和命令行传入的值保持一致,也就是说命令行参数的优先级高于yaml文件。这一点也为临时修改某些配置提供了便利。比如,我们只希望临时修改某个参数的值看看效果,此时我们只需要在命令行中传入该参数的值即可,不需要额外创建一份yaml文件;
  • 变量类型:从打印信息看,omegaconf的配置信息和python内置的数据类型非常相似,比如上面的conf变量的打印结果看起来和字典一模一样,但其实conf并不是字典,而是omegaconf定义的一个类omegaconf.dictconfig.DictConfig;又比如conf.training_args.arr像是列表,但其实是omegaconf定义的一个类omegaconf.listconfig.ListConfig。不过,omegaconf也提供了将上述conf转换为python内置数据类型的函数omegaconf.OmegaConf.to_object用于将DictConfig转为dict、将ListConfig转为list等

获取和更新值

获取值

  • 直接点引用:比如conf.training_args.arr[0] -> 1
  • 函数omegaconf.OmegaConf.select:比如omegaconf.OmegaConf.select(conf,"training_args.arr[0]") -> 1

更新值

  • 直接赋值:e.g. conf.training_args.arr = [1, 2, 3]
  • omegaconf.OmegaConf.update:e.g. omegaconf.OmegaConf.update(conf, "training_args.arr", [1, 2, 3])

写入和读取本地文件

  • 写入:e.g. omegaconf.OmegaConf.save(conf, "config.yaml")
  • 读取:e.g. omegaconf.OmegaConf.load("config.yaml")

高阶用法

除了上面的基础用法以外,omegaconf还提供了一些功能非常强大的高阶用法,本节主要介绍变量插值变量解析器 (resolver) 两种。

变量插值

变量插值指的是,在yaml文件中,我们可以将每个字段都视为一个类似python变量(list、dict、str等)的东西,然后在某一个字段中通过其余字段的值来对当前字段进行赋值

变量插值的基础语法为${...}。也就是说,假设某一个字段key的值为${training_args.value},那么key实际的值是training_args.value字段对应的值

举个例子,假设config.yaml中的内容如下:

training_args:
    batch_size: 10000
    arr: [1, 2, 3]
    dict:
      k: v
key: 'k'
value: "${training_args.dict[${key}]}_${training_args.arr[0]}"

然后加载该文件:

config = omegaconf.OmegaConf.load("./config.yaml")

config变量中的内容为:

{
    'training_args': {
        'batch_size': 10000, 'arr': [1, 2, 3], 'dict': {'k': 'v'}
    }, 
    'key': 'k', 
    'value': '${training_args.dict[${key}]}_${training_args.arr[0]}'
}

此时,config中的内容和config.yaml文件中的内容完全一致,没有什么特别的。但是,当我们尝试获取config.value的值时,神奇的事情发生了:

print(config.value)
# v_1

输出的值为v_1,而不是’KaTeX parse error: Expected '}', got 'EOF' at end of input: …ning_args.dict[{key}]}_${training_args.arr[0]}'。

这是因为,由于这个字段的值包含变量插值,omegaconf对其进行了解析。即依次执行如下替换:

  • 将${key}替换为k;
  • 将${training_args.dict[k]}替换为v;
  • 将training_args.arr[0]替换为1;
  • 将’KaTeX parse error: Expected '}', got 'EOF' at end of input: …ning_args.dict[{key}]}_${training_args.arr[0]}'整个替换为v_1。
    这一特性非常有用,假如多个字段的值相互之间存在一定关联,那么就可以使用变量插值极大地简化配置参数。

变量插值的解析只有当获取该值时才会发生。

变量解析器

除了变量插值以外,omegaconf另外一个很有用的特性是变量解析器 (resolver)。变量解析器的用途是自定义变量的解析方式,以实现更高级、灵活的变量解析

OmegaConf支持用户自定义变量解析器,从而实现根据自定义规则来进行变量解析。同时,OmegaConf还提供了一些内置的解析器。

简单起见,本节只介绍一个内置变量解析器oc.env,该解析器的作用是获取系统的环境变量。比如,我们希望将模型的输出文件夹指定为当前目录下的.checkpoints文件夹,那么就可以通过如下的yaml实现:

output_dir: ${oc:env:PWD}/.checkpoints

当配置信息和系统的环境变量存在一定绑定关系时,上述方式非常有用。

omegaConf 支持从环境变量中读取值。

import os
from omegaconf import OmegaConf

# 设置环境变量
os.environ["DB_HOST"] = "localhost"

# 使用环境变量插值
config = OmegaConf.create({
    "database": {
        "host": "${env:DB_HOST}",
        "port": 5432
    }
})

print(config.database.host)  # 输出: localhost

类型检查

omegaConf 支持类型检查,确保配置值的类型正确。

from omegaconf import OmegaConf, DictConfig

# 创建带类型检查的配置
config = OmegaConf.create({
    "database": {
        "host": "localhost",
        "port": 5432
    }
}, flags={"struct": True})

# 尝试修改为错误类型会抛出异常
try:
    config.database.port = "invalid"  # 应该为整数
except Exception as e:
    print(f"Error: {e}")

动态更新

omegaConf 支持动态更新配置。

from omegaconf import OmegaConf

config = OmegaConf.create({
    "database": {
        "host": "localhost",
        "port": 5432
    }
})

# 动态更新
OmegaConf.update(config, "database.host", "127.0.0.1")
print(config.database.host)  # 输出: 127.0.0.1

你可能感兴趣的:(【python】,【AI模型训练与部署】,python,开发语言)