返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

深入解析LLM预训练与SFT对齐:Loss函数差异与代码解析

[复制链接]
链载Ai 显示全部楼层 发表于 前天 11:58 |阅读模式 打印 上一主题 下一主题


LLM(Large Language Model,大型语言模型)在预训练和对齐阶段,虽然都使用loss函数来指导模型学习,但两者在loss的设计和目标上存在显著差异。

1. 预训练阶段:

  • 目标: 学习语言的通用表示,掌握语法、语义、知识等。
  • 数据: 海量、未标注的文本数据,例如书籍、网页、代码等。
  • Loss函数: 通常使用自监督学习方法,例如:
    • Masked Language Modeling (MLM): 掩盖句子中部分词语,让模型预测被掩盖的词语。
    • Causal Language Modeling (CLM): 根据前面的词语预测下一个词语。
  • Loss特点:
    • 关注模型对语言结构和知识的理解。
    • 数值较大,因为模型需要学习大量信息。
    • 随着训练的进行,loss逐渐下降,表示模型对语言的理解能力不断提升。
  • 预训练Loss代码:
    • transformers库中的源代码,包含在trainer中的compute_loss,会在预估的prediction_step和training_step函数中被调用,实现的源代码在LabelSmoother类中,具体实现如下:
544@dataclass
545classLabelSmoother:
546"""
547Addslabel-smoothingonapre-computedoutputfromaTransformersmodel.
548
549Args:
550epsilon(`float`,*optional*,defaultsto0.1):
551Thelabelsmoothingfactor.
552ignore_index(`int`,*optional*,defaultsto-100):
553Theindexinthelabelstoignorewhencomputingtheloss.
554"""
555
556epsilon:float=0.1
557ignore_index:int=-100
558
559def__call__(self,model_output,labels,shift_labels=False):
560logits=model_output["logits"]ifisinstance(model_output,dict)elsemodel_output[0]
561ifshift_labels:
562logits=logits[...,:-1,:].contiguous()
563labels=labels[...,1:].contiguous()
564
565log_probs=-nn.functional.log_softmax(logits,dim=-1)
566iflabels.dim()==log_probs.dim()-1:
567labels=labels.unsqueeze(-1)
568
569padding_mask=labels.eq(self.ignore_index)
570#Incasetheignore_indexis-100,thegatherwillfail,sowereplacelabelsby0.Thepadding_mask
571#willignoretheminanycase.
572labels=torch.clamp(labels,min=0)
573nll_loss=log_probs.gather(dim=-1,index=labels)
574#worksforfp16inputtensortoo,byinternallyupcastingittofp32
575smoothed_loss=log_probs.sum(dim=-1,keepdim=True,dtype=torch.float32)
576
577nll_loss.masked_fill_(padding_mask,0.0)
578smoothed_loss.masked_fill_(padding_mask,0.0)
579
580#Takethemeanoverthelabeldimensions,thendividebythenumberofactiveelements(i.e.not-padded):
581num_active_elements=padding_mask.numel()-padding_mask.long().sum()
582nll_loss=nll_loss.sum()/num_active_elements
583smoothed_loss=smoothed_loss.sum()/(num_active_elements*log_probs.shape[-1])
584return(1-self.epsilon)*nll_loss+self.epsilon*smoothed_loss
  • shift_labels:是否需要位移计算,logits = logits[..., :-1, :]预估值,从第0个到倒数第二个,labels = labels[..., 1:]为label,原始文本,从第1个到结束,label中的第0个为输入,预估结果从第一开始,计算loss也是。
  • log_probs:softmax函数的计算,转换为概率分布。
  • padding_mask:padding_mask = labels.eq(self.ignore_index),label中特殊标记token(pad_token_id)为padding,可以计算出padding_mask来,刨除padding_mask外的,参加loss计算。
  • nll_loss:核心计算nll_loss = log_probs.gather(dim=-1, index=labels),log_probs是一个形状为 (batch_size, sequence_length, vocab_size) 的张量,表示模型对每个词的预测概率的对数。
    • batch_size:批处理大小,即一次处理多少个样本。
    • sequence_length:序列长度,即句子中有多少个词。
    • vocab_size:词汇表大小,即模型认识多少个不同的词。
    • 例如,log_probs[0, 2, 500] 表示模型预测第一个样本中第三个词是词汇表中第500个词的概率的对数,从log_probs到nll_loss,主要做了平滑和去掉padding。
    • labels:形状为 (batch_size, sequence_length) 的张量,每个词的真实标签 (ground truth),例如,labels[0, 2] 表示第一个样本中第三个词的真实标签。
    • gather(dim=-1, index=labels):从log_probs获取labels对应位置的值;dim=-1:表示在最后一个维度(即 vocab_size 维度)上进行操作;index=labels:使用 labels 张量作为索引来获取log_probs 中的值。


