返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

LLM实践系列-详聊OpenRLHF中的各种Loss

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 11:50 |阅读模式 打印 上一主题 下一主题

今天给大家来带好友知乎@ybq关于OpenRLHF的学习笔记,主要介绍其中的各种loss内容。

从这篇文章开始,我会不定期分享利用 OpenRLHF 学习 RLHF 的一些心得。我平常读代码喜欢开门见山,直接去看 loss 函数是什么形式,再去理解代码的其他环节,所以就从 loss 开始分享吧。

代码详见:https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/loss.py

基础

在研究 loss 函数前,建议把下面几个公式和图先焊死在脑子中。

SFT 家族

GPTLMLoss

classGPTLMLoss(nn.Module):
"""
GPTLanguageModelLoss
"""

def__init__(self):
super().__init__()
self.IGNORE_INDEX=-100
self.loss=nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)

defforward(self,logits:torch.Tensor,labels:torch.Tensor)->torch.Tensor:
shift_logits=logits[...,:-1,:].contiguous()
shift_labels=labels[...,1:].contiguous()
#Flattenthetokens
returnself.loss(shift_logits.view(-1,shift_logits.size(-1)),shift_labels.view(-1))

没啥多说的,最常见的 gpt loss 函数,也就是 pretrain / sft 的 loss 函数,通过 self.IGNORE_INDEX 来实现 prompt 的 loss_mask 。

KDLoss

#Adaptedfromhttps://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
classKDLoss(nn.Module):
"""
LanguageModelKnowledgeDistillationLoss
"""

def__init__(self):
super().__init__()
self.IGNORE_INDEX=-100

defforward(self,logits:torch.Tensor,teacher_logits:torch.Tensor,label:torch.Tensor)->torch.Tensor:
teacher_probs=F.softmax(teacher_logits,dim=-1,dtype=torch.float32)
inf_mask=torch.isinf(logits)
logprobs=F.log_softmax(logits,dim=-1,dtype=torch.float32)
prod_probs=torch.masked_fill(teacher_probs*logprobs,inf_mask,0)
x=torch.sum(prod_probs,dim=-1).view(-1)
mask=(label!=self.IGNORE_INDEX).int()
distil_loss=-torch.sum(x*mask.view(-1),dim=0)/torch.sum(mask.view(-1),dim=0)

returndistil_loss

第二种 sft 的 loss 函数:知识蒸馏的 loss 函数。需要在同源 tokenizer 的情况下,利用一个大模型的 logits 分布结果,来让小模型学习软标签。当然,embedding 毕竟只是一个线性层,可以考虑再给模型外挂一个线性层,把 model_A 的 tokenizer 映射到 model_B 的 tokenizer,进而实现利用 qwen 蒸馏 llama 的美好愿景,不知道有没有大佬做过类似的尝试。

言归正传,我们都知道知识蒸馏是用 KL 散度作为 loss 函数的,但代码里也没看见 KL 散度公式啊,不妨一起简单推导下。

其中, 是教师模型的概率分布, 是学生模型的概率分布。在实际的优化过程中,KL 散度中的第一项 是关于教师模型的熵,对于学生模型的参数优化是一个常数,可丢弃。因此,我们通常只需要最小化第二项,即交叉熵损失:

知识蒸馏本身没啥痛点,只要能解决 seq_len * vocab_size 大小的 logits 通讯问题,这就是个简单纯粹有效的优化小模型的极佳方案。不过传统的 KL 往往是 soft_label 和 hard_label 的加权组合,这在 OpenRLHF 的代码中没有体现出来,大家有需要的话可以自行实践:

lm_loss=F.cross_entropy(
logits.view(-1,logits.size(-1)),
label.view(-1),
ignore_index=self.IGNORE_INDEX
)
total_loss=alpha*lm_loss+beta*distil_loss

DPO 家族

DPOLoss

classDPOLoss(nn.Module):
"""
DPOLoss
"""

