Trl框架 SFT/GRPO训练+模板预测

1、SFT训练

# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from dataclasses import dataclass, field
from typing import Optional

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import sys
import re
from trl import SFTConfig, SFTTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config


def format_user_question(example):
    candidates = "\n".join(example["candidates"])
    content = f"""
    用户的问题为:{example["query"]},请从候选列表Candidates中选择一个和用户问题意思一样的问题,\n
    Candidates: {candidates} \n
    和用户意思一致的问题为:
    """
    return content

def main(training_args, model_args):
    # Load a pretrained model
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )
    tokenizer.pad_token = tokenizer.eos_token

    import os
    cur_dir = os.path.dirname(__file__)
    data_files = {
        "train": os.path.join(cur_dir, "data/train.jsonl"),
        "validation": os.path.join(cur_dir, "data/test.jsonl"),
        "test": os.path.join(cur_dir, "data/test.jsonl"),
    }
    dataset = load_dataset("json", data_files=data_files)

    # Format into conversation
    def make_conversation(example):
        return {
            "prompt": [
                {"role": "system", "content": "you are a helpful assistant."},
                {"role": "user", "content": format_user_question(example)},
            ],
        }
    def make_text(example):
        return {"text": "[INST]" + format_user_question(example) + "[/INST] " + example["answer"] + ""}

    dataset = dataset.map(make_text, remove_columns=["candidates", "query", "answer"])

    training_args = SFTConfig(
        output_dir="output_sft",
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type='cosine',
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        num_train_epochs=2,
        eval_steps=100,
        do_eval=True,
        eval_strategy="steps",
        save_steps=100,
        save_total_limit=1,
        max_grad_norm=0.1,
        log_on_each_node=False,
    )
    model_args.use_peft = True
    # Initialize the GRPO trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


def make_parser(subparsers: argparse._SubParsersAction = None):
    dataclass_types = (SFTConfig, ModelConfig)
    if subparsers is not None:
        parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
    else:
        parser = TrlParser(dataclass_types)
    return parser


if __name__ == "__main__":
    parser = make_parser()
    training_args, model_args = parser.parse_args_and_config()
    main(training_args, model_args)

注意:这里的必须按照trl的格式定义text字段,将提示词和答案包含进去

2、GRPO训练

注意自定义的俩个奖励函数的传参,这样就将answer字段传进来了。

Trl框架 SFT/GRPO训练+模板预测_第1张图片

def accuracy_reward(completions, answer, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, answer):
        def _parse(a, b):
            a = a.split("")[-1].split("")[0]
            if a == b:
                return 1.0
            else:
                return 0.0
        rewards.append(_parse(content, sol))
    return rewards


def format_reward(completions, **kwargs):
    pattern = r"^.*?.*?$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

3、预测

from transformers import AutoModelForCausalLM, AutoTokenizer

# 模型名称
model_name = "qwen/Qwen2.5-7B-Instruct"

# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 输入文本
prompt = "Give me a short introduction to large language model."
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]

# 应用对话模板
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# 模型输入
model_inputs = tokenizer(text, return_tensors="pt").to(model.device)

# 生成文本
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)

# 解码生成结果
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(response)

批量预测代码:

from transformers import AutoModelForCausalLM, AutoTokenizer

# 模型名称
model_name = "qwen/Qwen2.5-7B-Instruct"

# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')  # 设置左填充

# 批量输入文本
batch_prompt = [
    "Give me a short introduction to large language model.",
    "请介绍下您自己?"
]
batch_messages = []
for prompt in batch_prompt:
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]
    batch_messages.append(messages)

# 应用对话模板
text = tokenizer.apply_chat_template(
    batch_messages,
    tokenize=False,
    add_generation_prompt=True,
)

# 模型输入
model_inputs = tokenizer(text, padding=True, return_tensors="pt").to(model.device)

# 生成文本
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512,
)

# 解码生成结果
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# 输出结果
for resp in response:
    print(resp)

注意:现在不需要自己去构造提示词模板了

你可能感兴趣的:(人工智能,语言模型)