参考资料
参考资料
代码
RLHF训练流程
将人类打好标签的数据输入到一个预训练的模型,进行language modeling。其目的是使经过SFT后的模型能生成合适的输出
模型架构
模型训练的目的是能对chosen的回答打高分,即最大化value_chosen,最小化value_rejected。
#模型的损失函数
loss = value_chosen - value_rejected
loss = -torch.nn.functional.logsigmoid(loss).mean()
该阶段一共有四个主要的模型,分别是:
actor_model(需要训练)
critic_model(需要训练)
reward_model
reference_model
其中:
actor_model和reference_model使用的是SFT阶段训练出来的模型
critic_model和reward_model使用的是RW阶段训练出来的模型
def get_reward(self, input_ids, attention_mask):
value = self.get_value(input_ids, attention_mask)
reward = []
for i, v in zip(input_ids, value):
end = input_ids.shape[1] - 1
if tokenizer.eos_token_id in i:
end = i.tolist().index(tokenizer.eos_token_id)
reward.append(v[end])
reward = torch.stack(reward)
将最后一个词的value定义为reward,reward就是对完整的prompt + response的及时奖励,由reward_model产生
prob_old = model_actor(input_ids=generate,
attention_mask=generate_mask).logits
prob_old = get_prob(prob_old[:, :-1], generate[:, 1:])
prob_ref = model_ref(
input_ids=generate.to('cpu'),
attention_mask=generate_mask.to('cpu')).logits.to('cuda')
prob_ref = get_prob(prob_ref[:, :-1], generate[:, 1:])
reward_kl = -0.1 * (prob_old - prob_ref)
对于actor模型来说,根据自己生成的结果,在response中每一个token都会有对应prob,我们记为prob_old。
对于ref模型来说,对于actor生成的response,在response中每一个token都会有对应prob,我们记为prob_ref。
若pro_ref越大,就说明ref模型对actor模型输出的肯定性越大,即ref认可actor的输出,即actor较ref没有训练歪。从散度理解kl越小,则说明两个分布相似性越高
value_old += reward_kl
delta = []
for i in range(255, value_old.shape[1]):
s = 0
for j in range(i, value_old.shape[1]):
s += value_old[:, j] * 0.95**(j - i)
delta.append(s)
return torch.stack(delta, dim=1)
在PPO算法中,优势(advantage)是用来衡量当前策略比基线策略优越的程度。优势函数的定义为: A ( s , a ) = Q ( s , a ) − V ( s ) A(s,a)=Q(s,a)−V(s) A(s,a)=Q(s,a)−V(s)。
将reward和kl融入value_old中,使在后续训练中模型能考虑到。使用蒙特卡洛采样法,可以使用value_old加权求和来估计优势。
代码中存在了对prob和value的新计算,该做法的原因是,从一个batch中获得信息,需要对四个模型进行推理,若只更新一次actor模型和critic模型,会存在浪费。
所以打算进行ppo_epoch次更新,在该代码中,ppo_epoch为1
def get_loss_actor(prob_new, prob_old, delta, generate_mask):
prob_new = prob_new[:, 255:]
prob_old = prob_old[:, 255:]
generate_mask = generate_mask[:, 256:]
#prob_new -> [4, gen_lens-256]
#prob_old -> [4, gen_lens-256]
#delta -> [4, gen_lens-256]
#generate_mask -> [4, gen_lens-256]
#对数概率,求差就是求商,所以这里求的是新旧概率的变化率
#[4, gen_lens-256]
ratio = ((prob_new - prob_old) * generate_mask).exp()
#delta是估计出来的去基线Q值,以变化率来缩放Q值
#最大化Q值,以此来寻找最优的actor
#裁剪,防止自举,保证prob_old的约束力,一旦变化率过大,则不进行参数更新
#[4, gen_lens-256]
loss1 = delta * ratio
loss2 = delta * ratio.clamp(0.8, 1.2)
loss = torch.min(loss1, loss2) * generate_mask
loss = loss.sum() / generate_mask.sum() / 8
return -loss
ratio的作用是保证新的actor能模仿开始的老actor,为了防止老actor的约束力不足,采取了裁剪,一旦ratio超过了可接受范围,ratio变为常数,则不进行对actor参数更新。
该loss的目的为最大化优势。loss的公式如下图所示
def get_loss_critic(value_new, value_old, delta, generate_mask):
value_new = value_new[:, 255:]
value_old = value_old[:, 255:]
generate_mask = generate_mask[:, 256:]
#value_new -> [4, gen_lens-256]
#value_old -> [4, gen_lens-256]
#delta -> [4, gen_lens-256]
#generate_mask -> [4, gen_lens-256]
#delta是估计出来的去基线Q值,加上value_old后还原为Q值
#value_new和Q值求mse loss即可,因为value都是对Q函数的估计
#裁剪,防止自举
#[4, gen_lens-256]
loss1 = (value_new - delta - value_old)**2
value_new = value_new.clamp(value_old - 0.2, value_old + 0.2)
loss2 = (value_new - delta - value_old)**2
#求平均
loss = torch.max(loss1, loss2) * generate_mask
loss = loss.sum() / 2 / generate_mask.sum() / 8
return loss
在这里我们同样想用旧模型去约束新模型,所以采用了old_value的一个变动范围,去约束新模型。
这个loss的作用是为了减小critic产生预估收益和实际收益的差距