def__init__(self,beta:float,label_smoothing:float=0.0,ipo:bool=False)->None:
super().__init__()
self.beta=beta
self.label_smoothing=label_smoothing
self.ipo=ipo

defforward(
self,
policy_chosen_logps:torch.Tensor,
policy_rejected_logps:torch.Tensor,
reference_chosen_logps:torch.Tensor,
reference_rejected_logps:torch.Tensor,
)->Tuple[torch.Tensor,torch.Tensor,torch.Tensor]:
pi_logratios=policy_chosen_logps-policy_rejected_logps
ref_logratios=reference_chosen_logps-reference_rejected_logps
logits=pi_logratios-ref_logratios

ifself.ipo:
losses=(logits-1/(2*self.beta))**2#Eq.17ofhttps://arxiv.org/pdf/2310.12036v2.pdf
else:
#Eq.3https://ericmitchell.ai/cdpo.pdf;label_smoothing=0givesoriginalDPO(Eq.7ofhttps://arxiv.org/pdf/2305.18290.pdf)
losses=(
-F.logsigmoid(self.beta*logits)*(1-self.label_smoothing)
-F.logsigmoid(-self.beta*logits)*self.label_smoothing
)

loss=losses.mean()
chosen_rewards=self.beta*(policy_chosen_logps-reference_chosen_logps).detach()
rejected_rewards=self.beta*(policy_rejected_logps-reference_rejected_logps).detach()

returnloss,chosen_rewards,rejected_rewards

我们熟悉的 dpo 的 loss 函数,看上去没提供任何 trick,实践中如果 chosen_rewards 和 rejected_rewards 都下降,可以考虑给正例 / 负例再加一个系数。

除了原始的 loss 函数,OpenRLHF 为我们提供了额外两个选项:

IPO:论文中的 loss 表达式

我没有实践过这个算法就不多评价了,似乎重点是加了一个正则项。

CDPO:大概就是给 DPO 加了个 label_smoothing。

标签平滑是一种正则化方法,它通过将硬标签转换为软标签来防止模型过度自信。具体来说,对于二分类问题:原本的样本是正例就是正例,是负例就是负例,平滑后变成了:(1 - self.label_smoothing) 的概率是正例,self.label_smoothing 的概率是负例。

具体在 DPO 算法中的含义,一个 pair 对,以 (1 - self.label_smoothing) 的概率认为 good_sentence 比 bad_sentence 质量高,以 self.label_smoothing 的概率认为 bad_sentence 比 good_sentence 质量高。从而避免了模型对训练数据的过度拟合和过度自信。

理解这个平滑代码实现的关键点在于下面这两个公式,相信负例的 loss 可以动手笔划一下:

  • 相信正例的 loss:
  • 相信负例的 loss:

KTOLoss

#Adaptedfromhttps://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L770
classKTOLoss(nn.Module):
"""
KTOlossforunevensampling
"""

def__init__(
self,beta:float,desirable_weight:float,undesirable_weight:float,world_size:int,device:torch.device
)->None:
super().__init__()
self.beta=beta
self.world_size=world_size
self.device=device
self.desirable_weight=desirable_weight
self.undesirable_weight=undesirable_weight

defforward(
self,
policy_chosen_logps:torch.FloatTensor,
policy_rejected_logps:torch.FloatTensor,
policy_KL_logps:torch.FloatTensor,
reference_chosen_logps:torch.FloatTensor,
reference_rejected_logps:torch.FloatTensor,
reference_KL_logps:torch.FloatTensor,
)->Tuple[torch.FloatTensor,torch.FloatTensor,torch.FloatTensor]:
KL=(policy_KL_logps-reference_KL_logps).mean().detach()
#all_reducesumsuptheKLestimatesacrossalldevices(gradientwillalsobescaledbyworldsize)
dist.all_reduce(KL,op=dist.ReduceOp.SUM)
#takeaverage(willalsoscalegradientsappropriately)
KL=(KL/self.world_size).clamp(min=0)

