【慢摹】TRL训练器细节详解(SFT/PPO/DPO/GRPO)

序言

本文总结一下目前TRL典型的训练器的实现细节(SFT,PPO,DPO,GRPO),也是对上一文【速写】PPOTrainer样例与错误思考(少量DAPO)的补充

目前DeepSeek关于各个训练器细节的掌握,尤其是PPOTrainer的问题,依然回答得很差,这个在上文中已经详细指出,它写的代码大多数都是跑不通的,而官方给出的PPO示例(ppo.py)似乎也有一些瑕疵,本文将会一一指出与各位探讨。

目前笔者写了一个对于4个典型训练器(SFT,PPO,DPO,GRPO)都适用的一个base_pipeline,以及对应的单元测试模块,其中PPOTrainer是最为繁琐的一个训练器,它需要更多的模型(奖励模型,价值模型,参考模型),并且对应的训练数据集train_dataset的处理方式与其他训练器存在显著区别。根据这份base_pipeline的代码,我将逐一探讨TRL训练器中的细节问题,谨以抛砖引玉,欢迎探讨。


文章目录

  • 序言
  • 1 TRL通用的训练器Pipeline测试及其细节说明
    • 1.1 训练配置参数的设置(Config)
    • 1.2 加载分词器与模型(tokenizer & model)
    • 1.3 数据集分割与字段问题(dataset)
    • 1.4 训练与保存(checkpoint)
    • 1.5 单元测试
  • 2 一些其他的问题
    • 2.1 DataProcessor与DataCollator
      • 2.1.1 `dataset.map(prompt_formatter)`:
      • 2.1.2 `data_collator`:
      • 2.1.3 何时选择哪种方式?
    • 2.2 关于PartialState
      • 2.2.1 核心功能
      • 2.2.2 典型使用场景
      • 2.2.3 参数与底层机制
      • 2.2.4 与类似方法的区别
      • 2.2.5 完整示例(分布式训练中的数据加载)
      • 2.2.6 注意事项
      • 2.2.7 总结
    • 2.3 DPO和PPO的reference_model的区别
      • 2.3.1 DPO的核心思想与Reference Model的作用
      • 2.3.2 **为什么DPO需要Reference Model,尽管它不是RL?**
      • 2.3.3 TRL的DPOTrainer中的`ref_model`
      • 2.3.4 **DPO vs PPO的Reference Model**
      • 2.3.5 **如果没有Reference Model会怎样?**


1 TRL通用的训练器Pipeline测试及其细节说明

完整的项目在GitHub@caoyang-sufe/easyllm,以下我们先来看下面代码中的base_pipeline,其他4个pipeline(sft_pipelineppo_pipelinedpo_pipelinegrpo_pipeline)都是直接调用它使用。

对应的脚本在trainer.py

这个pipeline很重要,接下来所有的讨论都是围绕它逐行展开的。

# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]

import wandb
import logging
from copy import deepcopy
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, HfArgumentParser
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from trl import (
	ScriptArguments, ModelConfig, 
	SFTConfig, SFTTrainer,
	PPOConfig, PPOTrainer,
	DPOConfig, DPOTrainer,
	GRPOConfig, GRPOTrainer,
	get_peft_config, get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
from src.tools.trl import update_trl_config, generate_simple_data_processor

# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. 
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. 
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
#   - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
#   - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
#   - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
#   - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
#   - keyword arguments for `DPOTrainer`: e.g. "ref_model"
#   - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):
	# 1 Configuration
	TRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")
	parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))
	script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()
	script_arguments = update_trl_config(script_arguments, **config_kwargs)
	trainer_config = update_trl_config(trainer_config, **config_kwargs)
	model_config = update_trl_config(model_config, **config_kwargs)
	peft_config = get_peft_config(model_config)
	quantization_config = get_quantization_config(model_config)
	# 2 Load models and tokenizer
	logging.info("Load models and tokenizer ...")
	logging.info(f"  - Model: {model_config.model_name_or_path}")
	tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
	if not "pad_token" in tokenizer.special_tokens_map:
		tokenizer.add_special_tokens({"pad_token": "[PAD]"})
	if tokenizer.chat_template is None:
		tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
	model = AutoModelForCausalLM.from_pretrained(
		model_config.model_name_or_path,
		device_map = "auto",
		trust_remote_code = model_config.trust_remote_code,
		quantization_config = quantization_config,
	)
	if peft_config is not None:
		logging.info("Prepare model for PEFT ...")
		model.config.pretraining_tp = 1
		model.config.use_cache = False
		model.gradient_checkpointing_enable()
		# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
		# Then you need set `model.enable_input_require_grads()` yourself
		# model = prepare_model_for_kbit_training(model)
		model.enable_input_require_grads()
		model = get_peft_model(model, peft_config)
	if name == "PPO":
		logging.info("PPO load reward value and reference models ...")
		# PPO is special! It needs more components!
		logging.info(f"  - Reward model: {trainer_config.reward_model_path}")
		reward_model = AutoModelForSequenceClassification.from_pretrained(
			trainer_config.reward_model_path,
			trust_remote_code = model_config.trust_remote_code,
			num_labels = 1,
		)
		value_model = AutoModelForSequenceClassification.from_pretrained(
			trainer_config.reward_model_path,
			trust_remote_code = model_config.trust_remote_code,
			num_labels = 1,
		)
		logging.info("  - Copy reference model ...")
		ref_model = deepcopy(model)
		# ref_model = model.__class__(model.config)
		# ref_model.load_state_dict(model.state_dict())
		trainer_kwargs["reward_model"] = reward_model
		trainer_kwargs["value_model"] = value_model
		trainer_kwargs["ref_model"] = ref_model
		logging.info("  - Done!")
		if data_processor is None:
			# The data processor of PPO is also different to others
			def data_processor(_data):
				outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)
				return {"input_ids": outputs["input_ids"]}
	# 2 Load dataset
	logging.info("Load dataset ...")
	logging.info(f"  - Dataset: {script_arguments.dataset_name}")
	if data_processor is None:
		data_processor = generate_simple_data_processor(name)
	train_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_train_split)
	eval_dataset = load_dataset(script_arguments.dataset_name, split=script_arguments.dataset_test_split)
	train_dataset = train_dataset.map(data_processor, remove_columns=train_dataset.column_names)
	eval_dataset = eval_dataset.map(data_processor, remove_columns=eval_dataset.column_names)
	logging.info(f"  - Train dataset: {len(train_dataset)}")
	logging.info(f"  - Eval dataset: {len(eval_dataset)}")
	# 4 Train model
	logging.info("Trainer starts ...")
	trainer = TRLTrainer(
		model = model,
		args = trainer_config,
		train_dataset = train_dataset,
		eval_dataset = eval_dataset,
		processing_class = tokenizer,
		peft_config = peft_config,
		**trainer_kwargs
	)
	trainer.train()
	logging.info("  - Trainer finishes!")
	# 5 Save model
	if trainer_config.push_to_hub:
		logging.info(f"  - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")
		trainer.push_to_hub()
	logging.info(f"Save model to {trainer_config.output_dir}")
	trainer.save_model(trainer_config.output_dir)
