基于 Python 的自然语言处理系列(87):RRHF 原理与实战

✨ 本文介绍一种新型的人类反馈微调策略 —— RRHF(Rank Responses with Human Feedback),它比传统的 RLHF 更简单、更稳定,在开源社区如 AlpacaFarm、UltraFeedback 等项目中获得广泛应用。

一、RRHF 简介

        RRHF(Rank Responses with Human Feedback)本质上是一种排名监督方法,用于训练语言模型更好地按照人类偏好进行响应排序

        RRHF 训练流程:

  1. 输入一个指令 query

  2. 提供多个候选回答 responses

  3. 根据人工打分 scores 排出优劣顺序;

  4. 鼓励模型为得分高的回答生成更高对数似然;

  5. 采用一种对比性排序损失函数 + SFT 监督损失。

        相比于 PPO、DPO 等强化学习方式,RRHF:

  • ✅ 更简单(只需打分,无需 reward model)

  • ✅ 训练稳定(避免 KL、熵控制超参)

  • ✅ 兼容 SFT,便于部署与迁移

二、实验准备与依赖安装

pip install transformers trl peft datasets

环境配置:

import os
import torch

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

三、模型与数据路径设置

from dataclasses import dataclass, field
from transformers import TrainingArguments

@dataclass
class ModelArguments:
    model_name_or_path: str = "facebook/opt-125m"

@dataclass
class DataArguments:
    data_path: str = './data/alpaca_responses_hh.json'
    stop_response: bool = False

@dataclass
class CustomTrainingArguments(TrainingArguments):
    rrhf_weight: float = 100.0
    length_penalty: float = 1.0
    only_use_provide: bool = False
    only_use_sample: bool = False

四、加载模型与 tokenizer

 
  
from transformers import AutoModelForCausalLM, AutoTokenizer

model_args = ModelArguments()
training_args = CustomTrainingArguments(output_dir="rrhf_model")
data_args = DataArguments()

model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token  # 添加 pad token 兼容批处理

五、定义 RRHF 格式数据集类

from torch.utils.data import Dataset
import json

class ScoreDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        with open(data_path, 'r') as f:
            self.data = [json.loads(line.strip()) for line in f.readlines()[:100]]  # 测试用少量样本
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return dict(input_ids=self.data[i])  # 结构:{query, responses, scores}

六、定义 Collator 构造批数据

from dataclasses import dataclass
import torch
from transformers import PreTrainedTokenizer

@dataclass
class DataCollatorForSupervisedDataset:
    tokenizer: PreTrainedTokenizer
    stop_response: bool = False

    def __call__(self, instances):
        input_ids, labels, idxs, scores = [], [], [], []

        for idx, ins in enumerate(instances):
            query = ins['input_ids']['query']
            responses = ins['input_ids']['responses']
            score_list = ins['input_ids']['scores']
            idxs.append([idx] * len(score_list))
            scores.append(score_list)

            query_ids = self.tokenizer(query, return_tensors="pt", truncation=True)["input_ids"][0]
            dummy_target = torch.LongTensor([-100])

            for res in responses:
                res_ids = self.tokenizer(res + tokenizer.eos_token, return_tensors="pt", truncation=True)["input_ids"][0]
                input_ids.append(torch.cat([query_ids, res_ids], dim=0))
                labels.append(torch.cat([torch.LongTensor([-100] * len(query_ids)), res_ids, dummy_target], dim=0))

        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        return {
            "input_ids": input_ids,
            "attention_mask": input_ids.ne(tokenizer.pad_token_id),
            "labels": labels,
            "idxs": torch.LongTensor(idxs),
            "scores": torch.FloatTensor(scores),
        }

七、构造数据模块

def make_supervised_data_module(tokenizer, data_args):
    dataset = ScoreDataset(tokenizer=tokenizer, data_path=data_args.data_path)
    collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, stop_response=data_args.stop_response)
    return dict(train_dataset=dataset, eval_dataset=None, data_collator=collator)

八、实现自定义 RRHFTrainer

from transformers import Trainer
import torch.nn.functional as F

class RRHFTrainer(Trainer):
    def gather_logits_labels(self, logits, labels):
        mask = labels != -100
        labels[labels == -100] = 0
        return torch.gather(logits, -1, labels.unsqueeze(-1)).squeeze(-1) * mask

    def get_score(self, logit_label, labels):
        mask = labels != -100
        length = mask.sum(-1)
        return logit_label.sum(-1) / (length ** self.args.length_penalty)

    def rrhf_loss(self, scores, idxs, rw_scores):
        diff = scores.unsqueeze(0) - scores.unsqueeze(-1)
        rw_diff = rw_scores.unsqueeze(0) - rw_scores.unsqueeze(-1)
        aval = torch.bitwise_and(rw_diff > 0, diff < 0)[0]
        return -diff[aval].sum()

    def sft_loss(self, logit_label, idxs, rw_scores):
        max_idx = torch.argmax(rw_scores)
        return -logit_label[max_idx].mean()

    def compute_loss(self, model, inputs, return_outputs=False):
        logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']).logits
        logits = F.log_softmax(logits, dim=-1)

        logit_label = self.gather_logits_labels(logits, inputs["labels"])
        scores = self.get_score(logit_label, inputs["labels"])

        rrhf_loss = self.rrhf_loss(scores, inputs["idxs"], inputs["scores"])
        sft_loss = self.sft_loss(logit_label, inputs["idxs"], inputs["scores"])
        loss = self.args.rrhf_weight * rrhf_loss + sft_loss

        return (loss, scores) if return_outputs else loss

九、开始训练

data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

trainer = RRHFTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module
)

trainer.train()

十、总结

        RRHF 是一种轻量级、实用的人类反馈指令微调方法,具备如下优势:

  • ✅ 不需要 RL,不需要 reward model

  • ✅ 可与已有指令数据直接集成使用

  • ✅ 支持少样本/半监督训练

  • ✅ 兼容所有 AutoModelForCausalLM 支持模型

如果你希望在自己项目中部署 RLHF,但资源有限、训练不稳定、代码复杂等成为障碍,RRHF 是你理想的选择。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

你可能感兴趣的:(人工智能,信息系统,自然语言处理,python,自然语言处理,开发语言)