【python第三方库】Hydra库在AI项目中使用简介

文章目录

  • 一、 前言
    • 1. omegaconf与Hydra库的关系
    • 2. Hydra优势
  • 二、 实际用法展示
    • 1. 项目结构
    • 2. 配置文件
    • 3. Python 代码
    • 4. 运行示例
      • 4.1 默认配置运行
      • 4.2 从命令行覆盖配置
      • 4.3 多运行模式
    • 5. 超参数优化
      • 5.1 安装 Optuna 插件
      • 5.2 修改 config.yaml
      • 5.3 运行超参数优化

一、 前言

Hydra 是一个开源 Python 框架,可简化研究和其他复杂应用程序的开发。关键特性是能够通过 组合动态创建分层配置,并通过配置文件和命令行覆盖它
在机器学习项目中,我们有多个版本的模型参数和架构,并且一些训练参数如文件链接,学习率,版本等信息更希望通过命令行来进行指定,以前的方法是需要自己实现argparse模块的读取命令行参数,自行进行解析和使用。使用Hydra就可以很好的解决上述麻烦。下面将结合官方文档对这个模块的使用进行介绍。

注意:我们导入包的时候使用 import hydra ,但pip下载包时需要使用 pip install hydra-core

  • 官方网址:Getting started | Hydra

1. omegaconf与Hydra库的关系

  • omegaConf 是 Hydra 的核心组件。
    • Hydra 是一个更高级的配置管理框架,而 omegaConf 是 Hydra 用来处理配置数据的底层库。
    • Hydra 依赖于 omegaConf 来解析、合并和操作配置。
  • omegaConf 可以独立使用:
    • 如果你只需要简单的配置管理功能,可以直接使用 omegaConf,而不需要引入 Hydra。
    • 如果你需要更强大的功能(如命令行集成、多运行模式等),可以选择 Hydra。
      【python第三方库】Hydra库在AI项目中使用简介_第1张图片
      关于omegaconf库的用法,见link

2. Hydra优势

  • 集中管理配置:通过 YAML 文件管理超参数、数据集、实验和环境配置。
  • 动态覆盖配置:从命令行动态覆盖配置。
  • 多运行模式:支持一次性运行多个实验。
  • 超参数优化:与 Optuna 集成,自动搜索最佳超参数。
  • 日志与输出管理:自动生成独立的输出目录。

二、 实际用法展示

我们将模拟一个典型的 AI 项目场景,包括超参数管理、数据集配置、实验管理、超参数优化、环境管理和日志管理。

1. 项目结构

假设项目结构如下:

my_ai_project/
  conf/
    config.yaml
    dataset/
      cifar10.yaml
      mnist.yaml
    experiment/
      exp1.yaml
      exp2.yaml
    env/
      local.yaml
      prod.yaml
  train.py

2. 配置文件

  • config.yaml
defaults:
  - dataset: cifar10
  - experiment: exp1
  - env: local
  • dataset/cifar10.yaml
name: CIFAR-10
path: /data/cifar10
num_classes: 10
  • dataset/mnist.yaml
name: MNIST
path: /data/mnist
num_classes: 10
  • experiment/exp1.yaml
name: Experiment 1
learning_rate: 0.001
batch_size: 64
  • experiment/exp2.yaml
name: Experiment 2
learning_rate: 0.01
batch_size: 128
  • env/local.yaml
data_path: /data/local
num_workers: 4
  • env/prod.yaml
data_path: /data/prod
num_workers: 16

3. Python 代码

  • train.py
import hydra
from omegaconf import DictConfig, OmegaConf
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="conf", config_name="config")
def train(cfg: DictConfig):
    # 打印完整配置
    logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")

    # 访问模型超参数
    learning_rate = cfg.experiment.learning_rate
    batch_size = cfg.experiment.batch_size
    logger.info(f"Learning Rate: {learning_rate}, Batch Size: {batch_size}")

    # 访问数据集配置
    dataset_name = cfg.dataset.name
    dataset_path = cfg.dataset.path
    logger.info(f"Dataset: {dataset_name}, Path: {dataset_path}")

    # 访问环境配置
    data_path = cfg.env.data_path
    num_workers = cfg.env.num_workers
    logger.info(f"Data Path: {data_path}, Num Workers: {num_workers}")

    # 模拟训练过程
    logger.info("Starting training...")
    # 这里可以添加实际的训练代码
    logger.info("Training completed.")