# SFT Pipeline
def sft_pipeline(data_processor, config_kwargs, trainer_kwargs):
	base_pipeline(
		name = "SFT",
		data_processor = data_processor,
		config_kwargs = config_kwargs,
		trainer_kwargs = trainer_kwargs,
	)
# PPO Pipeline
def ppo_pipeline(data_processor, config_kwargs, trainer_kwargs):
	base_pipeline(
		name = "PPO",
		data_processor = data_processor,
		config_kwargs = config_kwargs,
		trainer_kwargs = trainer_kwargs,
	)
# DPO Pipeline
def dpo_pipeline(data_processor, config_kwargs, trainer_kwargs):
	base_pipeline(
		name = "DPO",
		data_processor = data_processor,
		config_kwargs = config_kwargs,
		trainer_kwargs = trainer_kwargs,
	)
# GRPO Pipeline
def grpo_pipeline(data_processor, config_kwargs, trainer_kwargs):
		base_pipeline(
		name = "GRPO",
		data_processor = data_processor,
		config_kwargs = config_kwargs,
		trainer_kwargs = trainer_kwargs,
	)

对应的单元测试脚本在trainer_pipelines.py,这里面涉及3个模型和4个数据集,都可以在huggingface上直接下载得到,model_homedataset_home可根据本地路径进行修改。

# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]

import os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipeline

model_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = [
	"Qwen/Qwen2.5-0.5B-Instruct",
	"EleutherAI/pythia-1b-deduped",
	"EleutherAI/pythia-160m",
]

dataset_names = [
	"trl-lib/tldr",	# train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]
	"trl-lib/ultrafeedback_binarized",	# train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]
	"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]
	"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]

def sft_pipeline_test():
	logging.info("SFT unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[500:600]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
	}
	sft_pipeline(data_processor, config_kwargs, trainer_kwargs)

def ppo_pipeline_test():
	logging.info("PPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[1])
	"""
	EleutherAI/pythia-1b-deduped
	GPTNeoXForCausalLM(
	  (gpt_neox): GPTNeoXModel(
		(embed_in): Embedding(50304, 2048)
		(emb_dropout): Dropout(p=0.0, inplace=False)
		(layers): ModuleList(
		  (0-15): 16 x GPTNeoXLayer(
			(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
			(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
			(post_attention_dropout): Dropout(p=0.0, inplace=False)
			(post_mlp_dropout): Dropout(p=0.0, inplace=False)
			(attention): GPTNeoXAttention(
			  (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
			  (dense): Linear(in_features=2048, out_features=2048, bias=True)
			)
			(mlp): GPTNeoXMLP(
			  (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
			  (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
			  (act): GELUActivation()
			)
		  )
		)
		(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
		(rotary_emb): GPTNeoXRotaryEmbedding()
	  )
	  (embed_out): Linear(in_features=2048, out_features=50304, bias=False)
	)
	"""
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	reward_model_path = os.path.join(model_home, model_names[2])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"reward_model_path": reward_model_path,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[:100]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["query_key_value"],
	}
	trainer_kwargs = {
	}
	ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)

def dpo_pipeline_test():
	logging.info("DPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[2])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "descriptiveness[:500]",
		"dataset_test_split": "descriptiveness[500:600]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
	}
	dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

