关于在SFT之后进行DPO微调的技术解析与应用指南,结合国内外最新研究及实践案例,包含代码实现与未来优化方向:
监督微调(SFT)通过高质量样本训练模型生成符合预期的输出,但其核心问题在于无法让模型在“正确”与“错误”输出间主动选择偏好。例如,在代码生成中,SFT可能生成语法正确但效率低下的代码,而缺乏对高效解决方案的偏好引导。
直接偏好优化(DPO)通过对比学习优化模型输出,无需显式奖励模型。其损失函数为:
L D P O = − log σ ( β ( log π θ ( y w ∣ x ) π r e f ( y w ∣ x ) − log π θ ( y l ∣ x ) π r e f ( y l ∣ x ) ) ) L_{DPO} = -\log \sigma\!\left( \beta \left( \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right) LDPO=−logσ(β(logπref(yw∣x)πθ(yw∣x)−logπref(yl∣x)πθ(yl∣x)))
其中,y_w
为偏好响应,y_l
为非偏好响应,π_ref
为参考模型(通常为SFT模型),β
为温度系数。
偏好对生成示例(以代码生成为例):
# 使用GPT-4生成偏好对
from openai import OpenAI
client = OpenAI()
prompt = "Write a Python function to reverse a string."
response_preferred = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt + " (高效实现)"}]
)
response_rejected = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt + " (低效实现)"}]
)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class DPOTrainer:
def __init__(self, model_name="gpt2"):
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.ref_model = AutoModelForCausalLM.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.beta = 0.1 # 温度系数
def compute_loss(self, batch):
# 计算偏好响应与非偏好响应的对数概率
logits_w = self.model(**batch["preferred"]).logits
log_probs_w = self._get_log_probs(logits_w, batch["preferred_labels"])
logits_l = self.model(**batch["rejected"]).logits
log_probs_l = self._get_log_probs(logits_l, batch["rejected_labels"])
# 计算参考模型概率
with torch.no_grad():
ref_logits_w = self.ref_model(**batch["preferred"]).logits
ref_log_probs_w = self._get_log_probs(ref_logits_w, batch["preferred_labels"])
ref_logits_l = self.ref_model(**batch["rejected"]).logits
ref_log_probs_l = self._get_log_probs(ref_logits_l, batch["rejected_labels"])
# DPO损失
ratios = (log_probs_w - ref_log_probs_w) - (log_probs_l - ref_log_probs_l)
loss = -torch.nn.functional.logsigmoid(self.beta * ratios).mean()
return loss
def _get_log_probs(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return torch.gather(shift_logits.log_softmax(-1), 2, shift_labels.unsqueeze(-1)).squeeze(-1).sum(-1)
from trl import DPOTrainer
dpo_trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
beta=0.1,
)
dpo_trainer.train()
此框架可帮助开发者快速实现从SFT到DPO的升级迭代,在代码生成、长文本理解等场景中显著提升模型性能。
【哈佛博后带小白玩转机器学习】