ifpolicy_chosen_logps.shape[0]!=0:
chosen_logratios=policy_chosen_logps-reference_chosen_logps
chosen_losses=1-F.sigmoid(self.beta*(chosen_logratios-KL))
chosen_rewards=self.beta*chosen_logratios.detach()
else:
#importanttocasttopolicy_dtype;otherwiseerrorwilloccurduringall_gather
chosen_losses=torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
chosen_rewards=torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)

ifpolicy_rejected_logps.shape[0]!=0:
rejected_logratios=policy_rejected_logps-reference_rejected_logps
rejected_losses=1-F.sigmoid(self.beta*(KL-rejected_logratios))
rejected_rewards=self.beta*rejected_logratios.detach()
else:
#importanttocasttopolicy_dtype;otherwiseerrorwilloccurduringall_gather
rejected_losses=torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
rejected_rewards=torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)

losses=torch.cat(
(self.desirable_weight*chosen_losses,self.undesirable_weight*rejected_losses),0
).mean()
returnlosses,chosen_rewards,rejected_rewards,KL

kto 的 loss 函数,对标 dpo 的一个工作。这个算法我之前也没有实践过,但是“kto 不需要偏好 pair 对”这一优点吸引了我,所以也简单研究了一下它的原理和实现代码。

kto 的算法思想说是借鉴了“前景理论”,这个概念对我一个程序员来说太高深了,还是直接去讨论一下算法怎么实现的吧。

kto 的训练数据是 prompt + response + label,这个 label 就是 1 或者 -1,代表着 response 的质量是否被认可。label 是 1 的被称为正例,label 是 -1 的被称为负例。我们看到 loss 函数中要做一个判断 if policy_chosen_logps.shape[0] != 0 的操作,这是因为如果该条训练数据为负例,那么 policy_chosen_logps 这个变量就是一个空 tensor,反之亦然。和 dpo 相比最大的区别是:dpo 的每一条 prompt 需要同时具有正例和负例,kto 的每一条 prompt 则只需要有正例或负例中的一个即可。

kto 正例和负例的 loss 函数分别如下所示:

1 - sigmoid 是一个单调递减函数,这说明:kto 的 loss 函数在正例中鼓励策略模型尽量大于参考点 KL,在负例中则鼓励模型尽量小于参考点 KL,也是一个比较明显的学习正例打压负例的损失函数。self.desirable_weight 和 self.undesirable_weight 则是正向和负向样本各自的权重损失,调参用的。

kto 代码的理解难点是,这个 KL 并不是一条训练样本的 KL,而是一批样本的平均 KL(代码中的 dist.all_reduce),并且为了训练稳定这个 KL 也是不进行反向传播的(代码中的 detach),只是拿来控制损失的饱和度,并且做了 clamp(min=0) 处理。至于这么设计的原因,反正原论文就这么写的,我没具体看公式是怎么推的,不敢瞎分析,感兴趣的可以自己推推公式。

VanillaKTOLoss

#Adaptedfromhttps://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L742
classVanillaKTOLoss(nn.Module):
"""
KTOlossforevensampling
"""

def__init__(self,beta:float)->None:
super().__init__()
self.beta=beta

defforward(
self,
policy_chosen_logps:torch.FloatTensor,
policy_rejected_logps:torch.FloatTensor,
reference_chosen_logps:torch.FloatTensor,
reference_rejected_logps:torch.FloatTensor,
)->Tuple[torch.FloatTensor,torch.FloatTensor,torch.FloatTensor]:
chosen_KL=(policy_chosen_logps-reference_chosen_logps).mean().clamp(min=0)
rejected_KL=(policy_rejected_logps-reference_rejected_logps).mean().clamp(min=0)

chosen_logratios=policy_chosen_logps-reference_chosen_logps
rejected_logratios=policy_rejected_logps-reference_rejected_logps

losses=torch.cat(
(
1-F.sigmoid(self.beta*(chosen_logratios-rejected_KL)),
1-F.sigmoid(self.beta*(chosen_KL-rejected_logratios)),
),
0,
).mean()