def grpo_pipeline_test():
	logging.info("GRPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	data_processor = None
	def reward_funcs(completions, **kwargs):
		return [float(len(set(completion))) for completion in completions]
	config_kwargs = {
		"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[:100]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
		"reward_funcs": reward_funcs,
	}
	grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

目前单元测试都能通过,关键库版本如下:

accelerate==1.6.0
datasets==3.5.0
peft==0.15.2
torch==2.5.1
transformers==4.51.3
trl==0.17.0

1.1 训练配置参数的设置(Config)

目前模型训练涉及的参数非常的广泛,以前搭积木时代训练模型时,我们一般会自定义一个Config类用于管理与模型训练数据集等相关的参数,但目前Transformers提供了非常好的工具HfArgumentParser用于管理这些繁杂的参数,根据官方给出的PPO示例(ppo.py)中的写法

parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()

将所有参数分为三类:ScriptArguments, PPOConfig, ModelConfig

使用HfArgumentParser解析参数的逻辑是(这个过程没有读源码,可能实际上略有出入,但是结果应该是一致的),首先使用默认值初始化ScriptArguments, PPOConfig, ModelConfig对应的三个对象,然后将执行脚本中传入的参数按名称对应分配给每个对象。

例如,在PPO的例子的执行脚本中:

python -i examples/scripts/ppo/ppo.py \
    --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
    --dataset_train_split descriptiveness \
    --learning_rate 3e-6 \
    --output_dir models/minimal/ppo \
    --per_device_train_batch_size 64 \
    --gradient_accumulation_steps 1 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --missing_eos_penalty 1.0

会将dataset_namedataset_train_split的值更新给script_argslearning_rate output_dir的值更新给model_args

不妨可以看看script_args, training_args, model_args分别包含了哪些参数:

  • script_args
ScriptArguments(
	dataset_name=None, 
	dataset_config=None, 
	dataset_train_split='train', 
	dataset_test_split='test', 
	gradient_checkpointing_use_reentrant=False, 
	ignore_bias_buffers=False)
  • training_args(即ppo_config
PPOConfig(
	output_dir='trainer_output',
	overwrite_output_dir=False,
	do_train=False,
	do_eval=False,
	do_predict=False,
	eval_strategy=<IntervalStrategy.NO: 'no'>,
	prediction_loss_only=False,
	per_device_train_batch_size=8,
	per_device_eval_batch_size=8,
	per_gpu_train_batch_size=None,
	per_gpu_eval_batch_size=None,
	gradient_accumulation_steps=1,
	eval_accumulation_steps=None,
	eval_delay=0,
	torch_empty_cache_steps=None,
	learning_rate=5e-05,
	weight_decay=0.0,
	adam_beta1=0.9,
	adam_beta2=0.999,
	adam_epsilon=1e-08,
	max_grad_norm=1.0,
	num_train_epochs=3.0,
	max_steps=-1,
	lr_scheduler_type=<SchedulerType.LINEAR: 'linear'>,
	lr_scheduler_kwargs={},
	warmup_ratio=0.0,
	warmup_steps=0,
	log_level='passive',
	log_level_replica='warning',
	log_on_each_node=True,
	logging_dir='trainer_output\\runs\\Jun14_11-14-00_LAPTOP-PJP6MGE1',
	logging_strategy=<IntervalStrategy.STEPS: 'steps'>,
	logging_first_step=False,
	logging_steps=500,
	logging_nan_inf_filter=True,
	save_strategy=<SaveStrategy.STEPS: 'steps'>,
	save_steps=500,
	save_total_limit=None,
	save_safetensors=True,
	save_on_each_node=False,
	save_only_model=False,
	restore_callback_states_from_checkpoint=False,
	no_cuda=False,
	use_cpu=False,
	use_mps_device=False,
	seed=42,
	data_seed=None,
	jit_mode_eval=False,
	use_ipex=False,
	bf16=False,
	fp16=False,
	fp16_opt_level='O1',
	half_precision_backend='auto',
	bf16_full_eval=False,
	fp16_full_eval=False,
	tf32=None,
	local_rank=0,
	ddp_backend=None,
	tpu_num_cores=None,
	tpu_metrics_debug=False,
	debug=[],
	dataloader_drop_last=False,
	eval_steps=None,
	dataloader_num_workers=0,
	dataloader_prefetch_factor=None,
	past_index=-1,
	run_name='trainer_output',
	disable_tqdm=False,
	remove_unused_columns=True,
	label_names=None,
	load_best_model_at_end=False,
	metric_for_best_model=None,
	greater_is_better=None,
	ignore_data_skip=False,
	fsdp=[],
	fsdp_min_num_params=0,
	fsdp_config={'min_num_params': 0,
	'xla': False,
	'xla_fsdp_v2': False,
	'xla_fsdp_grad_ckpt': False},
	fsdp_transformer_layer_cls_to_wrap=None,
	accelerator_config=AcceleratorConfig(split_batches=False,
	dispatch_batches=None,
	even_batches=True,
	use_seedable_sampler=True,
	non_blocking=False,
	gradient_accumulation_kwargs=None,
	use_configured_state=False),
	deepspeed=None,
	label_smoothing_factor=0.0,
	optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>,
	optim_args=None,
	adafactor=False,
	group_by_length=False,
	length_column_name='length',
	report_to=['wandb'],
	ddp_find_unused_parameters=None,
	ddp_bucket_cap_mb=None,
	ddp_broadcast_buffers=None,
	dataloader_pin_memory=True,
	dataloader_persistent_workers=False,
	skip_memory_metrics=True,
	use_legacy_prediction_loop=False,
	push_to_hub=False,
	resume_from_checkpoint=None,
	hub_model_id=None,
	hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>,
	hub_token=None,
	hub_private_repo=None,
	hub_always_push=False,
	gradient_checkpointing=False,
	gradient_checkpointing_kwargs=None,
	include_inputs_for_metrics=False,
	include_for_metrics=[],
	eval_do_concat_batches=True,
	fp16_backend='auto',
	push_to_hub_model_id=None,
	push_to_hub_organization=None,
	push_to_hub_token=None,
	mp_parameters='',
	auto_find_batch_size=False,
	full_determinism=False,
	torchdynamo=None,
	ray_scope='last',
	ddp_timeout=1800,
	torch_compile=False,
	torch_compile_backend=None,
	torch_compile_mode=None,
	include_tokens_per_second=False,
	include_num_input_tokens_seen=False,
	neftune_noise_alpha=None,
	optim_target_modules=None,
	batch_eval_metrics=False,
	eval_on_start=False,
	use_liger_kernel=False,
	eval_use_gather_object=False,
	average_tokens_across_devices=False,
	dataset_num_proc=None,
	num_mini_batches=1,
	total_episodes=None,
	local_rollout_forward_batch_size=64,
	num_sample_generations=10,
	response_length=53,
	stop_token=None,
	stop_token_id=None,
	temperature=0.7,
	missing_eos_penalty=None,
	sft_model_path='EleutherAI/pythia-160m',
	world_size=None,
	num_total_batches=None,
	micro_batch_size=None,
	local_batch_size=None,
	batch_size=None,
	local_mini_batch_size=None,
	mini_batch_size=None,
	exp_name='ppo_config',
	reward_model_path='EleutherAI/pythia-160m',
	model_adapter_name=None,
	ref_adapter_name=None,
	num_ppo_epochs=4,
	whiten_rewards=False,
	kl_coef=0.05,
	kl_estimator='k1',
	cliprange=0.2,
	vf_coef=0.1,
	cliprange_value=0.2,
	gamma=1.0,
	lam=0.95,
	ds3_gather_for_generation=True
)
  • model_args
ModelConfig(
	model_name_or_path=None, 
	model_revision='main', 
	torch_dtype=None, 
	trust_remote_code=False, 
	attn_implementation=None, 
	use_peft=False, 
	lora_r=16, 
	lora_alpha=32, 
	lora_dropout=0.05, 
	lora_target_modules=None, 
	lora_modules_to_save=None, 
	lora_task_type='CAUSAL_LM', 
	use_rslora=False, 
	use_dora=False, 
	load_in_8bit=False, 
	load_in_4bit=False, 
	bnb_4bit_quant_type='nf4', 
	use_bnb_nested_quant=False,
)

其中:

  • script_args:主要控制数据集的参数。特别地,dataset_name是required,但这个事情比较奇怪

    • 如果是在Linux上使用脚本启动的话是必须传入--dataset_name或者--dataset-name参数的,否则会显示 error: the following arguments are required: --dataset_name/--dataset-nametransformers==4.51.3, trl==0.17.0
    • 在Windows系统上即使不传入任何参数也是不会报错的(transformers==4.52.4, trl==0.18.1,比Linux上的版本要高一点,但应该不是版本问题)。
  • training_args:即PPOConfig的对象,对应PPOTrainer中的args参数,其中包含了非常多与训练参数,具体每个训练器及其配置的参数列表可以直接到HuggingFace上的TRL文档查看(https://huggingface.co/docs/trl),这里例举常用的参数:

    • output_dir:模型checkpoint导出的路径,
    • adam_xxx:Adam优化器相关的参数,似乎TRL里所有Trainer的默认优化器都是Adam,不过是可以在PPOTrainer的参数中定义优化器的(PPOTrainer的优化器参数optimizers是需要传入两个变量(optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR])),一个是优化器,另一个是学习率的规划器。
    • learning_rate:学习率,
    • kl_coef:KL散度的惩罚系数,越大相当于策略更新幅度越小,越小策略就可能更新幅度越快,
  • model_args对应的是PPOTrainerpeft_config参数,里面也有一些和模型初始化相关的参数,包括量化与PEFT相关的参数也是从中提取得到的,比如在base_pipeline脚本中:

...
# Trainer Pipeline
# @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
# @param data_processor: Function object prepared for `dataset.map(data_processor)`
# @param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. 
# @param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. 
# @param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. "dataset_name", "dataset_train_split", "dataset_test_split"
# @param config_kwargs: [Dict] keyword arguments for updating TRL-Config, `ScriptArguments`, `ModelConfig`
#   - keyword arguments for `TRLConfig`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ScriptArguments`: e.g. "output_dir", "adam_xxx", "learning_rate", "kl_coef", "push_to_hub"
#   - keyword arguments for `ModelConfig`: e.g. "model_name_or_path", "torch_dtype", "trust_remote_code", "use_peft", "lora_xxx", "load_in_4bit", "bnb_4bit_compute_dtype", "bnb_4bit_quant_type"
# @param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer
#   - keyword arguments for all Trainers: e.g. "data_collator", "callbacks"
#   - keyword arguments for `SFTTrainer`: e.g. "compute_loss_func", "compute_metrics"
#   - keyword arguments for `PPOTrainer`: e.g. "ref_model[required]", "reward_model[required]", "value_model[required]"
#   - keyword arguments for `DPOTrainer`: e.g. "ref_model"
#   - keyword arguments for `GRPOTrainer`: e.g. "reward_funcs[required]"
def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):
	# 1 Configuration
	TRLConfig, TRLTrainer = eval(f"{name}Config"), eval(f"{name}Trainer")
	parser = HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))
	script_arguments, trainer_config, model_config = parser.parse_args_into_dataclasses()
	script_arguments = update_trl_config(script_arguments, **config_kwargs)
	trainer_config = update_trl_config(trainer_config, **config_kwargs)
	model_config = update_trl_config(model_config, **config_kwargs)
	peft_config = get_peft_config(model_config)
	quantization_config = get_quantization_config(model_config)
	...

