# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import sys
import re
from trl import SFTConfig, SFTTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
def format_user_question(example):
candidates = "\n".join(example["candidates"])
content = f"""
用户的问题为:{example["query"]},请从候选列表Candidates中选择一个和用户问题意思一样的问题,\n
Candidates: {candidates} \n
和用户意思一致的问题为:
"""
return content
def main(training_args, model_args):
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer.pad_token = tokenizer.eos_token
import os
cur_dir = os.path.dirname(__file__)
data_files = {
"train": os.path.join(cur_dir, "data/train.jsonl"),
"validation": os.path.join(cur_dir, "data/test.jsonl"),
"test": os.path.join(cur_dir, "data/test.jsonl"),
}
dataset = load_dataset("json", data_files=data_files)
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": "you are a helpful assistant."},
{"role": "user", "content": format_user_question(example)},
],
}
def make_text(example):
return {"text": "[INST]" + format_user_question(example) + "[/INST] " + example["answer"] + ""}
dataset = dataset.map(make_text, remove_columns=["candidates", "query", "answer"])
training_args = SFTConfig(
output_dir="output_sft",
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
num_train_epochs=2,
eval_steps=100,
do_eval=True,
eval_strategy="steps",
save_steps=100,
save_total_limit=1,
max_grad_norm=0.1,
log_on_each_node=False,
)
model_args.use_peft = True
# Initialize the GRPO trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
training_args, model_args = parser.parse_args_and_config()
main(training_args, model_args)
注意:这里的必须按照trl的格式定义text字段,将提示词和答案包含进去
注意自定义的俩个奖励函数的传参,这样就将answer字段传进来了。
def accuracy_reward(completions, answer, **kwargs):
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, answer):
def _parse(a, b):
a = a.split("")[-1].split(" ")[0]
if a == b:
return 1.0
else:
return 0.0
rewards.append(_parse(content, sol))
return rewards
def format_reward(completions, **kwargs):
pattern = r"^.*? .*? $"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
reward_funcs_registry = {
"accuracy": accuracy_reward,
"format": format_reward,
}
from transformers import AutoModelForCausalLM, AutoTokenizer
# 模型名称
model_name = "qwen/Qwen2.5-7B-Instruct"
# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 输入文本
prompt = "Give me a short introduction to large language model."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
# 应用对话模板
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 模型输入
model_inputs = tokenizer(text, return_tensors="pt").to(model.device)
# 生成文本
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
# 解码生成结果
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(response)
批量预测代码:
from transformers import AutoModelForCausalLM, AutoTokenizer
# 模型名称
model_name = "qwen/Qwen2.5-7B-Instruct"
# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left') # 设置左填充
# 批量输入文本
batch_prompt = [
"Give me a short introduction to large language model.",
"请介绍下您自己?"
]
batch_messages = []
for prompt in batch_prompt:
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
batch_messages.append(messages)
# 应用对话模板
text = tokenizer.apply_chat_template(
batch_messages,
tokenize=False,
add_generation_prompt=True,
)
# 模型输入
model_inputs = tokenizer(text, padding=True, return_tensors="pt").to(model.device)
# 生成文本
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
)
# 解码生成结果
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# 输出结果
for resp in response:
print(resp)
注意:现在不需要自己去构造提示词模板了