chosen_rewards=self.beta*(policy_chosen_logps-reference_chosen_logps).detach()
rejected_rewards=self.beta*(policy_rejected_logps-reference_rejected_logps).detach()
returnlosses,chosen_rewards,rejected_rewards

kto 的变种,看代码实现的话,主要 diff 应该是去掉了参考点 KL ,并且正负样本要 1:1 均衡(OpenRLHF 代码库没写,但是注释里的 https://github.com/ContextualAI/HALOs 这个代码库写了)

乍一看,这个均匀采样 kto 的 loss 函数和 dpo 已经很相似了,但其实还是有本质区别的。dpo 的重点是 margin,也就是正例和负例的 loss 是要做减法的,均匀采样 kto 用的是 torch.cat(),也就是说正例和负例的 loss 相互之间毫无影响,各自朝着各自的 label 去优化。

需要留意的细节是,chosen_KL 和 rejected_KL 也做了 clamp(min=0) 的操作。这里给出我对 RLHF 代码的一条学习心得:不要放过任何一个 clamp / clip 操作背后的原因

RLHF 家族

PolicyLoss

classPolicyLoss(nn.Module):
"""
PolicyLossforPPO
"""

def__init__(self,clip_eps:float=0.2)->None:
super().__init__()
self.clip_eps=clip_eps

defforward(
self,
log_probs:torch.Tensor,
old_log_probs:torch.Tensor,
advantages:torch.Tensor,
action_mask:Optional[torch.Tensor]=None,
)->torch.Tensor:
ratio=(log_probs-old_log_probs).exp()
surr1=ratio*advantages
surr2=ratio.clamp(1-self.clip_eps,1+self.clip_eps)*advantages
loss=-torch.min(surr1,surr2)
loss=masked_mean(loss,action_mask,dim=-1).mean()
returnloss

rlhf 中,actor_model (也就是被优化的模型)的 loss 函数,大概是三个步骤:

  • 计算新旧策略的概率比例
  • 利用优势函数指导更新方向;
  • 限制策略更新幅度
  • -torch.min(surr1, surr2),选择未剪辑和剪辑后损失项的最小值。取负号是因为在最小化损失函数,但 PPO 的目标是最大化期望收益。

代码写的很清晰简洁,和 ppo 论文完全吻合,上面的两个公式也都是 ppo 论文的原始公式。对这里的代码实现有疑惑的,可以结合 ppo 论文一起读。

ValueLoss

classValueLoss(nn.Module):
"""
ValueLossforPPO
"""

def__init__(self,clip_eps:float=None)->None:
super().__init__()
self.clip_eps=clip_eps

defforward(
self,
values:torch.Tensor,
old_values:torch.Tensor,
returns:torch.Tensor,
action_mask:Optional[torch.Tensor]=None,
)->torch.Tensor:
ifself.clip_epsisnotNone:
values_clipped=old_values+(values-old_values).clamp(-self.clip_eps,self.clip_eps)
surr1=(values_clipped-returns)**2
surr2=(values-returns)**2
loss=torch.max(surr1,surr2)
else:
loss=(values-returns)**2

loss=masked_mean(loss,action_mask,dim=-1).mean()
return0.5*loss

rlhf 中,critic_model 的 loss 函数。

  • 如果不剪辑损失函数:
  • 如果要剪辑损失函数,便需要对价值函数的更新进行剪辑,防止价值估计发生过大的变化,和上面的 policy model 的剪辑是一个道理:
  • 剪辑的损失函数:
    。也就是说,如果有剪辑,通过 torch.max(surr1, surr2) 选择让 loss 最大化的更新策略;
  • loss 最终乘以 0.5,也算是平方误差损失函数中常见的缩放因子了。

这里我之前被一个地方绊住过,可以分享一下我曾经的疑惑点:clamp 的意义既然是防止模型进行较大的参数更新,那为什么 value function 的 loss 还要选 torch.max() 呢,不应该是 torch.min() 更合理吗?

我目前的观点:在策略函数中,模型更新幅度的大与小,和 loss 的大小并无直接关系。新的 values 距离 old_values 越近,代表着价值估计的目标更新幅度越小。显然,old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps) 对应的 values_clipped ,是一定比原始的 values 更接近 old_values 的。