peft_configquantization_config都是可以现成的函数生成的,它们被定义在trl项目根目录下的trainer/utils.py中:

def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
	if model_args.load_in_4bit:
		quantization_config = BitsAndBytesConfig(
			load_in_4bit=True,
			bnb_4bit_compute_dtype=model_args.torch_dtype,  # For consistency with model weights, we use the same value as `torch_dtype`
			bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
			bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
			bnb_4bit_quant_storage=model_args.torch_dtype,
		)
	elif model_args.load_in_8bit:
		quantization_config = BitsAndBytesConfig(
			load_in_8bit=True,
		)
	else:
		quantization_config = None
	return quantization_config

def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
    if model_args.use_peft is False:
        return None

    if not is_peft_available():
        raise ValueError(
            "You need to have PEFT library installed in your environment, make sure to install `peft`. "
            "Make sure to run `pip install -U peft`."
        )
    peft_config = LoraConfig(
        task_type=model_args.lora_task_type,
        r=model_args.lora_r,
        target_modules=model_args.lora_target_modules,
        lora_alpha=model_args.lora_alpha,
        lora_dropout=model_args.lora_dropout,
        bias="none",
        use_rslora=model_args.use_rslora,
        use_dora=model_args.use_dora,
        modules_to_save=model_args.lora_modules_to_save,
    )
    return peft_config

从上面的源码可以看出,如果model_config.use_peft=False,则默认不启用PEFT,得到的peft_config也就是None,同理是否采用量化取决于load_in_4bitload_in_8bit是否至少有一个是True


1.2 加载分词器与模型(tokenizer & model)

回到base_pipeline中对应的部分:

	...
	# 2 Load models and tokenizer
	logging.info("Load models and tokenizer ...")
	logging.info(f"  - Model: {model_config.model_name_or_path}")
	tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
	if not "pad_token" in tokenizer.special_tokens_map:
		tokenizer.add_special_tokens({"pad_token": "[PAD]"})
	if tokenizer.chat_template is None:
		tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
	model = AutoModelForCausalLM.from_pretrained(
		model_config.model_name_or_path,
		device_map = "auto",
		trust_remote_code = model_config.trust_remote_code,
		quantization_config = quantization_config,
	)
	if peft_config is not None:
		logging.info("Prepare model for PEFT ...")
		model.config.pretraining_tp = 1
		model.config.use_cache = False
		model.gradient_checkpointing_enable()
		# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
		# Then you need set `model.enable_input_require_grads()` yourself
		# model = prepare_model_for_kbit_training(model)
		model.enable_input_require_grads()
		model = get_peft_model(model, peft_config)
	...

这里主要是几个零碎的注意点:

  1. 分词器的词汇表中一定要包含pad_token,具体可以通过tokenizer.special_tokens_map,否则在后面trainer.train()时会报错,提示你设置pad_token
  2. 关于tokenizer.chat_template的问题,这个到接下来的数据处理部分会详细说明,针对对话类的任务,一般是会采用chat_template来处理数据集的,因此可以用trl.trainer.utils中提供的SIMPLE_CHAT_TEMPLATE作为缺省
  3. 模型可以通过量化加载(即quantization_config不为None时),这样会节约内存,此时一般都会做如下的设置(如果不是量化加载的话,可能加上这些配置也没什么关系,所以我就这么写了)
    • 这里的pretraining_tp是张量并行的意思,数值越大,就会并行的数量越多,当然只有多卡才能并行,单卡只能设为1
    • use_cache应该是指(大概)做model.generate()生成时是否进行kv_cache缓存
    • model.enable_input_require_grads()也是节约内存的方法,我记得是缓存梯度以达到更快的反向传播运算的作用。
		...
		model.config.pretraining_tp = 1
		model.config.use_cache = False
		model.gradient_checkpointing_enable()
		# If `prepare_model_for_kbit_training` is ignored, and `gradient_checkpointing = True` (for GPU memory saving)
		# Then you need set `model.enable_input_require_grads()` yourself
		# model = prepare_model_for_kbit_training(model)
		model.enable_input_require_grads()
		...

然后对于PPOTrainer而言,非常特殊,它的构造参数还需要reward_model, value_model, ref_model三个参数,因此在base_pipeline中做了额外处理:

	if name == "PPO":
		logging.info("PPO load reward value and reference models ...")
		# PPO is special! It needs more components!
		logging.info(f"  - Reward model: {trainer_config.reward_model_path}")
		reward_model = AutoModelForSequenceClassification.from_pretrained(
			trainer_config.reward_model_path,
			trust_remote_code = model_config.trust_remote_code,
			num_labels = 1,
		)
		value_model = AutoModelForSequenceClassification.from_pretrained(
			trainer_config.reward_model_path,
			trust_remote_code = model_config.trust_remote_code,
			num_labels = 1,
		)
		logging.info("  - Copy reference model ...")
		ref_model = deepcopy(model)
		# ref_model = model.__class__(model.config)
		# ref_model.load_state_dict(model.state_dict())
		trainer_kwargs["reward_model"] = reward_model
		trainer_kwargs["value_model"] = value_model
		trainer_kwargs["ref_model"] = ref_model
		logging.info("  - Done!")

