{
"prompt_template": "Question: {question} Answer:",
"samples": [
{"question": "北京的最高温度是多少?", "answer": "42°C"},
{"question": "埃菲尔铁塔高度?", "answer": "324米"}
]
}
from transformers import AutoModelForCausalLM, AutoTokenizer, PromptEncoder
import torch
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt_encoder = PromptEncoder(num_virtual_tokens=20, encoder_hidden_size=model.config.n_embd)
# 冻结模型参数
for param in model.parameters(): param.requires_grad = False
# 只训练 prompt_encoder
optimizer = torch.optim.Adam(prompt_encoder.parameters(), lr=1e-3)
# 训练循环
for batch in dataloader:
inputs = tokenizer(batch['question'], return_tensors='pt', padding=True)
soft_prompts = prompt_encoder(batch['prompt_ids'])
outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], inputs_embeds=soft_prompts)
loss = loss_fn(outputs.logits, tokenizer(batch['answer'], return_tensors='pt').input_ids)
loss.backward()
optimizer.step()
[
{"input": "将下面句子翻译成英文:你好,世界。", "output": "Hello, world."},
{"input": "总结这段话要点:...", "output": "..."}
]
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base')
tokenizer = AutoTokenizer.from_pretrained('t5-base')
config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
inference_mode=False,
r=8, # 低秩矩阵秩
lora_alpha=16,
lora_dropout=0.1
)
model = get_peft_model(model, config)
# 全部参数:只有 LoRA 参数可训练
model.print_trainable_parameters()
trainer = Trainer(model=model, ...)
trainer.train()
[
{"context": "写一首五言绝句关于春天。", "chosen": "春风拂面柳含烟,...", "rejected": "春天来了,花开了。"},
...
]
# 1) SFT
sft_model = train_sft(base_model, sft_data)
# 2) 奖励模型
rm = train_reward_model(sft_model, preference_pairs)
# 3) PPO 微调
from trl import PPOTrainer, PPOConfig
ppo_trainer = PPOTrainer(model=sft_model, ref_model=base_model, reward_model=rm, config=PPOConfig())
ppo_trainer.train(preference_episodes)
维度 | P-Tuning | LoRA | RLHF |
---|---|---|---|
可训练参数 | 少(几十~百级) | 少(百万级别) | 多(全量 + RM) |
显存占用 | 极低 | 低 | 高 |
部署开销 | 无 | 少量算力增量 | 高 |
性能提升 | 中等 | 高 | 极高 |
数据需求 | 少量指令样本 | 中等任务对 | 大量比较对样本 |
难度 | 简单 | 中等 | 高(RL + 人工环) |
原因在于: