RLHF代码解读

参考资料

参考资料

代码

RLHF训练流程

RLHF代码解读_第1张图片

  1. SFT

  将人类打好标签的数据输入到一个预训练的模型,进行language modeling。其目的是使经过SFT后的模型能生成合适的输出

 

  1. RW

模型架构

RLHF代码解读_第2张图片

  模型训练的目的是能对chosen的回答打高分,即最大化value_chosen,最小化value_rejected。


#模型的损失函数

loss = value_chosen - value_rejected

loss = -torch.nn.functional.logsigmoid(loss).mean()

 

  1. RLHF-PPO

该阶段一共有四个主要的模型,分别是:

  • actor_model(需要训练)

  • critic_model(需要训练)

  • reward_model

  • reference_model

其中:

  • actor_model和reference_model使用的是SFT阶段训练出来的模型

  • critic_model和reward_model使用的是RW阶段训练出来的模型

 

代码分析

def get_reward(self, input_ids, attention_mask):

        value = self.get_value(input_ids, attention_mask)



        reward = []

        for i, v in zip(input_ids, value):

            end = input_ids.shape[1] - 1

            if tokenizer.eos_token_id in i:

                end = i.tolist().index(tokenizer.eos_token_id)

            reward.append(v[end])

        reward = torch.stack(reward)



  将最后一个词的value定义为reward,reward就是对完整的prompt + response的及时奖励,由reward_model产生

  1. 对ref和actor的散度的理解

prob_old = model_actor(input_ids=generate,

                        attention_mask=generate_mask).logits

prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])



prob_ref = model_ref(

        input_ids=generate.to('cpu'),

        attention_mask=generate_mask.to('cpu')).logits.to('cuda')

prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])



reward_kl = -0.1 * (prob_old - prob_ref)

  对于actor模型来说,根据自己生成的结果,在response中每一个token都会有对应prob,我们记为prob_old。

对于ref模型来说,对于actor生成的response,在response中每一个token都会有对应prob,我们记为prob_ref。

若pro_ref越大,就说明ref模型对actor模型输出的肯定性越大,即ref认可actor的输出,即actor较ref没有训练歪。从散度理解kl越小,则说明两个分布相似性越高

  1. 对get_delta函数的理解

    value_old += reward_kl



    delta = []

    for i in range(255, value_old.shape[1]):

        s = 0

        for j in range(i, value_old.shape[1]):

            s += value_old[:, j] * 0.95**(j - i)

        delta.append(s)



    return torch.stack(delta, dim=1)

  在PPO算法中,优势(advantage)是用来衡量当前策略比基线策略优越的程度。优势函数的定义为: A ( s , a ) = Q ( s , a ) − V ( s ) A(s,a)=Q(s,a)−V(s) A(s,a)=Q(s,a)V(s)

将reward和kl融入value_old中,使在后续训练中模型能考虑到。使用蒙特卡洛采样法,可以使用value_old加权求和来估计优势。

  1. 对新旧的理解

  代码中存在了对prob和value的新计算,该做法的原因是,从一个batch中获得信息,需要对四个模型进行推理,若只更新一次actor模型和critic模型,会存在浪费。

所以打算进行ppo_epoch次更新,在该代码中,ppo_epoch为1

  1. 对get_loss_actor的理解

def get_loss_actor(prob_new, prob_old, delta, generate_mask):

    prob_new = prob_new[:, 255:]

    prob_old = prob_old[:, 255:]

    generate_mask = generate_mask[:, 256:]



    #prob_new -> [4, gen_lens-256]

    #prob_old -> [4, gen_lens-256]

    #delta -> [4, gen_lens-256]

    #generate_mask -> [4, gen_lens-256]



    #对数概率,求差就是求商,所以这里求的是新旧概率的变化率

    #[4, gen_lens-256]

    ratio = ((prob_new - prob_old) * generate_mask).exp()



    #delta是估计出来的去基线Q值,以变化率来缩放Q值

    #最大化Q值,以此来寻找最优的actor

    #裁剪,防止自举,保证prob_old的约束力,一旦变化率过大,则不进行参数更新

    #[4, gen_lens-256]

    loss1 = delta * ratio

    loss2 = delta * ratio.clamp(0.8, 1.2)

    loss = torch.min(loss1, loss2) * generate_mask

    loss = loss.sum() / generate_mask.sum() / 8

    return -loss

  ratio的作用是保证新的actor能模仿开始的老actor,为了防止老actor的约束力不足,采取了裁剪,一旦ratio超过了可接受范围,ratio变为常数,则不进行对actor参数更新。

该loss的目的为最大化优势。loss的公式如下图所示

在这里插入图片描述

  1. 对get_loss_crtic函数的理解

def get_loss_critic(value_new, value_old, delta, generate_mask):

    value_new = value_new[:, 255:]

    value_old = value_old[:, 255:]

    generate_mask = generate_mask[:, 256:]



    #value_new -> [4, gen_lens-256]

    #value_old -> [4, gen_lens-256]

    #delta -> [4, gen_lens-256]

    #generate_mask -> [4, gen_lens-256]



    #delta是估计出来的去基线Q值,加上value_old后还原为Q值

    #value_new和Q值求mse loss即可,因为value都是对Q函数的估计

    #裁剪,防止自举

    #[4, gen_lens-256]

    loss1 = (value_new - delta - value_old)**2

    value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)

    loss2 = (value_new - delta - value_old)**2



    #求平均

    loss = torch.max(loss1, loss2) * generate_mask

    loss = loss.sum() / 2 / generate_mask.sum() / 8



    return loss

  在这里我们同样想用旧模型去约束新模型,所以采用了old_value的一个变动范围,去约束新模型。

这个loss的作用是为了减小critic产生预估收益和实际收益的差距

你可能感兴趣的:(大模型学习,语言模型)