这里注意一下:

  • 一般初始化ref_model和目标模型是完全一样的,这里测试了两种方法,一种是直接ref_model = deepcopy(model),这确实是可行的,不会发生unpickled之类的问题,另一种方法则是重构一个和目标模型结构相同的空模型,然后加载状态字典:
      # ref_model = model.__class__(model.config)
      # ref_model.load_state_dict(model.state_dict())
    
    这个方法在use_peft=False时是可行的,但是因为这里可能是一个peft_model,因此model.__class__的参数还需要一个peft_config,直接这样写是会报错的
  • 使用HfArgumentParser加载PPOConfig参数时,ppo_config.reward_model_path是有默认值的(可以看上面的training_args),即'EleutherAI/pythia-160m',这个奖励模型必须是可以被AutoModelForSequenceClassification类型加载的
  • value_model通常设置为和reward_model相同,至少在官方的PPO示例中是这样的,其实也可以解释,我的理解是:
    A ( s , a ) = Q ( s , a ) − V ( s ) A(s,a)=Q(s,a)-V(s) A(s,a)=Q(s,a)V(s)
    中,这里 V ( s ) V(s) V(s)value_model Q ( s , a ) = r ( s , a ) + γ V ( s ′ ) Q(s,a)=r(s,a)+\gamma V(s') Q(s,a)=r(s,a)+γV(s)中的 r ( s , a ) r(s,a) r(s,a)即奖励函数,在CAUSAL_LM的语境下,其实就是句子 s s s加了一个单词 a a a后的奖励,本质上都是对一句话进行评分。虽然说得通,但是感觉也挺奇怪的。

1.3 数据集分割与字段问题(dataset)

目前主流的数据集加载都是直接用HuggingFace的datasets.load_dataset方法加载,标准的数据格式就是每个样本的格式为{“column_1”: , "column_2": },整个数据集类似jsonl的格式。

针对目前已知的几个训练器,它们对数据集的格式要求大概是这样的:

  • SFTTrainer:字段要包含promptcompletion,注意,训练时只会在completion部分计算损失,prompt部分是不会计算损失的

    • 但是其实我发现如果数据集中只有text或者input+target这种字段时也是可行的,需要仔细读源码才行。
  • DPOTrainer:这个就很简单,一般就是prompt+chosen+rejected,但是在官方示例使用的数据集trl-lib/ultrafeedback_binarized中,字段却是chosen, rejected, score_chosen, score_rejected,也就是说其实prompt可能并非必须,本来也确实可以直接用空字符串替代,然后数据集中如果没有回答进行评分的话,可能还有一个默认的评分机制在训练器里面。

  • GRPOTrainer:这个我也是根据官方示例来的,我看到它使用的数据集是trl-lib/tldr,这是一个典型的只有promptcompletion两个字段的数据集

  • PPOTrainer这个最为特殊,之前DeepSeek一直写不对的原因也在于此

    • 在官方给出的PPO运行脚本(ppo.py)中

          def prepare_dataset(dataset, tokenizer):
          """pre-tokenize the dataset before training; only collate during training"""
          def tokenize(element):
              outputs = tokenizer(
                  element[dataset_text_field],
                  padding=False,
              )
              return {"input_ids": outputs["input_ids"]}
          return dataset.map(
              tokenize,
              batched=True,
              remove_columns=dataset.column_names,
              num_proc=training_args.dataset_num_proc,
          )
      

      注意到这里是直接处理成分词后的input_ids格式

    • 但是,我本来以为PPO和GRPO是类似的,因此理论上在数据集的格式要求上应该也差不了太多,因此我觉得可能prompt+completion也是可行的,但是测试下来会报错,报错提示是一定要求是带input_ids字段的。

    • 因此我单独给PPO写了一个数据处理的data_processor,而其余的都是很简单的:

      # @param name: [Str] e.g. "SFT", "PPO", "DPO", "GRPO"
      def generate_simple_data_processor(name, **kwargs):
      	if name in ["SFT", "GRPO"]:
      		def _data_processor(_data):
      			return {"prompt": _data["prompt"], "completion": _data["completion"]}
      	elif name == "PPO":
      		tokenizer = kwargs.get("tokenizer")
      		def _data_processor(_data):
      			outputs = tokenizer(_data["prompt"] + _data["completion"], padding = False)
      			return {"input_ids": outputs["input_ids"]}
      	elif name == "DPO":
      		def _data_processor(_data):
      			return {"prompt": _data["prompt"], "chosen": _data["chosen"], "rejected": _data["rejected"]}
      	else:
      		raise NotImplementedError(name)
      	return _data_processor
      

当然一些写法也会使用apply_chat_template,这通常是用于交互式对话任务,都是可行的:

		def _data_processor(_data):
			_message = [
				{"role": "system", "content": "You are an AI assistant developped by CY"},
				{"role": "user", "content": _data[dataset_input_column]},
				{"role": "assistant", "content": _data[dataset_target_column]},
			]
			_prompt = tokenizer.apply_chat_template(_message, tokenize=False)
			return {"text": _prompt}

然后一个小细节是在用dataset.map进行数据处理时一般会设置参数emove_columns=dataset.column_names,以丢弃原先不必要的字段,防止Trainer在运行时错用了其他的字段进行训练。

1.4 训练与保存(checkpoint)

base_pipeline中最后一部分训练完然后保存模型,顺利地话就完整地跑通:

	# 4 Train model
	logging.info("Trainer starts ...")
	trainer = TRLTrainer(
		model = model,
		args = trainer_config,
		train_dataset = train_dataset,
		eval_dataset = eval_dataset,
		processing_class = tokenizer,
		peft_config = peft_config,
		**trainer_kwargs
	)
	trainer.train()
	logging.info("  - Trainer finishes!")
	# 5 Save model
	if trainer_config.push_to_hub:
		logging.info(f"  - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id}")
		trainer.push_to_hub()
	logging.info(f"Save model to {trainer_config.output_dir}")
	trainer.save_model(trainer_config.output_dir)

这里值得注意的几个问题:

  1. trainer_config(即SFTConfig, PPOConfig, DPOConfig, GRPOConfig之类的对象)中,有一个参数trainer_config.report_to,这个不设置的话默认是会上传到W&B的,网络不支持访问的话是会在训练到checkpoint的时候发生网络错误的,因此一般会设置成"none",或者也可以用"tensorboard"

    • 不过就训练绘图的话,训练结束到trainer_config.output_dir中找到checkpoint-xxx文件夹,里面会有trainer_state.json文件,然后自己根据里面的数据绘图即可,例如:
      # Plot dynamics of TRL trainer state
      def plot_trl_dynamics(trainer_state_path):
      	with open(trainer_state_path, 'r', encoding="utf8") as f:
      		data = json.load(f)
      	log_history = data["log_history"]
      	steps = [entry["step"] for entry in log_history]
      	episodes = [entry["episode"] for entry in log_history]
      	epochs = [entry["epoch"] for entry in log_history]
      	policy_loss = [entry["loss/policy_avg"] for entry in log_history]
      	value_loss = [entry["loss/value_avg"] for entry in log_history]
      	lrs = [entry["lr"] for entry in log_history]
      	entropys = [entry["objective/entropy"] for entry in log_history]
      	kls = [entry["objective/kl"] for entry in log_history]
      	non_score_rewards = [entry["objective/non_score_reward"] for entry in log_history]
      	rlhf_rewards = [entry["objective/rlhf_reward"] for entry in log_history]
      	scores = [entry["objective/scores"] for entry in log_history]
      	plt.figure(figsize=(8, 8))
      	ax_1 = plt.subplot(2, 2, 1)
      	ax_2 = plt.subplot(4, 2, 2)
      	ax_3 = plt.subplot(4, 2, 4)
      	ax_4 = plt.subplot(2, 2, 3)
      	ax_5 = plt.subplot(2, 2, 4)
      	ax_1.plot(steps, policy_loss, label="Policy Loss")
      	ax_1.plot(steps, value_loss, label="Value Loss", linestyle="--")
      	ax_1.set_xlabel("Step"), ax_1.set_ylabel("Loss"), ax_1.legend()
      	ax_1.set_title("Policy and Value Loss")
      	# ------------------------------------------------------------------
      	ax_2.plot(steps, kls, label="objective/kl")
      	ax_2.set_xlabel("Step"), ax_2.set_ylabel("KL"), ax_2.legend()
      	ax_2.set_title("KL Curve")
      	# ------------------------------------------------------------------
      	ax_3.plot(steps, entropys, label="objective/entropy")
      	ax_3.set_xlabel("Step"), ax_3.set_ylabel("Entropy"), ax_3.legend()
      	ax_3.set_title("Entropy Curve")
      	# ------------------------------------------------------------------
      	ax_4.plot(steps, lrs, label="Learning Rate")
      	ax_4.set_xlabel("Step"), ax_4.set_ylabel("Learning Rate"), ax_4.legend()
      	ax_4.set_title("Learning Rate Curve")
      	# ------------------------------------------------------------------
      	ax_5.plot(steps, non_score_rewards, label="objective/non_score_reward", linestyle="--")
      	ax_5.plot(steps, rlhf_rewards, label="objective/rlhf_reward", linestyle="--")
      	ax_5.plot(steps, scores, label="objective/scores")
      	ax_5.set_xlabel("Step"), ax_5.set_ylabel("Score/Reward"), ax_5.legend()
      	ax_5.set_title("Reward and Score")
      	plt.show()
      	print(steps)
      	print('-' * 64)
      	print(episodes)
      	print('-' * 64)
      	print(epochs)
      
  2. 关于W&B的使用:如果网络支持,使用wandb确实也是不错的选择:

    登录 wandb

    方法 1:交互式登录(推荐)

    运行以下代码,会弹出浏览器或要求输入 API Key:

    import wandb
    wandb.login()
    

    然后按提示操作即可。

    方法 2:直接提供 API Key

    import wandb
    wandb.login(key="your_api_key_here")  # 替换成你的 wandb API Key
    

    ⚠️ 注意:不要直接在代码里硬编码 API Key(特别是公开代码时),建议使用环境变量或配置文件。

    方法 3:通过环境变量设置(适合脚本/服务器)

    export WANDB_API_KEY=your_api_key_here
    

    然后在 Python 中直接 wandb.init() 即可。


    步骤 3:在 Trainer 中启用 wandb

    方式 1:自动记录(默认)

    HuggingFace Trainer 默认会使用 wandb(如果已安装并登录):

    from transformers import Trainer, TrainingArguments
    training_args = TrainingArguments(
        output_dir="./results",
        report_to="wandb",  # 明确指定使用 wandb(默认已经是这样)
        logging_dir="./logs",
        # 其他训练参数...
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()  # 自动记录到 wandb
    

    方式 2:手动初始化 wandb(更灵活控制)

    import wandb
    # 初始化 wandb 并设置项目名称
    wandb.init(project="my-ppo-project", name="experiment-1")
    # 然后正常训练
    trainer.train()
    # 训练结束后(可选)
    wandb.finish()
    

    这样你可以自定义 wandb 的配置(如 projectnameconfig 等)。


额外配置(可选)

1. 记录超参数

config = {
    "learning_rate": 2e-5,
    "batch_size": 32,
    "epochs": 3,
}
wandb.init(project="my-project", config=config)

训练时,这些参数会显示在 wandb 面板中。

2. 自定义记录指标

# 在训练循环中记录自定义指标
wandb.log({"reward": current_reward, "loss": current_loss})

3. 恢复中断的训练

wandb.init(id="之前运行的ID", resume="must")
  1. 训练开始后,终端会显示类似:
    Tracking run with wandb, view at https://wandb.ai/your-username/your-project/runs/xxx
    
  2. 打开链接即可查看实时训练数据。

1.5 单元测试

单元测试脚本trainer_pipelines.py:

# -*- coding: utf8 -*-
# @author: caoyang
# @email: [email protected]

import os
import logging
from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipeline

model_home = "/nfsshare/home/caoyang/resource/model"
dataset_home = "/nfsshare/home/caoyang/resource/dataset"
model_names = [
	"Qwen/Qwen2.5-0.5B-Instruct",
	"EleutherAI/pythia-1b-deduped",
	"EleutherAI/pythia-160m",
]

dataset_names = [
	"trl-lib/tldr",	# train["prompt", "completion"] + validation["prompt", "completion"] + test["prompt", "completion"]
	"trl-lib/ultrafeedback_binarized",	# train["chosen", "rejected", "score_chosen", "score_rejected"] + test["chosen", "rejected", "score_chosen", "score_rejected"]
	"trl-internal-testing/descriptiveness-sentiment-trl-style", # sentiment["prompt", "chosen", "rejected"] + descriptiveness["prompt", "chosen", "rejected"]
	"YeungNLP/firefly-train-1.1M", # train["input", "target"]
]

def sft_pipeline_test():
	logging.info("SFT unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/sft+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[500:600]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
	}
	sft_pipeline(data_processor, config_kwargs, trainer_kwargs)

def ppo_pipeline_test():
	logging.info("PPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[1])
	"""
	EleutherAI/pythia-1b-deduped
	GPTNeoXForCausalLM(
	  (gpt_neox): GPTNeoXModel(
		(embed_in): Embedding(50304, 2048)
		(emb_dropout): Dropout(p=0.0, inplace=False)
		(layers): ModuleList(
		  (0-15): 16 x GPTNeoXLayer(
			(input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
			(post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
			(post_attention_dropout): Dropout(p=0.0, inplace=False)
			(post_mlp_dropout): Dropout(p=0.0, inplace=False)
			(attention): GPTNeoXAttention(
			  (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
			  (dense): Linear(in_features=2048, out_features=2048, bias=True)
			)
			(mlp): GPTNeoXMLP(
			  (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
			  (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
			  (act): GELUActivation()
			)
		  )
		)
		(final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
		(rotary_emb): GPTNeoXRotaryEmbedding()
	  )
	  (embed_out): Linear(in_features=2048, out_features=50304, bias=False)
	)
	"""
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	reward_model_path = os.path.join(model_home, model_names[2])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/ppo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"reward_model_path": reward_model_path,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[:100]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["query_key_value"],
	}
	trainer_kwargs = {
	}
	ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)

def dpo_pipeline_test():
	logging.info("DPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[2])
	data_processor = None
	config_kwargs = {
		"output_dir": f"./temp/dpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "descriptiveness[:500]",
		"dataset_test_split": "descriptiveness[500:600]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
	}
	dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

def grpo_pipeline_test():
	logging.info("GRPO unittest ...")
	model_name_or_path = os.path.join(model_home, model_names[0])
	dataset_name = os.path.join(dataset_home, dataset_names[0])
	data_processor = None
	def reward_funcs(completions, **kwargs):
		return [float(len(set(completion))) for completion in completions]
	config_kwargs = {
		"output_dir": f"./temp/grpo+{model_name_or_path.split('/')[-1]}+{dataset_name.split('/')[-1]}",
		"model_name_or_path": model_name_or_path,
		"dataset_name": dataset_name,
		"trust_remote_code": True,
		"dataset_train_split": "train[:500]",
		"dataset_test_split": "validation[:100]",
		"use_peft": True,
		"report_to": "none",
		"lora_target_modules": ["q_proj", "k_proj", "v_proj"]
	}
	trainer_kwargs = {
		"reward_funcs": reward_funcs,
	}
	grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)

这里主要就是grpo_pipeline_test中的reward_funcs,其他几个测试函数都没什么好说的,很好理解

GRPOTrainer中有required参数reward_funcs,这个在TRL文档中给出的例子是:

	def reward_funcs(completions, **kwargs):
		return [float(len(set(completion))) for completion in completions]

这个奖励的意思是会给长回答更高的得分。

注意这里的**kwargs是不可以省略的,因为GRPOTrainer调用reward_funcs会传入其他参数,比如prompts或者ground_truths,在TRL的GRPO文档中有很多例子。


2 一些其他的问题

这部分摘取一些和DeepSeek讨论的问题

2.1 DataProcessor与DataCollator

关于TRL的Trainer的一个问题,因为Trainer是可以传入data_collator参数的,但是一般来说我们都会在向Trainer传入数据集(如train_dataset)之前,先写好一个函数用于预处理数据集,比如train_dataset = train.dataset.map(prompt_formatter)这样,我的理解其实这里prompt_formatter何data_collator的作用是一样的(我觉得它们处理数据的逻辑也是一样,就是传入一个样本,然后做数据处理),因此如果在Trainer外面已经做了train_dataset.map(prompt_formatter)的操作,就不需要再传入data_collator了。反之也是一样,我可以传入data_collator,而无需进行train_dataset.map(prompt_formatter)的操作

以下回答由AI生成:


你的理解基本正确,但有一些细节需要注意。prompt_formatter(在dataset.map中使用)和data_collator确实都涉及数据处理,但它们的作用时机和目的有所不同。下面我来详细解释它们的区别以及何时使用它们:

2.1.1 dataset.map(prompt_formatter):

  • 作用时机: 在数据加载时,预处理阶段(即在训练开始之前)对数据集中的每个样本进行处理。
  • 典型用途:
    • 将原始文本转换为模型需要的输入格式(例如添加特殊token、构建prompt-completion结构)。
    • 对数据进行过滤(通过remove_columns删除不需要的字段)。
    • 执行一次性的、确定性的转换(例如tokenization,如果你不启用动态padding)。
  • 特点:
    • 处理结果是静态的(处理后的数据会被缓存,除非设置load_from_cache_file=False)。
    • 适合所有样本共享的逻辑(例如统一的prompt模板)。

2.1.2 data_collator:

  • 作用时机: 在训练过程中,动态批处理阶段(即在每个batch生成时)对样本进行实时处理。
  • 典型用途:
    • 动态padding(将同一batch内的样本padding到相同长度)。
    • 随机数据增强(例如随机masking,适用于MLM任务)。
    • 构造特殊输入(例如为RLHF构建奖励模型的输入格式)。
  • 特点:
    • 处理是动态的(每次batch采样时重新执行)。
    • 适合需要随机性依赖batch内其他样本的操作(例如padding长度取决于当前batch)。

2.1.3 何时选择哪种方式?

  • 只用dataset.map:

    • 如果你的数据处理是确定性的、不需要动态调整(例如固定的prompt模板),且不需要batch级别的操作(如padding),可以只使用dataset.map
    • 但注意:如果涉及变长序列,仍需data_collator处理padding(除非你已手动填充到固定长度)。
  • 只用data_collator:

    • 如果数据处理需要动态性(例如随机masking)或依赖batch上下文(如padding),可以只用data_collator
    • 但注意:这样可能导致训练时重复处理数据(例如每次epoch重复tokenization),影响效率。
  • 联合使用:

    • 常见做法是:
      1. dataset.map处理静态逻辑(例如tokenization、添加固定prompt)。
      2. data_collator处理动态逻辑(例如padding、随机增强)。
    • 例如:
      def prompt_formatter(example):
          example["input_ids"] = tokenizer(example["text"]).input_ids
          return example  # 不padding,保留原始长度
      
      train_dataset = raw_dataset.map(prompt_formatter, batched=False)
      
      data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True)  # 动态padding和masking
      
  • 如果prompt_formatter已经完成了所有必要处理(包括padding到固定长度),且不需要任何动态逻辑,确实可以跳过data_collator(或传入default_data_collator,它仅做简单的batch堆叠)。

  • 但如果需要动态操作(如padding),即使有dataset.map,仍需data_collator

  • 二者并非互斥,而是互补的。dataset.map用于离线预处理,data_collator用于在线批处理。

  • 选择取决于你的数据处理需求是否依赖动态上下文或随机性。


