本文总结一下目前TRL典型的训练器的实现细节(SFT,PPO,DPO,GRPO),也是对上一文【速写】PPOTrainer样例与错误思考(少量DAPO)的补充
目前DeepSeek关于各个训练器细节的掌握,尤其是PPOTrainer
的问题,依然回答得很差,这个在上文中已经详细指出,它写的代码大多数都是跑不通的,而官方给出的PPO示例(ppo.py)似乎也有一些瑕疵,本文将会一一指出与各位探讨。
目前笔者写了一个对于4个典型训练器(SFT,PPO,DPO,GRPO)都适用的一个base_pipeline
,以及对应的单元测试模块,其中PPOTrainer
是最为繁琐的一个训练器,它需要更多的模型(奖励模型,价值模型,参考模型),并且对应的训练数据集train_dataset
的处理方式与其他训练器存在显著区别。根据这份base_pipeline
的代码,我将逐一探讨TRL训练器中的细节问题,谨以抛砖引玉,欢迎探讨。
完整的项目在GitHub@caoyang-sufe/easyllm,以下我们先来看下面代码中的base_pipeline
,其他4个pipeline(sft_pipeline
,ppo_pipeline
,dpo_pipeline
,grpo_pipeline
)都是直接调用它使用。
对应的脚本在trainer.py
这个pipeline很重要,接下来所有的讨论都是围绕它逐行展开的。
# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]
import wandb
import logging
from copy import deepcopy
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, HfArgumentParser
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import (
ScriptArguments, ModelConfig,
SFTConfig, SFTTrainer,
PPOConfig, PPOTrainer,
DPOConfig, DPOTrainer,
GRPOConfig, GRPOTrainer,
get_peft_config, get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from src.tools.trl import update_trl_config, generate_simple_data_processor
# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g.
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g.
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
# - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
# - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
# - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
# - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
# - keyword arguments for `DPOTrainer`: e.g. "ref_model"
# - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):
# 1 Configuration
TRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")
parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))
script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()
script_arguments = update_trl_config(script_arguments, **config_kwargs)
trainer_config = update_trl_config(trainer_config, **config_kwargs)
model_config = update_trl_config(model_config, **config_kwargs)
peft_config = get_peft_config(model_config)
quantization_config = get_quantization_config(model_config)
# 2 Load models and tokenizer
logging.info("Load models and tokenizer ...")
logging.info(f" - Model: {model_config.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if not "pad_token" in tokenizer.special_tokens_map:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
device_map = "auto",
trust_remote_code = model_config.trust_remote_code,
quantization_config = quantization_config,
)
if peft_config is not None:
logging.info("Prepare model for PEFT ...")
model.config.pretraining_tp = 1
model.config.use_cache = False
model.gradient_checkpointing_enable()
# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
# Then you need set `model.enable_input_require_grads()` yourself
# model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
if name == "PPO":
logging.info("PPO load reward value and reference models ...")
# PPO is special! It needs more components!
logging.info(f" - Reward model: {trainer_config.reward_model_path}")
reward_model = AutoModelForSequenceClassification.from_pretrained(
trainer_config.reward_model_path,
trust_remote_code = model_config.trust_remote_code,
num_labels = 1,
)
value_model = AutoModelForSequenceClassification.from_pretrained(
trainer_config.reward_model_path,
trust_remote_code = model_config.trust_remote_code,
num_labels = 1,
)
logging.info(" - Copy reference model ...")
ref_model = deepcopy(model)
# ref_model = model.__class__(model.config)
# ref_model.load_state_dict(model.state_dict())
trainer_kwargs["reward_model"] = reward_model
trainer_kwargs["value_model"] = value_model
trainer_kwargs["ref_model"] = ref_model
logging.info(" - Done!")
if data_processor is None:
# The data processor of PPO is also different to others
def data_processor(_data):
outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)
return {"input_ids": outputs["input_ids"]}
# 2 Load dataset
logging.info("Load dataset ...")
logging.info(f" - Dataset: {script_arguments.dataset_name}")
if data_processor is None:
data_processor = generate_simple_data_processor(name)
train_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_train_split)
eval_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_test_split)
train_dataset = train_dataset.map(data_processor, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(data_processor, remove_columns=eval_dataset.column_names)
logging.info(f" - Train dataset: {len(train_dataset)}")
logging.info(f" - Eval dataset: {len(eval_dataset)}")
# 4 Train model
logging.info("Trainer starts ...")
trainer = TRLTrainer(
model = model,
args = trainer_config,
train_dataset = train_dataset,
eval_dataset = eval_dataset,
processing_class = tokenizer,
peft_config = peft_config,
**trainer_kwargs
)
trainer.train()
logging.info(" - Trainer finishes!")
# 5 Save model
if trainer_config.push_to_hub:
logging.info(f" - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")
trainer.push_to_hub()
logging.info(f"Save model to {trainer_config.output_dir}")
trainer.save_model(trainer_config.output_dir)
# SFT Pipeline
def sft_pipeline(data_processor, config_kwargs, trainer_kwargs):
base_pipeline(
name = "SFT",
data_processor = data_processor,
config_kwargs = config_kwargs,
trainer_kwargs = trainer_kwargs,
)
# PPO Pipeline
def ppo_pipeline(data_processor, config_kwargs, trainer_kwargs):
base_pipeline(
name = "PPO",
data_processor = data_processor,
config_kwargs = config_kwargs,
trainer_kwargs = trainer_kwargs,
)
# DPO Pipeline
def dpo_pipeline(data_processor, config_kwargs, trainer_kwargs):
base_pipeline(
name = "DPO",
data_processor = data_processor,
config_kwargs = config_kwargs,
trainer_kwargs = trainer_kwargs,
)
# GRPO Pipeline
def grpo_pipeline(data_processor, config_kwargs, trainer_kwargs):
base_pipeline(
name = "GRPO",
data_processor = data_processor,
config_kwargs = config_kwargs,
trainer_kwargs = trainer_kwargs,
)
对应的单元测试脚本在trainer_pipelines.py,这里面涉及3个模型和4个数据集,都可以在huggingface上直接下载得到,model_home
和dataset_home
可根据本地路径进行修改。
# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]
import os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipeline
model_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = [
"Qwen/Qwen2.5-0.5B-Instruct",
"EleutherAI/pythia-1b-deduped",
"EleutherAI/pythia-160m",
]
dataset_names = [
"trl-lib/tldr", # train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]
"trl-lib/ultrafeedback_binarized", # train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]
"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]
"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]
def sft_pipeline_test():
logging.info("SFT unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[0])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[500:600]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
}
sft_pipeline(data_processor, config_kwargs, trainer_kwargs)
def ppo_pipeline_test():
logging.info("PPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[1])
"""
EleutherAI/pythia-1b-deduped
GPTNeoXForCausalLM(
(gpt_neox): GPTNeoXModel(
(embed_in): Embedding(50304, 2048)
(emb_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-15): 16 x GPTNeoXLayer(
(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(post_attention_dropout): Dropout(p=0.0, inplace=False)
(post_mlp_dropout): Dropout(p=0.0, inplace=False)
(attention): GPTNeoXAttention(
(query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
(dense): Linear(in_features=2048, out_features=2048, bias=True)
)
(mlp): GPTNeoXMLP(
(dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
(act): GELUActivation()
)
)
)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(rotary_emb): GPTNeoXRotaryEmbedding()
)
(embed_out): Linear(in_features=2048, out_features=50304, bias=False)
)
"""
dataset_name = os.path.join(dataset_home, dataset_names[0])
reward_model_path = os.path.join(model_home, model_names[2])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"reward_model_path": reward_model_path,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[:100]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["query_key_value"],
}
trainer_kwargs = {
}
ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)
def dpo_pipeline_test():
logging.info("DPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[2])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "descriptiveness[:500]",
"dataset_test_split": "descriptiveness[500:600]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
}
dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
def grpo_pipeline_test():
logging.info("GRPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[0])
data_processor = None
def reward_funcs(completions, **kwargs):
return [float(len(set(completion))) for completion in completions]
config_kwargs = {
"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[:100]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
"reward_funcs": reward_funcs,
}
grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
目前单元测试都能通过,关键库版本如下:
accelerate==1.6.0
datasets==3.5.0
peft==0.15.2
torch==2.5.1
transformers==4.51.3
trl==0.17.0
目前模型训练涉及的参数非常的广泛,以前搭积木时代训练模型时,我们一般会自定义一个Config
类用于管理与模型、训练、数据集等相关的参数,但目前Transformers提供了非常好的工具HfArgumentParser
用于管理这些繁杂的参数,根据官方给出的PPO示例(ppo.py)中的写法
parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
将所有参数分为三类:ScriptArguments, PPOConfig, ModelConfig
使用HfArgumentParser
解析参数的逻辑是(这个过程没有读源码,可能实际上略有出入,但是结果应该是一致的),首先使用默认值初始化ScriptArguments, PPOConfig, ModelConfig
对应的三个对象,然后将执行脚本中传入的参数按名称对应分配给每个对象。
例如,在PPO的例子的执行脚本中:
python -i examples/scripts/ppo/ppo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/ppo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--missing_eos_penalty 1.0
会将dataset_name
和dataset_train_split
的值更新给script_args
,learning_rate
和output_dir
的值更新给model_args
,
不妨可以看看script_args, training_args, model_args
分别包含了哪些参数:
script_args
ScriptArguments(
dataset_name=None,
dataset_config=None,
dataset_train_split='train',
dataset_test_split='test',
gradient_checkpointing_use_reentrant=False,
ignore_bias_buffers=False,
)
training_args
(即ppo_config
)PPOConfig(
output_dir='trainer_output',
overwrite_output_dir=False,
do_train=False,
do_eval=False,
do_predict=False,
eval_strategy=<IntervalStrategy.NO: 'no'>,
prediction_loss_only=False,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
per_gpu_train_batch_size=None,
per_gpu_eval_batch_size=None,
gradient_accumulation_steps=1,
eval_accumulation_steps=None,
eval_delay=0,
torch_empty_cache_steps=None,
learning_rate=5e-05,
weight_decay=0.0,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
max_grad_norm=1.0,
num_train_epochs=3.0,
max_steps=-1,
lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>,
lr_scheduler_kwargs={},
warmup_ratio=0.0,
warmup_steps=0,
log_level='passive',
log_level_replica='warning',
log_on_each_node=True,
logging_dir='trainer_output\\runs\\Jun14_11-14-00_LAPTOP-PJP6MGE1',
logging_strategy=<IntervalStrategy.STEPS: 'steps'>,
logging_first_step=False,
logging_steps=500,
logging_nan_inf_filter=True,
save_strategy=<SaveStrategy.STEPS: 'steps'>,
save_steps=500,
save_total_limit=None,
save_safetensors=True,
save_on_each_node=False,
save_only_model=False,
restore_callback_states_from_checkpoint=False,
no_cuda=False,
use_cpu=False,
use_mps_device=False,
seed=42,
data_seed=None,
jit_mode_eval=False,
use_ipex=False,
bf16=False,
fp16=False,
fp16_opt_level='O1',
half_precision_backend='auto',
bf16_full_eval=False,
fp16_full_eval=False,
tf32=None,
local_rank=0,
ddp_backend=None,
tpu_num_cores=None,
tpu_metrics_debug=False,
debug=[],
dataloader_drop_last=False,
eval_steps=None,
dataloader_num_workers=0,
dataloader_prefetch_factor=None,
past_index=-1,
run_name='trainer_output',
disable_tqdm=False,
remove_unused_columns=True,
label_names=None,
load_best_model_at_end=False,
metric_for_best_model=None,
greater_is_better=None,
ignore_data_skip=False,
fsdp=[],
fsdp_min_num_params=0,
fsdp_config={'min_num_params': 0,
'xla': False,
'xla_fsdp_v2': False,
'xla_fsdp_grad_ckpt': False},
fsdp_transformer_layer_cls_to_wrap=None,
accelerator_config=AcceleratorConfig(split_batches=False,
dispatch_batches=None,
even_batches=True,
use_seedable_sampler=True,
non_blocking=False,
gradient_accumulation_kwargs=None,
use_configured_state=False),
deepspeed=None,
label_smoothing_factor=0.0,
optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>,
optim_args=None,
adafactor=False,
group_by_length=False,
length_column_name='length',
report_to=['wandb'],
ddp_find_unused_parameters=None,
ddp_bucket_cap_mb=None,
ddp_broadcast_buffers=None,
dataloader_pin_memory=True,
dataloader_persistent_workers=False,
skip_memory_metrics=True,
use_legacy_prediction_loop=False,
push_to_hub=False,
resume_from_checkpoint=None,
hub_model_id=None,
hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>,
hub_token=None,
hub_private_repo=None,
hub_always_push=False,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
include_inputs_for_metrics=False,
include_for_metrics=[],
eval_do_concat_batches=True,
fp16_backend='auto',
push_to_hub_model_id=None,
push_to_hub_organization=None,
push_to_hub_token=None,
mp_parameters='',
auto_find_batch_size=False,
full_determinism=False,
torchdynamo=None,
ray_scope='last',
ddp_timeout=1800,
torch_compile=False,
torch_compile_backend=None,
torch_compile_mode=None,
include_tokens_per_second=False,
include_num_input_tokens_seen=False,
neftune_noise_alpha=None,
optim_target_modules=None,
batch_eval_metrics=False,
eval_on_start=False,
use_liger_kernel=False,
eval_use_gather_object=False,
average_tokens_across_devices=False,
dataset_num_proc=None,
num_mini_batches=1,
total_episodes=None,
local_rollout_forward_batch_size=64,
num_sample_generations=10,
response_length=53,
stop_token=None,
stop_token_id=None,
temperature=0.7,
missing_eos_penalty=None,
sft_model_path='EleutherAI/pythia-160m',
world_size=None,
num_total_batches=None,
micro_batch_size=None,
local_batch_size=None,
batch_size=None,
local_mini_batch_size=None,
mini_batch_size=None,
exp_name='ppo_config',
reward_model_path='EleutherAI/pythia-160m',
model_adapter_name=None,
ref_adapter_name=None,
num_ppo_epochs=4,
whiten_rewards=False,
kl_coef=0.05,
kl_estimator='k1',
cliprange=0.2,
vf_coef=0.1,
cliprange_value=0.2,
gamma=1.0,
lam=0.95,
ds3_gather_for_generation=True
)
model_args
ModelConfig(
model_name_or_path=None,
model_revision='main',
torch_dtype=None,
trust_remote_code=False,
attn_implementation=None,
use_peft=False,
lora_r=16,
lora_alpha=32,
lora_dropout=0.05,
lora_target_modules=None,
lora_modules_to_save=None,
lora_task_type='CAUSAL_LM',
use_rslora=False,
use_dora=False,
load_in_8bit=False,
load_in_4bit=False,
bnb_4bit_quant_type='nf4',
use_bnb_nested_quant=False,
)
其中:
script_args
:主要控制数据集的参数。特别地,dataset_name
是required,但这个事情比较奇怪
--dataset_name
或者--dataset-name
参数的,否则会显示 error: the following arguments are required: --dataset_name/--dataset-name
(transformers==4.51.3, trl==0.17.0
)transformers==4.52.4, trl==0.18.1
,比Linux上的版本要高一点,但应该不是版本问题)。training_args
:即PPOConfig
的对象,对应PPOTrainer
中的args
参数,其中包含了非常多与训练参数,具体每个训练器及其配置的参数列表可以直接到HuggingFace上的TRL文档查看(https://huggingface.co/docs/trl),这里例举常用的参数:
output_dir
:模型checkpoint导出的路径,adam_xxx
:Adam优化器相关的参数,似乎TRL里所有Trainer的默认优化器都是Adam,不过是可以在PPOTrainer
的参数中定义优化器的(PPOTrainer
的优化器参数optimizers
是需要传入两个变量(optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
)),一个是优化器,另一个是学习率的规划器。learning_rate
:学习率,kl_coef
:KL散度的惩罚系数,越大相当于策略更新幅度越小,越小策略就可能更新幅度越快,model_args
对应的是PPOTrainer
的peft_config
参数,里面也有一些和模型初始化相关的参数,包括量化与PEFT相关的参数也是从中提取得到的,比如在base_pipeline
脚本中:
...
# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g.
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g.
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
# - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
# - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
# - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
# - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
# - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
# - keyword arguments for `DPOTrainer`: e.g. "ref_model"
# - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):
# 1 Configuration
TRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")
parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))
script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()
script_arguments = update_trl_config(script_arguments, **config_kwargs)
trainer_config = update_trl_config(trainer_config, **config_kwargs)
model_config = update_trl_config(model_config, **config_kwargs)
peft_config = get_peft_config(model_config)
quantization_config = get_quantization_config(model_config)
...
peft_config
和quantization_config
都是可以现成的函数生成的,它们被定义在trl
项目根目录下的trainer/utils.py
中:
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype`
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.torch_dtype,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config
def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
if model_args.use_peft is False:
return None
if not is_peft_available():
raise ValueError(
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
"Make sure to run `pip install -U peft`."
)
peft_config = LoraConfig(
task_type=model_args.lora_task_type,
r=model_args.lora_r,
target_modules=model_args.lora_target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
bias="none",
use_rslora=model_args.use_rslora,
use_dora=model_args.use_dora,
modules_to_save=model_args.lora_modules_to_save,
)
return peft_config
从上面的源码可以看出,如果model_config.use_peft=False
,则默认不启用PEFT,得到的peft_config
也就是None
,同理是否采用量化取决于load_in_4bit
和load_in_8bit
是否至少有一个是True
回到base_pipeline
中对应的部分:
...
# 2 Load models and tokenizer
logging.info("Load models and tokenizer ...")
logging.info(f" - Model: {model_config.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if not "pad_token" in tokenizer.special_tokens_map:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path,
device_map = "auto",
trust_remote_code = model_config.trust_remote_code,
quantization_config = quantization_config,
)
if peft_config is not None:
logging.info("Prepare model for PEFT ...")
model.config.pretraining_tp = 1
model.config.use_cache = False
model.gradient_checkpointing_enable()
# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
# Then you need set `model.enable_input_require_grads()` yourself
# model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
...
这里主要是几个零碎的注意点:
pad_token
,具体可以通过tokenizer.special_tokens_map
,否则在后面trainer.train()
时会报错,提示你设置pad_token
tokenizer.chat_template
的问题,这个到接下来的数据处理部分会详细说明,针对对话类的任务,一般是会采用chat_template
来处理数据集的,因此可以用trl.trainer.utils
中提供的SIMPLE_CHAT_TEMPLATE
作为缺省quantization_config
不为None
时),这样会节约内存,此时一般都会做如下的设置(如果不是量化加载的话,可能加上这些配置也没什么关系,所以我就这么写了)
pretraining_tp
是张量并行的意思,数值越大,就会并行的数量越多,当然只有多卡才能并行,单卡只能设为1use_cache
应该是指(大概)做model.generate()
生成时是否进行kv_cache
缓存model.enable_input_require_grads()
也是节约内存的方法,我记得是缓存梯度以达到更快的反向传播运算的作用。 ...
model.config.pretraining_tp = 1
model.config.use_cache = False
model.gradient_checkpointing_enable()
# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
# Then you need set `model.enable_input_require_grads()` yourself
# model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()
...
然后对于PPOTrainer
而言,非常特殊,它的构造参数还需要reward_model, value_model, ref_model
三个参数,因此在base_pipeline
中做了额外处理:
if name == "PPO":
logging.info("PPO load reward value and reference models ...")
# PPO is special! It needs more components!
logging.info(f" - Reward model: {trainer_config.reward_model_path}")
reward_model = AutoModelForSequenceClassification.from_pretrained(
trainer_config.reward_model_path,
trust_remote_code = model_config.trust_remote_code,
num_labels = 1,
)
value_model = AutoModelForSequenceClassification.from_pretrained(
trainer_config.reward_model_path,
trust_remote_code = model_config.trust_remote_code,
num_labels = 1,
)
logging.info(" - Copy reference model ...")
ref_model = deepcopy(model)
# ref_model = model.__class__(model.config)
# ref_model.load_state_dict(model.state_dict())
trainer_kwargs["reward_model"] = reward_model
trainer_kwargs["value_model"] = value_model
trainer_kwargs["ref_model"] = ref_model
logging.info(" - Done!")
这里注意一下:
ref_model
和目标模型是完全一样的,这里测试了两种方法,一种是直接ref_model = deepcopy(model)
,这确实是可行的,不会发生unpickled之类的问题,另一种方法则是重构一个和目标模型结构相同的空模型,然后加载状态字典: # ref_model = model.__class__(model.config)
# ref_model.load_state_dict(model.state_dict())
这个方法在use_peft=False
时是可行的,但是因为这里可能是一个peft_model
,因此model.__class__
的参数还需要一个peft_config
,直接这样写是会报错的HfArgumentParser
加载PPOConfig
参数时,ppo_config.reward_model_path
是有默认值的(可以看上面的training_args
),即'EleutherAI/pythia-160m'
,这个奖励模型必须是可以被AutoModelForSequenceClassification
类型加载的value_model
通常设置为和reward_model
相同,至少在官方的PPO示例中是这样的,其实也可以解释,我的理解是:value_model
, Q ( s , a ) = r ( s , a ) + γ V ( s ′ ) Q(s,a)=r(s,a)+\gamma V(s') Q(s,a)=r(s,a)+γV(s′)中的 r ( s , a ) r(s,a) r(s,a)即奖励函数,在CAUSAL_LM
的语境下,其实就是句子 s s s加了一个单词 a a a后的奖励,本质上都是对一句话进行评分。虽然说得通,但是感觉也挺奇怪的。目前主流的数据集加载都是直接用HuggingFace的datasets.load_dataset
方法加载,标准的数据格式就是每个样本的格式为{“column_1”:
,整个数据集类似jsonl的格式。
针对目前已知的几个训练器,它们对数据集的格式要求大概是这样的:
SFTTrainer
:字段要包含prompt
和completion
,注意,训练时只会在completion
部分计算损失,prompt
部分是不会计算损失的
text
或者input
+target
这种字段时也是可行的,需要仔细读源码才行。DPOTrainer
:这个就很简单,一般就是prompt
+chosen
+rejected
,但是在官方示例使用的数据集trl-lib/ultrafeedback_binarized
中,字段却是chosen, rejected, score_chosen, score_rejected
,也就是说其实prompt
可能并非必须,本来也确实可以直接用空字符串替代,然后数据集中如果没有回答进行评分的话,可能还有一个默认的评分机制在训练器里面。
GRPOTrainer
:这个我也是根据官方示例来的,我看到它使用的数据集是trl-lib/tldr
,这是一个典型的只有prompt
和completion
两个字段的数据集
PPOTrainer
:这个最为特殊,之前DeepSeek一直写不对的原因也在于此
在官方给出的PPO运行脚本(ppo.py)中
def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}
return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=training_args.dataset_num_proc,
)
注意到这里是直接处理成分词后的input_ids
格式
但是,我本来以为PPO和GRPO是类似的,因此理论上在数据集的格式要求上应该也差不了太多,因此我觉得可能prompt
+completion
也是可行的,但是测试下来会报错,报错提示是一定要求是带input_ids
字段的。
因此我单独给PPO写了一个数据处理的data_processor
,而其余的都是很简单的:
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
def generate_simple_data_processor(name, **kwargs):
if name in ["SFT", "GRPO"]:
def _data_processor(_data):
return {"prompt": _data["prompt"], "completion": _data["completion"]}
elif name == "PPO":
tokenizer = kwargs.get("tokenizer")
def _data_processor(_data):
outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)
return {"input_ids": outputs["input_ids"]}
elif name == "DPO":
def _data_processor(_data):
return {"prompt": _data["prompt"], "chosen": _data["chosen"], "rejected": _data["rejected"]}
else:
raise NotImplementedError(name)
return _data_processor
当然一些写法也会使用apply_chat_template
,这通常是用于交互式对话任务,都是可行的:
def _data_processor(_data):
_message = [
{"role": "system", "content": "You are an AI assistant developped by CY"},
{"role": "user", "content": _data[dataset_input_column]},
{"role": "assistant", "content": _data[dataset_target_column]},
]
_prompt = tokenizer.apply_chat_template(_message, tokenize=False)
return {"text": _prompt}
然后一个小细节是在用dataset.map
进行数据处理时一般会设置参数emove_columns=dataset.column_names
,以丢弃原先不必要的字段,防止Trainer在运行时错用了其他的字段进行训练。
base_pipeline
中最后一部分训练完然后保存模型,顺利地话就完整地跑通:
# 4 Train model
logging.info("Trainer starts ...")
trainer = TRLTrainer(
model = model,
args = trainer_config,
train_dataset = train_dataset,
eval_dataset = eval_dataset,
processing_class = tokenizer,
peft_config = peft_config,
**trainer_kwargs
)
trainer.train()
logging.info(" - Trainer finishes!")
# 5 Save model
if trainer_config.push_to_hub:
logging.info(f" - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")
trainer.push_to_hub()
logging.info(f"Save model to {trainer_config.output_dir}")
trainer.save_model(trainer_config.output_dir)
这里值得注意的几个问题:
在trainer_config
(即SFTConfig, PPOConfig, DPOConfig, GRPOConfig
之类的对象)中,有一个参数trainer_config.report_to
,这个不设置的话默认是会上传到W&B的,网络不支持访问的话是会在训练到checkpoint的时候发生网络错误的,因此一般会设置成"none"
,或者也可以用"tensorboard"
trainer_config.output_dir
中找到checkpoint-xxx
文件夹,里面会有trainer_state.json
文件,然后自己根据里面的数据绘图即可,例如:# Plot dynamics of TRL trainer state
def plot_trl_dynamics(trainer_state_path):
with open(trainer_state_path, 'r', encoding="utf8") as f:
data = json.load(f)
log_history = data["log_history"]
steps = [entry["step"] for entry in log_history]
episodes = [entry["episode"] for entry in log_history]
epochs = [entry["epoch"] for entry in log_history]
policy_loss = [entry["loss/policy_avg"] for entry in log_history]
value_loss = [entry["loss/value_avg"] for entry in log_history]
lrs = [entry["lr"] for entry in log_history]
entropys = [entry["objective/entropy"] for entry in log_history]
kls = [entry["objective/kl"] for entry in log_history]
non_score_rewards = [entry["objective/non_score_reward"] for entry in log_history]
rlhf_rewards = [entry["objective/rlhf_reward"] for entry in log_history]
scores = [entry["objective/scores"] for entry in log_history]
plt.figure(figsize=(8, 8))
ax_1 = plt.subplot(2, 2, 1)
ax_2 = plt.subplot(4, 2, 2)
ax_3 = plt.subplot(4, 2, 4)
ax_4 = plt.subplot(2, 2, 3)
ax_5 = plt.subplot(2, 2, 4)
ax_1.plot(steps, policy_loss, label="Policy Loss")
ax_1.plot(steps, value_loss, label="Value Loss", linestyle="--")
ax_1.set_xlabel("Step"), ax_1.set_ylabel("Loss"), ax_1.legend()
ax_1.set_title("Policy and Value Loss")
# ------------------------------------------------------------------
ax_2.plot(steps, kls, label="objective/kl")
ax_2.set_xlabel("Step"), ax_2.set_ylabel("KL"), ax_2.legend()
ax_2.set_title("KL Curve")
# ------------------------------------------------------------------
ax_3.plot(steps, entropys, label="objective/entropy")
ax_3.set_xlabel("Step"), ax_3.set_ylabel("Entropy"), ax_3.legend()
ax_3.set_title("Entropy Curve")
# ------------------------------------------------------------------
ax_4.plot(steps, lrs, label="Learning Rate")
ax_4.set_xlabel("Step"), ax_4.set_ylabel("Learning Rate"), ax_4.legend()
ax_4.set_title("Learning Rate Curve")
# ------------------------------------------------------------------
ax_5.plot(steps, non_score_rewards, label="objective/non_score_reward", linestyle="--")
ax_5.plot(steps, rlhf_rewards, label="objective/rlhf_reward", linestyle="--")
ax_5.plot(steps, scores, label="objective/scores")
ax_5.set_xlabel("Step"), ax_5.set_ylabel("Score/Reward"), ax_5.legend()
ax_5.set_title("Reward and Score")
plt.show()
print(steps)
print('-' * 64)
print(episodes)
print('-' * 64)
print(epochs)
关于W&B的使用:如果网络支持,使用wandb
确实也是不错的选择:
登录 wandb
方法 1:交互式登录(推荐)
运行以下代码,会弹出浏览器或要求输入 API Key:
import wandb
wandb.login()
然后按提示操作即可。
方法 2:直接提供 API Key
import wandb
wandb.login(key="your_api_key_here") # 替换成你的 wandb API Key
⚠️ 注意:不要直接在代码里硬编码 API Key(特别是公开代码时),建议使用环境变量或配置文件。
方法 3:通过环境变量设置(适合脚本/服务器)
export WANDB_API_KEY=your_api_key_here
然后在 Python 中直接 wandb.init()
即可。
步骤 3:在 Trainer 中启用 wandb
方式 1:自动记录(默认)
HuggingFace Trainer
默认会使用 wandb
(如果已安装并登录):
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir="./results",
report_to="wandb", # 明确指定使用 wandb(默认已经是这样)
logging_dir="./logs",
# 其他训练参数...
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train() # 自动记录到 wandb
方式 2:手动初始化 wandb(更灵活控制)
import wandb
# 初始化 wandb 并设置项目名称
wandb.init(project="my-ppo-project", name="experiment-1")
# 然后正常训练
trainer.train()
# 训练结束后(可选)
wandb.finish()
这样你可以自定义 wandb
的配置(如 project
、name
、config
等)。
额外配置(可选)
1. 记录超参数
config = {
"learning_rate": 2e-5,
"batch_size": 32,
"epochs": 3,
}
wandb.init(project="my-project", config=config)
训练时,这些参数会显示在 wandb 面板中。
2. 自定义记录指标
# 在训练循环中记录自定义指标
wandb.log({"reward": current_reward, "loss": current_loss})
3. 恢复中断的训练
wandb.init(id="之前运行的ID", resume="must")
Tracking run with wandb, view at https://wandb.ai/your-username/your-project/runs/xxx
单元测试脚本trainer_pipelines.py:
# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]
import os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipeline
model_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = [
"Qwen/Qwen2.5-0.5B-Instruct",
"EleutherAI/pythia-1b-deduped",
"EleutherAI/pythia-160m",
]
dataset_names = [
"trl-lib/tldr", # train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]
"trl-lib/ultrafeedback_binarized", # train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]
"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]
"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]
def sft_pipeline_test():
logging.info("SFT unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[0])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[500:600]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
}
sft_pipeline(data_processor, config_kwargs, trainer_kwargs)
def ppo_pipeline_test():
logging.info("PPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[1])
"""
EleutherAI/pythia-1b-deduped
GPTNeoXForCausalLM(
(gpt_neox): GPTNeoXModel(
(embed_in): Embedding(50304, 2048)
(emb_dropout): Dropout(p=0.0, inplace=False)
(layers): ModuleList(
(0-15): 16 x GPTNeoXLayer(
(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(post_attention_dropout): Dropout(p=0.0, inplace=False)
(post_mlp_dropout): Dropout(p=0.0, inplace=False)
(attention): GPTNeoXAttention(
(query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
(dense): Linear(in_features=2048, out_features=2048, bias=True)
)
(mlp): GPTNeoXMLP(
(dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
(dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
(act): GELUActivation()
)
)
)
(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
(rotary_emb): GPTNeoXRotaryEmbedding()
)
(embed_out): Linear(in_features=2048, out_features=50304, bias=False)
)
"""
dataset_name = os.path.join(dataset_home, dataset_names[0])
reward_model_path = os.path.join(model_home, model_names[2])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"reward_model_path": reward_model_path,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[:100]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["query_key_value"],
}
trainer_kwargs = {
}
ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)
def dpo_pipeline_test():
logging.info("DPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[2])
data_processor = None
config_kwargs = {
"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "descriptiveness[:500]",
"dataset_test_split": "descriptiveness[500:600]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
}
dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
def grpo_pipeline_test():
logging.info("GRPO unittest ...")
model_name_or_path = os.path.join(model_home, model_names[0])
dataset_name = os.path.join(dataset_home, dataset_names[0])
data_processor = None
def reward_funcs(completions, **kwargs):
return [float(len(set(completion))) for completion in completions]
config_kwargs = {
"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
"model_name_or_path": model_name_or_path,
"dataset_name": dataset_name,
"trust_remote_code": True,
"dataset_train_split": "train[:500]",
"dataset_test_split": "validation[:100]",
"use_peft": True,
"report_to": "none",
"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
}
trainer_kwargs = {
"reward_funcs": reward_funcs,
}
grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)
这里主要就是grpo_pipeline_test
中的reward_funcs
,其他几个测试函数都没什么好说的,很好理解
GRPOTrainer
中有required参数reward_funcs
,这个在TRL文档中给出的例子是:
def reward_funcs(completions, **kwargs):
return [float(len(set(completion))) for completion in completions]
这个奖励的意思是会给长回答更高的得分。
注意这里的**kwargs
是不可以省略的,因为GRPOTrainer
调用reward_funcs
会传入其他参数,比如prompts
或者ground_truths
,在TRL的GRPO文档中有很多例子。
这部分摘取一些和DeepSeek讨论的问题
关于TRL的Trainer
的一个问题,因为Trainer
是可以传入data_collator
参数的,但是一般来说我们都会在向Trainer
传入数据集(如train_dataset
)之前,先写好一个函数用于预处理数据集,比如train_dataset = train.dataset.map(prompt_formatter)
这样,我的理解其实这里prompt_formatter何data_collator的作用是一样的(我觉得它们处理数据的逻辑也是一样,就是传入一个样本,然后做数据处理),因此如果在Trainer外面已经做了train_dataset.map(prompt_formatter)
的操作,就不需要再传入data_collator了。反之也是一样,我可以传入data_collator
,而无需进行train_dataset.map(prompt_formatter)
的操作
以下回答由AI生成:
你的理解基本正确,但有一些细节需要注意。prompt_formatter
(在dataset.map
中使用)和data_collator
确实都涉及数据处理,但它们的作用时机和目的有所不同。下面我来详细解释它们的区别以及何时使用它们:
dataset.map(prompt_formatter)
:remove_columns
删除不需要的字段)。load_from_cache_file=False
)。data_collator
:只用dataset.map
:
dataset.map
。data_collator
处理padding(除非你已手动填充到固定长度)。只用data_collator
:
data_collator
。联合使用:
dataset.map
处理静态逻辑(例如tokenization、添加固定prompt)。data_collator
处理动态逻辑(例如padding、随机增强)。def prompt_formatter(example):
example["input_ids"] = tokenizer(example["text"]).input_ids
return example # 不padding,保留原始长度
train_dataset = raw_dataset.map(prompt_formatter, batched=False)
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True) # 动态padding和masking
如果prompt_formatter
已经完成了所有必要处理(包括padding到固定长度),且不需要任何动态逻辑,确实可以跳过data_collator
(或传入default_data_collator
,它仅做简单的batch堆叠)。
但如果需要动态操作(如padding),即使有dataset.map
,仍需data_collator
。
二者并非互斥,而是互补的。dataset.map
用于离线预处理,data_collator
用于在线批处理。
选择取决于你的数据处理需求是否依赖动态上下文或随机性。
accelerate库里有一个方法:PartialState().local_main_process_first()
,我看到这个方法会用来加载数据集,这个看起来像是并行加载数据集的意思,你可以解释一下这个函数是用来做什么的吗?
注:这个方法本来是在PPO官方示例中使用的
PartialState().local_main_process_first()
是 Hugging Face accelerate
库中用于多进程并行处理的一个关键方法,尤其在分布式训练或推理时控制数据加载、文件处理等操作的执行顺序。它的核心作用是确保某些操作(如数据集加载或文件预处理)只在**主进程(main process)**上优先执行,避免多进程竞争或重复操作。
local_main_process_first()
返回一个上下文管理器(context manager),其行为如下:
(1) 加载数据集时避免重复缓存
在分布式训练中,如果多个进程同时尝试缓存数据集(例如通过 datasets
库),可能会导致文件写入冲突。通过 local_main_process_first()
可以确保只有主进程生成缓存,其他进程直接复用:
from accelerate import PartialState
state = PartialState() # 初始化分布式状态
# 只有主进程先缓存数据集,其他进程等待后直接加载缓存
with state.local_main_process_first():
dataset = load_dataset("imdb") # 假设这是耗时操作
(2) 文件预处理或下载
如果预处理(如分词、数据转换)只需要执行一次:
with state.local_main_process_first():
tokenized_data = dataset.map(tokenize_function, batched=True) # 主进程先处理
# 其他进程会自动等待主进程完成,然后直接使用处理后的数据
PartialState
:是 accelerate
中管理分布式状态的类,自动检测当前进程的分布式环境(如多GPU、TPU等)。local_main_process
:指当前节点(机器)上的主进程(通常是全局主进程的一个子集,在多机训练时每台机器有自己的 local_main_process
)。torch.distributed
或类似库实现进程间同步,确保非主进程等待主进程完成。方法/属性 | 作用范围 | 用途 |
---|---|---|
local_main_process_first() |
当前节点的主进程优先 | 避免文件/缓存操作的竞争 |
main_process_first() |
全局主进程优先(跨所有节点) | 适用于全局只需执行一次的操作 |
on_main_process() |
仅在主进程执行 | 完全跳过其他进程(如日志记录) |
from accelerate import PartialState
from datasets import load_dataset
# 初始化分布式状态(自动检测环境)
state = PartialState()
# 只有主进程优先加载并缓存数据,其他进程等待
with state.local_main_process_first():
dataset = load_dataset("glue", "mrpc", split="train")
tokenized_data = dataset.map(lambda x: tokenizer(x["sentence"]), batched=True)
# 所有进程继续同步训练
for batch in tokenized_data:
outputs = model(**batch)
...
local_main_process_first()
。如果需要进一步了解 accelerate
的分布式策略(如多机训练),可以继续探讨!
这个是阅读TRL文档发现的一个问题,DPOTrainer
的ref_model
参数是有默认值None
的,但是PPOTrainer
的ref_model
参数没有默认值。这意味着必须给PPOTrainer
传入ref_model
参数,尽管一般都是跟目标模型一样。
你提出了一个很好的问题!DPO(Direct Preference Optimization)虽然确实不属于传统强化学习的范畴,但它仍然保留了reference model,这是出于特定的设计考虑。下面我详细解释原因:
DPO通过直接优化偏好数据来规避强化学习,但其目标函数(源自Bradley-Terry模型)隐式地包含了一个KL散度约束:
L DPO = − E [ log σ ( β log π θ ( y w ∣ x ) π ref ( y w ∣ x ) − β log π θ ( y l ∣ x ) π ref ( y l ∣ x ) ) ] \mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] LDPO=−E[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
其中:
Reference Model的作用:
ref_model
在HuggingFace TRL库中,DPOTrainer
确实需要传入ref_model
参数,原因如下:
ref_model
在训练过程中是冻结的(不更新),仅作为基准。ref_model
可能与当前模型共享部分结构(例如共享底层Transformer),但通过梯度截断避免更新。特性 | PPO (RL) | DPO (Non-RL) |
---|---|---|
Reference Model | 旧策略 π old \pi_{\text{old}} πold | 预训练模型 π ref \pi_{\text{ref}} πref |
更新方式 | 定期同步(on-policy要求) | 完全冻结 |
作用 | 重要性采样校正 + KL约束 | 隐式KL约束 + 防止过拟合 |
假设令 π ref \pi_{\text{ref}} πref为均匀分布,则DPO退化为直接最大化偏好对的似然,可能导致:
DPO虽然跳过了显式的强化学习框架,但通过保留reference model,继承了RL中“约束策略更新”的核心思想。这是其稳定性和性能的关键设计,也是它区别于纯监督学习(如直接微调)的重要特征。
PS:我感觉这个回答没有解释我的问题,答非所问了