surr1,surr2 分别代表使用 values 和 values_clipped 进行模型更新的 loss:

  • surr1 > surr2:说明 clip 的过分了,导致 loss 变小可能会更新不动,那就放弃 clip,选择 values 来更新;
  • surr1 < surr2:说明用更保守的更新策略 values_clipped,得到了更大的 loss。模型期望的更新幅度小,训练动力还大,没有比这更好的事情了。

PairWiseLoss

classPairWiseLoss(nn.Module):
"""
PairwiseLossforRewardModel
"""

defforward(
self,chosen_reward:torch.Tensor,reject_reward:torch.Tensor,margin:torch.Tensor=None
)->torch.Tensor:
ifmarginisnotNone:
loss=-F.logsigmoid(chosen_reward-reject_reward-margin)
else:
loss=-F.logsigmoid(chosen_reward-reject_reward)
returnloss.mean()

主角登场,reward_model 的 loss 函数,以最简单的形式干最多的活!

这里有意思的点是 OpenRLHF 提供了一个 margin 的选项。还记得文章开头给大家画出来的 求导后的曲线吗?x 越小,梯度的绝对值越大,就越能避免梯度消失。原本 positive_reward - negative_reward = 2 的时候,已经没梯度训不动了,现在 positive_reward - negative_reward = 2 + margin 的时候才会训不动。

这个 margin 和 dpo 的 refernece_model 非常类似,都是常量。我曾经疑惑过这种常量是不是没啥大用,后来动手求了求导就明白了:这些被 logsigmoid() 包裹起来的常量,会影响梯度的大小,决定梯度在什么情况下趋近于零,进而也会影响模型训练的动力。

LogExpLoss

classLogExpLoss(nn.Module):
"""
PairwiseLossforRewardModel
Details:https://arxiv.org/abs/2204.05862
"""

defforward(
self,chosen_reward:torch.Tensor,reject_reward:torch.Tensor,margin:torch.Tensor=None
)->torch.Tensor:
loss=torch.log(1+torch.exp(reject_reward-chosen_reward)).mean()
returnloss

想不到吧,reward_model 也有变种。

看上去是用 取代了 ,这妥妥的就是一个等价变化啊。

PRMLoss

classPRMLoss(nn.Module):
"""
ProcessRewardModelLoss
"""

def__init__(self,placeholder_token_id:int,reward_token_ids:Optional[list[int]]=None):
super().__init__()
self.IGNORE_INDEX=-100
self.loss=nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.placeholder_token_id=placeholder_token_id
self.reward_token_ids=reward_token_ids

defforward(self,inputs:torch.Tensor,logits:torch.Tensor,labels:torch.Tensor,*,return_acc:bool=False):
placeholder_mask=inputs==self.placeholder_token_id
logits=logits[placeholder_mask]
labels=labels[placeholder_mask]

iflabels.dtype==torch.float:
#softlabel
assertlen(self.reward_token_ids)==2,"reward_token_idsshouldhave2tokensforsoftlabels"
logits=logits[...,self.reward_token_ids]
positive_labels=labels.to(logits.dtype)
negative_labels=1-positive_labels
negative_labels[positive_labels!=-100]=1-positive_labels[positive_labels!=-100]
labels=torch.stack([positive_labels,negative_labels],dim=-1)
elifself.reward_token_idsisnotNone:
#hardlabelwithreward_token_idsset.(otherwisethewholevocabwillbetrainedtogether.)
logits=logits[...,self.reward_token_ids]
#thisisslow....
fori,tokeninenumerate(self.reward_token_ids):
labels=torch.where(labels==token,i,labels)

loss=self.loss(logits,labels)
ifnotreturn_acc:
returnloss

iflabels.dtype==logits.dtype:
labels=labels.argmax(dim=-1)
acc=(logits.argmax(dim=-1)==labels).float().mean()
returnloss,acc