2.2 关于PartialState

accelerate库里有一个方法:PartialState().local_main_process_first(),我看到这个方法会用来加载数据集,这个看起来像是并行加载数据集的意思,你可以解释一下这个函数是用来做什么的吗?

注:这个方法本来是在PPO官方示例中使用的


PartialState().local_main_process_first() 是 Hugging Face accelerate 库中用于多进程并行处理的一个关键方法,尤其在分布式训练或推理时控制数据加载、文件处理等操作的执行顺序。它的核心作用是确保某些操作(如数据集加载或文件预处理)只在**主进程(main process)**上优先执行,避免多进程竞争或重复操作。


2.2.1 核心功能

  • local_main_process_first() 返回一个上下文管理器(context manager),其行为如下:
    • 主进程(local main process):会立即执行上下文内的代码块。
    • 其他进程(非主进程):会等待主进程完成后,再跳过或执行相同操作(取决于具体逻辑)。
  • 目的:避免多进程同时操作共享资源(如文件读取、数据集缓存)导致的冲突或冗余计算。

2.2.2 典型使用场景

(1) 加载数据集时避免重复缓存

在分布式训练中,如果多个进程同时尝试缓存数据集(例如通过 datasets 库),可能会导致文件写入冲突。通过 local_main_process_first() 可以确保只有主进程生成缓存,其他进程直接复用:

from accelerate import PartialState

state = PartialState()  # 初始化分布式状态

# 只有主进程先缓存数据集,其他进程等待后直接加载缓存
with state.local_main_process_first():
    dataset = load_dataset("imdb")  # 假设这是耗时操作

(2) 文件预处理或下载

如果预处理(如分词、数据转换)只需要执行一次:

with state.local_main_process_first():
    tokenized_data = dataset.map(tokenize_function, batched=True)  # 主进程先处理
# 其他进程会自动等待主进程完成,然后直接使用处理后的数据

2.2.3 参数与底层机制

  • PartialState:是 accelerate 中管理分布式状态的类,自动检测当前进程的分布式环境(如多GPU、TPU等)。
  • local_main_process:指当前节点(机器)上的主进程(通常是全局主进程的一个子集,在多机训练时每台机器有自己的 local_main_process)。
  • 同步机制:底层通过 torch.distributed 或类似库实现进程间同步,确保非主进程等待主进程完成。

2.2.4 与类似方法的区别

方法/属性 作用范围 用途
local_main_process_first() 当前节点的主进程优先 避免文件/缓存操作的竞争
main_process_first() 全局主进程优先(跨所有节点) 适用于全局只需执行一次的操作
on_main_process() 仅在主进程执行 完全跳过其他进程(如日志记录)