if __name__ == "__main__":
    train()

4. 运行示例

4.1 默认配置运行

python train.py

输出如下:

INFO:__main__:Configuration:
model:
  name: resnet50
  learning_rate: 0.001
  batch_size: 64
  num_layers: 18
dataset:
  name: CIFAR-10
  path: /data/cifar10
  num_classes: 10
experiment:
  name: Experiment 1
  learning_rate: 0.001
  batch_size: 64
env:
  data_path: /data/local
  num_workers: 4

INFO:__main__:Learning Rate: 0.001, Batch Size: 64
INFO:__main__:Dataset: CIFAR-10, Path: /data/cifar10
INFO:__main__:Data Path: /data/local, Num Workers: 4
INFO:__main__:Starting training...
INFO:__main__:Training completed.

4.2 从命令行覆盖配置

python train.py experiment=exp2 dataset=mnist env=prod

输出如下:

INFO:__main__:Configuration:
model:
  name: resnet50
  learning_rate: 0.01
  batch_size: 128
  num_layers: 18
dataset:
  name: MNIST
  path: /data/mnist
  num_classes: 10
experiment:
  name: Experiment 2
  learning_rate: 0.01
  batch_size: 128
env:
  data_path: /data/prod
  num_workers: 16

INFO:__main__:Learning Rate: 0.01, Batch Size: 128
INFO:__main__:Dataset: MNIST, Path: /data/mnist
INFO:__main__:Data Path: /data/prod, Num Workers: 16
INFO:__main__:Starting training...
INFO:__main__:Training completed.

4.3 多运行模式

python train.py --multirun experiment=exp1,exp2

输出如下:

[2023-10-01 12:34:56,123][__main__][INFO] - Configuration:
model:
  name: resnet50
  learning_rate: 0.001
  batch_size: 64
  num_layers: 18
dataset:
  name: CIFAR-10
  path: /data/cifar10
  num_classes: 10
experiment:
  name: Experiment 1
  learning_rate: 0.001
  batch_size: 64
env:
  data_path: /data/local
  num_workers: 4

[2023-10-01 12:34:56,124][__main__][INFO] - Learning Rate: 0.001, Batch Size: 64
[2023-10-01 12:34:56,124][__main__][INFO] - Dataset: CIFAR-10, Path: /data/cifar10
[2023-10-01 12:34:56,124][__main__][INFO] - Data Path: /data/local, Num Workers: 4
[2023-10-01 12:34:56,124][__main__][INFO] - Starting training...
[2023-10-01 12:34:56,124][__main__][INFO] - Training completed.

[2023-10-01 12:34:56,125][__main__][INFO] - Configuration:
model:
  name: resnet50
  learning_rate: 0.01
  batch_size: 128
  num_layers: 18
dataset:
  name: CIFAR-10
  path: /data/cifar10
  num_classes: 10
experiment:
  name: Experiment 2
  learning_rate: 0.01
  batch_size: 128
env:
  data_path: /data/local
  num_workers: 4

[2023-10-01 12:34:56,125][__main__][INFO] - Learning Rate: 0.01, Batch Size: 128
[2023-10-01 12:34:56,125][__main__][INFO] - Dataset: CIFAR-10, Path: /data/cifar10
[2023-10-01 12:34:56,125][__main__][INFO] - Data Path: /data/local, Num Workers: 4
[2023-10-01 12:34:56,125][__main__][INFO] - Starting training...
[2023-10-01 12:34:56,125][__main__][INFO] - Training completed.

5. 超参数优化

5.1 安装 Optuna 插件

pip install hydra-optuna-sweeper

5.2 修改 config.yaml

defaults:
  - hydra/sweeper: optuna

hydra:
  sweeper:
    sampler:
      type: tpe
    direction: minimize
    n_trials: 10
    params:
      experiment.learning_rate: interval(0.0001, 0.1)
      experiment.batch_size: choice(32, 64, 128)

5.3 运行超参数优化

python train.py --multirun hydra/sweeper=optuna

你可能感兴趣的:(【python第三方库】Hydra库在AI项目中使用简介)