2. 对齐阶段 (SFT: Supervised Fine-Tuning):

  • 目标: 将预训练模型的能力迁移到特定任务,例如对话生成、文本摘要、机器翻译、LLM落地到垂类业务场景等。
  • 数据: 针对特定任务的标注数据,例如对话记录、摘要文章、翻译文本、垂类业务数据等。
  • Loss函数: 通常使用监督学习方法,根据具体任务选择合适的loss函数,例如:
    • Cross-Entropy Loss: 用于分类任务,例如情感分析、意图识别等。
    • Mean Squared Error (MSE) Loss: 用于回归任务,例如文本评分、机器翻译质量评估等。
  • Loss特点:
    • 关注模型在特定任务上的表现。
    • 数值相对较小,因为模型只需要微调预训练的参数。
    • 随着训练的进行,loss逐渐下降,表示模型在特定任务上的表现不断提升。
  • 对齐sft Loss代码:
    • 对齐sft loss对不同的训练框架实现稍微有些区别,但本质都是一样的,都会先对prompt部分剔除或者mask掉,然后调用预训练transormers库的loss计算,以开源框架LLaMA-Factory中的sft进行解读。
    • LLaMA-Factory中的sft loss计算代码在train/sft/trainer.py中,具体实现在prediction_step函数中,详细如下:
81@override
82defprediction_step(
83self,
84model:"torch.nn.Module",
85inputsict[str,Union["torch.Tensor",Any]],
86prediction_loss_only:bool,
87ignore_keys:Optional[List[str]]=None,
88)->Tuple[Optional[float],Optional["torch.Tensor"],Optional["torch.Tensor"]]:
89r"""
90Removesthepromptpartinthegeneratedtokens.
91
92Subclassandoverridetoinjectcustombehavior.
93"""
94labels=inputs["labels"]if"labels"ininputselseNone
95ifself.args.predict_with_generate:
96assertself.tokenizer.padding_side=="left","Thismethodonlyacceptsleft-paddedtensor."
97labels=labels.detach().clone()iflabelsisnotNoneelseNone#backuplabels
98prompt_len,label_len=inputs["input_ids"].size(-1),inputs["labels"].size(-1)
99ifprompt_len>label_len:
100inputs["labels"]=self._pad_tensors_to_target_len(inputs["labels"],inputs["input_ids"])
101iflabel_len>prompt_len:#truncatethelabelsinsteadofpaddingtheinputs(llama2fp16compatibility)
102inputs["labels"]=inputs["labels"][:,:prompt_len]
103
104loss,generated_tokens,_=super().prediction_step(#ignorethereturnedlabels(maybetruncated)
105model,inputs,prediction_loss_only=prediction_loss_only,ignore_keys=ignore_keys
106)
107ifgenerated_tokensisnotNoneandself.args.predict_with_generate:
108generated_tokens[:,:prompt_len]=self.tokenizer.pad_token_id
109generated_tokens=generated_tokens.contiguous()
110
111returnloss,generated_tokens,labels
112
113def_pad_tensors_to_target_len(self,src_tensor:"torch.Tensor",tgt_tensor:"torch.Tensor")->"torch.Tensor":
114r"""
115Padsthetensortothesamelengthasthetargettensor.
116"""
117assertself.tokenizer.pad_token_idisnotNone,"adtokenisrequired."
118padded_tensor=self.tokenizer.pad_token_id*torch.ones_like(tgt_tensor)
119padded_tensor[:,-src_tensor.shape[-1]:]=src_tensor#adoptleft-padding
120returnpadded_tensor.contiguous()#incontiguousmemory
  • 部分参数含义:因为计算sft计算loss时,prompt部分不参与,需要从labels中刨除掉,使用pad_token_id特殊token掩盖掉prompt
    • padding_side:padding_side == "left" assert为左边padding,否则报错(需要可以自行修改,但需要把一些padding的逻辑都一起改了)。
    • if prompt_len > label_len则在prompt 张量中将prompt部分用pad_token_id特殊token mask掉,具体实现在_pad_tensors_to_target_len中,if label_len > prompt_len 同理。
    • mask掉prompt部分后,调用transformers库基类Trainer的prediction_step,即可计算出输出部分的loss(详细代码看预训练部分loss)。

总结:

特征预训练对齐 (SFT)
目标学习通用语言表示迁移到特定任务
数据海量未标注数据高质量标注数据
Loss函数自监督学习 (MLM, CLM)监督学习 (Cross-Entropy, MSE)
Loss特点数值较大,关注语言理解数值较小,关注任务表现

需要注意的是,以上只是一些常见的区别,实际情况可能更加复杂。例如,有些预训练任务也会使用少量标注数据,而有些对齐任务也会使用自监督学习方法。

总的来说,预训练和对齐阶段的loss函数设计都至关重要,它们共同决定了LLM最终的性能。


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