2.2.5 完整示例(分布式训练中的数据加载)

from accelerate import PartialState
from datasets import load_dataset

# 初始化分布式状态(自动检测环境)
state = PartialState()

# 只有主进程优先加载并缓存数据,其他进程等待
with state.local_main_process_first():
    dataset = load_dataset("glue", "mrpc", split="train")
    tokenized_data = dataset.map(lambda x: tokenizer(x["sentence"]), batched=True)

# 所有进程继续同步训练
for batch in tokenized_data:
    outputs = model(**batch)
    ...

2.2.6 注意事项

  1. 缓存路径一致性:确保所有进程能访问相同的缓存路径(如共享存储)。
  2. 性能影响:主进程的预处理时间会被计入总耗时,但避免了重复操作的开销。
  3. 错误处理:如果主进程预处理失败,其他进程会一直等待,需添加超时或错误处理逻辑。

2.2.7 总结

  • 何时使用:在分布式环境中,任何需要单例执行的操作(如数据缓存、文件下载、预处理)都应使用 local_main_process_first()
  • 优势:减少资源竞争,避免冗余计算,确保分布式任务的高效协同。

如果需要进一步了解 accelerate 的分布式策略(如多机训练),可以继续探讨!


2.3 DPO和PPO的reference_model的区别

这个是阅读TRL文档发现的一个问题,DPOTrainerref_model参数是有默认值None的,但是PPOTrainerref_model参数没有默认值。这意味着必须给PPOTrainer传入ref_model参数,尽管一般都是跟目标模型一样。


你提出了一个很好的问题!DPO(Direct Preference Optimization)虽然确实不属于传统强化学习的范畴,但它仍然保留了reference model,这是出于特定的设计考虑。下面我详细解释原因:

2.3.1 DPO的核心思想与Reference Model的作用

DPO通过直接优化偏好数据来规避强化学习,但其目标函数(源自Bradley-Terry模型)隐式地包含了一个KL散度约束:
L DPO = − E [ log ⁡ σ ( β log ⁡ π θ ( y w ∣ x ) π ref ( y w ∣ x ) − β log ⁡ π θ ( y l ∣ x ) π ref ( y l ∣ x ) ) ] \mathcal{L}_{\text{DPO}} = -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] LDPO=E[logσ(βlogπref(ywx)πθ(ywx)βlogπref(ylx)πθ(ylx))]
其中:

  • π θ \pi_\theta πθ 是待训练的策略(当前模型)。
  • π ref \pi_{\text{ref}} πrefreference model(通常初始化为预训练模型,例如SFT模型)。
  • β \beta β 是控制偏离参考模型程度的超参数。

Reference Model的作用

  • 防止过度偏离初始策略:KL散度约束(通过reference model体现)避免模型为迎合偏好数据而退化(例如生成极端或无意义的输出)。
  • 隐式正则化:类似于PPO中的KL惩罚,但DPO通过直接在损失函数中对比当前模型和reference模型的输出概率实现。

2.3.2 为什么DPO需要Reference Model,尽管它不是RL?

  • 数学推导的必然性:DPO的公式是从逆向强化学习的目标推导而来的,其中reference model代表初始策略(即RL中的 π old \pi_{\text{old}} πold)。去掉reference model会导致目标函数失去约束,模型可能过拟合偏好数据。
  • 实践中的稳定性:实验表明,去掉reference model后,模型容易过度优化偏好信号(例如始终输出短答案或重复模式)。

2.3.3 TRL的DPOTrainer中的ref_model

在HuggingFace TRL库中,DPOTrainer确实需要传入ref_model参数,原因如下:

  1. 计算概率比:DPO需要比较当前模型和reference模型对正/负样本的概率(即(\log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}))。
  2. 冻结参考模型ref_model在训练过程中是冻结的(不更新),仅作为基准。
  3. 内存优化:实际实现中,ref_model可能与当前模型共享部分结构(例如共享底层Transformer),但通过梯度截断避免更新。

2.3.4 DPO vs PPO的Reference Model

特性 PPO (RL) DPO (Non-RL)
Reference Model 旧策略 π old \pi_{\text{old}} πold 预训练模型 π ref \pi_{\text{ref}} πref
更新方式 定期同步(on-policy要求) 完全冻结
作用 重要性采样校正 + KL约束 隐式KL约束 + 防止过拟合

2.3.5 如果没有Reference Model会怎样?

假设令 π ref \pi_{\text{ref}} πref为均匀分布,则DPO退化为直接最大化偏好对的似然,可能导致:

  • 模型过度自信(overconfident)。
  • 生成内容多样性下降(模式坍塌)。
  • 对噪声偏好数据过拟合。

DPO虽然跳过了显式的强化学习框架,但通过保留reference model,继承了RL中“约束策略更新”的核心思想。这是其稳定性和性能的关键设计,也是它区别于纯监督学习(如直接微调)的重要特征。

PS:我感觉这个回答没有解释我的问题,答非所问了

你可能感兴趣的:(【慢摹】TRL训练器细节详解(SFT/PPO/DPO/GRPO))