当红炸子鸡,学界认为的通往 o1 的钥匙,process_reward_model 的 loss 函数。出于对 o1 的尊重,我逐行解读一下这个 loss 。

首先看下 PRM 训练集合的样子:

{
"inputs":"Janetpays$40/hourfor3hoursperweekofclarinetlessonsand$28/hourfor5hoursaweekofpianolessons.Howmuchmoredoesshespendonpianolessonsthanclarinetlessonsinayear?Step1:Janetspends3hours+5hours=<<3+5=8>>8hoursperweekonmusiclessons.киStep2:Shespends40*3=<<40*3=120>>120onclarinetlessonsperweek.киStep3:Shespends28*5=<<28*5=140>>140onpianolessonsperweek.киStep4:Janetspends120+140=<<120+140=260>>260onmusiclessonsperweek.киStep5:Shespends260*52=<<260*52=13520>>13520onmusiclessonsinayear.Theansweris:13520ки",
"labels":"Janetpays$40/hourfor3hoursperweekofclarinetlessonsand$28/hourfor5hoursaweekofpianolessons.Howmuchmoredoesshespendonpianolessonsthanclarinetlessonsinayear?Step1:Janetspends3hours+5hours=<<3+5=8>>8hoursperweekonmusiclessons.+Step2:Shespends40*3=<<40*3=120>>120onclarinetlessonsperweek.+Step3:Shespends28*5=<<28*5=140>>140onpianolessonsperweek.+Step4:Janetspends120+140=<<120+140=260>>260onmusiclessonsperweek.+Step5:Shespends260*52=<<260*52=13520>>13520onmusiclessonsinayear.Theansweris:13520-",
"values":["+","+","+","+","-"]
}
  • 在 inputs 中,每个 step 后面会有一个 special_token:ки
  • 在 labels 中,每个 step 后面会有一个 label_token:+ / - (代表着当前 step 的推理是否正确)
placeholder_mask=inputs==self.placeholder_token_id
logits=logits[placeholder_mask]
labels=labels[placeholder_mask]

logits 就是整个 inputs 过了一遍 llm 后得到的输出,形状为 seq_len * vocab_size (不考虑 batch_size),self.placeholder_token_id 就是 “ки” 对应的 id。使用这几行代码,上面的 case 中,logits 会变成 5 * vocab_size, label 会变成 5 * 1

logits=logits[...,self.reward_token_ids]
fori,tokeninenumerate(self.reward_token_ids):
labels=torch.where(labels==token,i,labels)

紧接着,先理解常规的 hard label,self.reward_token_ids 就是["+"对应的 id, "-"对应的 id],labels 就是["+"对应的 id, "+"对应的 id, "+"对应的 id, "+"对应的 id , "-"对应的 id]。这几行代码成功提取出了每个 step 下,两个 label 各自对应的 logits,以及每个 step 的 label 是什么。

iflabels.dtype==torch.float:
logits=logits[...,self.reward_token_ids]
assertlen(self.reward_token_ids)==2,"reward_token_idsshouldhave2tokensforsoftlabels"
positive_labels=labels.to(logits.dtype)
negative_labels=1-positive_labels
negative_labels[positive_labels!=-100]=1-positive_labels[positive_labels!=-100]
labels=torch.stack([positive_labels,negative_labels],dim=-1)

再理解非常规的 soft_label,此时 labels 不再是 id,而是 float 类型,比如 labels = [0.8, 0.85, 0.9, 0.78, 0.1],代表着每个 step 正确的概率( assert len(self.reward_token_ids) == 2 是为了确保可以通过减法算出 step 错误的概率)。

理解了 soft_label 和 hard_label 分别是如何获得的之后,后面的 loss 计算和 acc 计算就没什么好多说的了。哦对,如果用 qwen 去跑这份代码,会遇到一个 tokenizer 的 bug,请教了下代码作者朱小霖大佬,大概是说不想为了 math-shepherd 加太多冗余逻辑,后续会给出一个优化版的代码(期待ing)

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