最近wsdm cup到了瓶颈,租卡跑算力成本太高,而lmsys比赛的微调结果也没啥可抄的了,所以只能回头看看top方案,研究了一下阳哥的《Distill is all you need》,和第二名tascj对于训练推理的科技与狠活,有些感觉,伴随着deepseek的大火,蒸馏和强化学习又被端上了台面,我对强化学习暂时没什么兴趣,不过蒸馏跟我最近看的内容相关,在网上搜了一圈关于deepseek针对蒸馏的策略,好像没有过多内容介绍,于是想着总结找到的一些资料。
什么是模型蒸馏? 模型蒸馏即知识蒸馏(Knowledge Distillation),是一种模型压缩和加速技术。在深度学习中,大型深度神经网络虽性能优异,但因计算复杂度高、存储需求大,难以部署在资源受限设备上。模型蒸馏通过构建师生架构,让小的学生模型学习大的教师模型的知识,使学生模型在保持较小规模的同时,尽可能接近教师模型的性能。其核心组件包括知识(如教师模型的 logits、中间层特征等)、蒸馏算法(用于指导知识转移)和师生架构(决定知识传递方式)。
这里可以看比较主流的一张图,出自2021年综述:《Knowledge Distillation: A Survey》,对近年的Distillation做了一个详细概括,Knowledge Distillation的流程可以理解为:
图中除了loss之后会详细说明,唯一的未知点可能在于soft targets,它是经过softmax的下一层级结果logits(原始分数),公式为:
其中是温度系数,从公式中能很明显看出当值较大时,Softmax 输出的概率分布会更加平滑,每个类别的概率值相对更接近;值较小时,概率分布会更尖锐,高概率类别的概率值远高于其他类别。这些 soft targets 会传递给学生模型,学生模型在学习过程中不仅学习真实的hard targets信息,还能从教师模型的 soft targets 中获取类别之间的关联等知识,帮助其更好地训练和泛化。
hard targets 与 soft targets的区别可以从下面的四分类图中很形象的看出:
知识蒸馏有什么意义 实现模型压缩与加速: 模型蒸馏能有效压缩模型大小、降低计算复杂度,提升推理速度。如在论文研究中,通过知识蒸馏将大模型知识转移到小模型,在 CIFAR10 和 CIFAR100 数据集上进行实验,结果表明可实现不同深度模型的压缩,使轻量级学生模型在保持较高准确率的同时,显著减少模型参数和计算量,满足在资源受限设备上的部署需求 。提升模型性能: 帮助学生模型学习到教师模型的有用知识,提高自身性能。在视觉识别、自然语言处理、语音识别等多个领域的研究中发现,知识蒸馏可提升模型在复杂任务中的表现。例如在自然语言处理中,对BERT 模型进行知识蒸馏得到的轻量级模型,在保持较高准确率的同时,推理速度大幅提升,能够高效完成多种语言任务 。解决数据相关问题: 在数据稀缺、存在隐私问题或数据难以获取时,模型蒸馏有独特优势。数据无关蒸馏方法可利用教师模型生成合成数据训练学生模型,避免对大量真实数据的依赖。在涉及敏感数据的场景中,多教师蒸馏可让多个教师模型分别处理不同子集数据,监督学生模型训练,既能保护数据隐私,又能完成模型训练。促进跨领域与跨模态学习: 跨模态蒸馏可实现不同模态间的知识转移,帮助模型更好地处理多模态数据。在一些研究中,将 RGB 图像模态的知识转移到深度图像模态,使模型在不同模态下都能取得较好的性能,拓宽了模型的应用范围。助力终身学习与持续优化: 与终身学习结合,模型蒸馏可帮助模型在新任务学习中保留旧知识,避免灾难性遗忘。在不断出现新数据和新任务的场景下,通过知识蒸馏将已有知识传递给新模型,使模型能够持续学习和优化,提升其适应性和泛化能力。如何做知识蒸馏 做知识蒸馏的方式有非常多,从训练方案流程 来看,就有离线蒸馏、在线蒸馏和自蒸馏 等,从算法更新 角度上,还有对抗蒸馏、多教师蒸馏 等,这里我就不用豆包在灌水了,想查一大片说明,直接以bert时代的蒸馏开始看。
unsetunsettinybertunsetunset TinyBERT是一种轻量级的预训练语言模型,由华为和华中科技大学提出。它通过知识蒸馏技术,将BERT模型的知识迁移到一个更小的模型中,从而实现了模型体积的大幅减小和推理速度的提升。在当时,它提出了两阶段transformer蒸馏 方案:在大规模语料上首先进行通用MLM任务的蒸馏 ,在下游任务 时,先学好老师模型,再进行蒸馏 ,具体如下图:
关于Transformer层蒸馏,主要包括注意力attn的蒸馏和隐藏层hidn的蒸馏:
关于损失函数,TinyBert的蒸馏loss为:
# 蒸馏配置 distill_config = DistillationConfig( # 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好 temperature=self.temperature, # 设置ground truth loss权重 hard_label_weight=self.hard_label_weight, # 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重 kd_loss_type=self.kd_loss_type, kd_loss_weight=self.kd_loss_weight, # 配置中间层蒸馏映射 intermediate_matches=[ # 配置hidden蒸馏映射、维度映射 {'layer_T':0,'layer_S':0,'feature':'hidden','loss':'hidden_mse','weight':1, 'proj': ['linear',312,768]}, # embedding层输出 {'layer_T':3,'layer_S':1,'feature':'hidden','loss':'hidden_mse','weight':1, 'proj': ['linear',312,768]}, {'layer_T':6,'layer_S':2,'feature':'hidden','loss':'hidden_mse','weight':1, 'proj': ['linear',312,768]}, {'layer_T':9,'layer_S':3,'feature':'hidden','loss':'hidden_mse','weight':1, 'proj': ['linear',312,768]}, {'layer_T':12,'layer_S':4,'feature':'hidden','loss':'hidden_mse','weight':1, 'proj': ['linear',312,768]}, # 配置attention矩阵蒸馏映射,注意layer序号从0开始 {"layer_T":2,"layer_S":0,"feature":"attention","loss":"attention_mse","weight":1}, {"layer_T":5,"layer_S":1,"feature":"attention","loss":"attention_mse","weight":1}, {"layer_T":8,"layer_S":2,"feature":"attention","loss":"attention_mse","weight":1}, {"layer_T":11,"layer_S":3,"feature":"attention","loss":"attention_mse","weight":1}, ] ) # 训练配置 optimizer = AdamW(self.student_model.parameters(), lr=self.lr) # 使用大一点的lr train_config = TrainingConfig( output_dir=self.student_model_dir, device=self.student_trainer.device, data_parallel=self.enable_parallel, ckpt_frequency=self.ckpt_frequency # 一个epoch存ckpt_frequency次模型 ) # 配置model中logits hiddens attentions losses的获取方法 defsimple_adaptor(batch, model_outputs): return{ 'logits': model_outputs[-1]['logits'],'hidden': model_outputs[-1]['hiddens'], 'attention': model_outputs[-1]['attentions'],'losses': model_outputs[1], } # 蒸馏 distiller = GeneralDistiller( train_config=train_config, distill_config=distill_config, model_T=self.teacher_model, model_S=self.student_model, adaptor_T=simple_adaptor, adaptor_S=simple_adaptor ) withdistiller: logger.info('start to knowledge distill ...') distiller.train(optimizer, train_dataloader, num_epochs=epoch) logger.info('distill finish')unsetunsetKL散度(**Kullback-Leibler divergence**)unsetunset KL散度的定义是建立在熵(Entropy)的基础上的。此处以离散随机变量为例,若一个离散随机变量的可能取值为,而对应的概率为,则随机变量的熵定义为:
若有两个随机变量,且其概率分布分别为,则相对的相对摘为:
之所以称之为相对熵,是因为其可以通过两随机变量的交叉嫡(Cross-Entropy)以及信息摘推导得到,针对上述离散变量的概率分布而言,其交叉摘定义为:
因此,KL散度或相对熵可通过下式得出:
在上一节中,TinyBERT在设计其蒸馏过程时采用了多种损失函数,包括词向量层损失、中间层损失和预测层损失,在大模型时代下,词向量损失不用多说,因为已经完全做了解耦,如何进行embedding我想看到这里的都知道,中间层损失的不再使用,或者说中间层蒸馏的使用变少,我理解是大模型通常已经具有足够的参数来学习复杂的特征表示,因此它的必要性相对较低,另外就是中间层叠得太厚,所能获得的收益太低,所以不如针对预测层进行相应的改进,那自然,就不得不提本节在介绍的KL散度。
那为什么作为大模型来讲,更多使用KL散度呢?我觉得可以从以下三点考虑:
知识蒸馏的需求 :大模型在进行知识蒸馏时,需要将教师模型的知识传递给学生模型。KL散度能够衡量两个概率分布之间的差异,适合用于衡量教师模型和学生模型之间的输出分布差异。通过最小化KL散度,可以使得学生模型的输出分布尽可能接近教师模型的输出分布。考虑分布的整体差异 :KL散度不仅考虑了预测分布与真实分布之间的交叉熵,还考虑了真实分布的熵。这使得KL散度能够更全面地衡量两个分布之间的差异,适合用于大模型这种需要精细调整输出分布的场景。优化目标的一致性 :在知识蒸馏中,优化KL散度等价于优化交叉熵。但是,KL散度在某些情况下能够提供更稳定的优化目标,尤其是在教师模型和学生模型的输出分布差异较大时。unsetunset前向KL(forward)和后向KL(reverse )unsetunset 上述介绍了KL散度的定义,很明显,KL损失不是一个对称形式,即,那么我们可以试图用近似分布来优化该目标:
Minimizing the forward KL: Minimizing the reverse KL: 根据上一小节的概率公式推导,可以计算出反向 (Reverse KL,RKL )为:
正向 (Forward KL,FKL )为:
其中P是teacher,Q是student,在大模型之前,似乎很多人更喜欢用FKL,正向KL散度(FKL)更受青睐的原因可能与其在传统任务上的表现有关。传统分类任务的输出空间相对较小,模式(即分布的峰值)较少,这意味着分布更倾向于单一峰值而非多峰值分布。在这种情况下,FKL表现良好,因为它倾向于让学生模型关注教师模型输出中概率较高的区域,从而产生更准确的样本。然而,对于大型语言模型(LLM)来说,输出空间更加复杂,模式更多,再使用FKL可能导致学生模型关注教师模型输出中概率较低的区域,从而产生不良样本。
如上图所示,教师模型是蓝色曲线,它的输出是可量化的,这里假设为两个高斯波峰,而黄色,是理想情况下,我们认为学生模型可以近似为正态分布来拟合教师曲线,那么会出现两种结果,一种是尽可能多的包括多峰的面积,第二种是直接拟合最高波峰的分布。所以左边是Forward KL,右边是反向。
中间的一些具体推导过程不过多赘述,近年有非常多的论文对该方案做了benchmark,比如说下图是《f-Divergence Minimization for Sequence-Level Knowledge Distillation》一文的数据:
还有《Rethinking Kullback-Leibler Divergence in Knowledge Distillation for Large Language Models》篇的数据和AKL:
另外说明一下,本节内容就是看了作者在知乎发的《LLM的知识蒸馏(KD)应该用Reverse KL?》一文才有想法撰写本节,对于想复现的小伙伴来讲,可以去看这几篇论文的github,作者还给了一些相应的可视化demo。
unsetunsettrl中的知识蒸馏unsetunset TRL(Transformer Reinforcement Learning)库是用于后续训练基础模型的综合库,专为使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术进行训练后的基础模型而设计。这里我们只看它里面的两种trainer——SFTtrainer和GKDtrainer。
从原理方面来讲:
SFTTrainer: SFTTrainer 即监督微调训练器,主要是对预训练语言模型进行有监督的微调。它利用给定的输入输出对数据,通过最小化模型输出与真实标签之间的损失,让模型学习到特定任务的模式,将预训练模型适配到具体的下游任务。GKDTrainer: GKDTrainer 是用于知识蒸馏的一种训练器,基于知识蒸馏原理,利用教师模型的知识来指导学生模型的训练,使学生模型学习到教师模型的知识,比如输出分布、特征表示等,以提高学生模型的性能。从损失计算方面来讲:
SFTTrainer: 通常计算模型输出与真实标签之间的交叉熵损失等,衡量模型预测结果与实际标注的差异,通过反向传播来更新模型参数,使模型输出尽可能接近真实标签。GKDTrainer: 主要计算学生模型与教师模型输出之间的散度,如 Jensen - Shannon Divergence(JSD)、Kullback - Leibler Divergence(KLD)等,让学生模型学习教师模型的输出分布等知识。这两种顺序非常直观,GKDTrainer继承自SFTTrainer,SFTTrainer继承自Trainer。那从SFTtrainer看,它的调用非常简单,trl的readme直接写了一个demo:
fromtrlimportSFTConfig, SFTTrainer fromdatasetsimportload_dataset dataset = load_dataset("trl-lib/Capybara", split="train") training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT") trainer = SFTTrainer( args=training_args, model="Qwen/Qwen2.5-0.5B", train_dataset=dataset, ) trainer.train()调用该类后,我又去看了下transformers的trainer,它的损失函数为:
defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """ How the loss is computed by Trainer. By default, all models return the loss in the first element. Subclass and override for custom behavior. """ if(self.label_smootherisnotNoneorself.compute_loss_funcisnotNone)and"labels"ininputs: labels = inputs.pop("labels") else: labels =None ifself.model_accepts_loss_kwargs: loss_kwargs = {} ifnum_items_in_batchisnotNone: loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) # Save past state if it exists #TODO:this needs to be fixed and made cleaner later. ifself.args.past_index >=0: self._past = outputs[self.args.past_index] iflabelsisnotNone: unwrapped_model = self.accelerator.unwrap_model(model) if_is_peft_model(unwrapped_model): model_name = unwrapped_model.base_model.model._get_name() else: model_name = unwrapped_model._get_name() # User-defined compute_loss function ifself.compute_loss_funcisnotNone: loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch) elifmodel_nameinMODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) else: ifisinstance(outputs, dict)and"loss"notinoutputs: raiseValueError( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are{','.join(inputs.keys())}." ) # We don't use .loss here since the model may return tuples instead of ModelOutput. loss = outputs["loss"]ifisinstance(outputs, dict)elseoutputs[0] ifself.args.average_tokens_across_devicesandself.model_accepts_loss_kwargs: loss *= self.accelerator.num_processes return(loss, outputs)ifreturn_outputselseloss很显然这部分有非常多的自适应判断,根据我们上一层为SFTtrainer类,并且没有指定loss方法,所以将选用cross-entropy loss作为模型训练参数。
而GKDtrainer类的方式就不一样,由于KL散度是不对称的,在知识蒸馏中使用JSD,Jensen-Shannon Divergence 是基于KL散度改进的更平滑和对称的概率分布度量。论文中给出了其改进的计算公式:
那自然其重写了compute_loss,具体计算为generalized_jsd_loss,代码如下:
defgeneralized_jsd_loss( student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" ): """ Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) of https://huggingface.co/papers/2306.13649 for the definition. Args: student_logits: Tensor of shape (batch_size, sequence_length, vocab_size) teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size) labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss beta: Interpolation coefficient between 0 and 1 (default: 0.5) temperature: Softmax temperature (default: 1.0) reduction: Specifies the reduction to apply to the output (default: 'batchmean') Returns: loss: Scalar tensor with the generalized JSD loss """ # Apply temperature scaling student_logits = student_logits / temperature teacher_logits = teacher_logits / temperature # Compute log probabilities for student and probabilities for teacher student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) # Compute the log of the mixture distribution # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture beta = torch.tensor(beta, dtype=student_log_probs.dtype) mixture_log_probs = torch.logsumexp( torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1- beta)]), dim=0, ) # Compute KL divergences using F.kl_div # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) # Compute the Generalized Jensen-Shannon Divergence jsd = beta * kl_teacher + (1- beta) * kl_student # Masking iflabelsisnotNone: mask = labels !=-100 jsd = jsd[mask] # Apply reduction ifreduction =="batchmean": returnjsd.sum() / mask.sum()iflabelsisnotNoneelsejsd.sum() / (jsd.size(0) * jsd.size(1)) elifreduction =="sum": returnjsd.sum() elifreduction =="mean": returnjsd.mean() else: returnjsd defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # compute student output outputs_student = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # compute teacher output in eval mode self.teacher_model.eval() withtorch.no_grad(): outputs_teacher = self.teacher_model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) # slice the logits for the generated tokens using the inputs["prompts"] lengths prompt_lengths = inputs["prompts"].shape[1] shifted_student_logits = outputs_student.logits[:, prompt_lengths -1:-1, :] shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths -1:-1, :] shifted_labels = inputs["labels"][:, prompt_lengths:] # compute loss loss = self.generalized_jsd_loss( student_logits=shifted_student_logits, teacher_logits=shifted_teacher_logits, labels=shifted_labels, beta=self.beta, ) # empty cache empty_cache() # Return loss return(loss, outputs_student)ifreturn_outputselseloss对于该类好不好用,我也不知道,暂时没用过,只能说从理论来分析,JSD损失和KL损失的区别,不过与SFTtrainer类似,调用方式也很简单,可以跑几次看看情况:
fromdatasetsimportload_dataset importrandom fromtransformersimportAutoTokenizer fromtrlimport( GKDConfig, GKDTrainer, LogCompletionsCallback, ModelConfig, ScriptArguments, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config, ) ################ # Training ################ trainer = GKDTrainer( model=model_config.model_name_or_path, teacher_model=training_args.teacher_model_name_or_path, args=training_args, train_dataset=dataset[args.dataset_train_split], eval_dataset=test_data, processing_class=tokenizer, peft_config=get_peft_config(model_config), ) completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8) trainer.add_callback(completions_callback) trainer.train() # Save trainer.save_model(training_args.output_dir)lmsys方案思考 本节是对阳哥夺冠方案中关于蒸馏部分的经典总结,在这里做一个旁征博引,因为没有算力,具体我也没复现过,不过算是除了写这篇推文的初衷,本来是想做一个top方案亮点汇总,只是因为deepseek的爆火针对其中一个方向做了延展。那话不多说,github原址为:https://github.com/shyoulala/LMSYS_BlackPearl
该仓库的目录结构为:
./model_path # 预训练模型的路径,存放预训练模型的权重和配置文件 ./src_fast # 快速训练脚本的存放位置,可能包含简化的训练代码 ./src # 完整解决方案的代码目录,包含整个项目的完整训练和处理流程 ./data # 数据目录,存放训练数据和其他相关数据 ./data/oof # Out-of-Fold 数据目录,可能用于交叉验证的中间结果 ./data/processed_data # 处理后的数据目录,存放经过预处理的数据 ./data/processed_data/orgemma2fold4 # 训练集,包含用于直接蒸馏的 70b 概率数据(第4折) ./data/processed_data/orgemma2fold2 # 同上,第2折 ./data/processed_data/orgemma2fold0 # 同上,第0折 ./data/processed_data/orgemma2fold1 # 同上,第1折 ./data/processed_data/orgemma2fold3 # 同上,第3折 ./data/lmsys-chatbot-arena # 可能存放与 LMSYS Chatbot Arena 相关的数据或资源 ./sub # 输出目录,用于存放训练结果、预测结果等 ./model_save # 训练模型的保存路径,存放训练完成后的模型文件 ./model_save_or # 另一个模型保存路径,可能是用于存放原始模型或特定版本的模型 ./model_save_or/v7_ut_gemma_v7_64r128_ddgemma2_16bit # 经过后处理(如蒸馏)的模型版本,可能是 Gemma2-9B 的 16bit 版本挺难想象的,大模型时代竟然还能做交叉验证,不过lmsys是个三分类任务,依照之前逻辑也没什么问题,该方案主要是用llama3-70B和Qwen2-72B-instruct对gamma2-9B做蒸馏,所有大致流程,都通过run_pipeline.sh有显现:
#!/bin/bash set -e qwen_path=../model_path/qwen2_72b llama_path=../model_path/llama3_70b gemma_path=../model_path/Gemma2_9b qwen_path_ut=../model_save/qwen2_4bit_pretrain/epoch_0_model/adapter.bin llama_path_ut=../model_save/llama3_4bit_pretrain/epoch_0_model/adapter.bin gemma_path_ut=../model_save/gemma2_4bit_pretrain/epoch_0_model/adapter.bin fold=$1 echo run {fold} # train llama3 70b sh run_fintune.sh llama3 ${llama_path} ${llama_path_ut} ${fold} # predict train logits python predict_train.py ${llama_path} ../model_save/llama3_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/llama3fold${fold}/train.parquet ../data/oof/llama3fold${fold}_train.parquet # train qwen2 70b sh run_fintune.sh qwen2 ${qwen_path} ${qwen_path_ut} ${fold} # predict train logits python predict_train.py ${qwen_path} ../model_save/qwen2_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/qwen2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet # merge logits python merge_logits.py ../data/processed_data/gemma2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet ../data/oof/llama3fold${fold}_train.parquet ../data/processed_data/gemma2fold${fold}/train_logits.parquet # distill fintune gemma2-9b sh run_fintune_16bit_distill.sh gemma2 ${gemma_path} ${gemma_path_ut} ${fold}中间几步有挺多有趣的操作,比如是如何做post train的,以及最后merge logits,这里仅谈蒸馏之前的merge lora,因为代码足够简单:
importtime fromdataclassesimportdataclass importpickle importtorch importsklearn importnumpyasnp importpandasaspd fromtqdm.autoimporttqdm fromtransformersimportGemma2ForSequenceClassification, GemmaTokenizerFast, BitsAndBytesConfig fromtransformers.data.data_collatorimportpad_without_fast_tokenizer_warning frompeftimportget_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType lora_dir ='../model_save/gemma2fold0_16bit_load_fintune/best_val_loss_model/adapter.bin' d1 = torch.load(lora_dir) lora_dir ='../model_save/gemma2fold1_16bit_load_fintune/best_val_loss_model/adapter.bin' d2 = torch.load(lora_dir) lora_dir ='../model_save/gemma2fold2_16bit_load_fintune/best_val_loss_model/adapter.bin' d3 = torch.load(lora_dir) lora_dir ='../model_save/gemma2fold3_16bit_load_fintune/best_val_loss_model/adapter.bin' d4 = torch.load(lora_dir) lora_dir ='../model_save/gemma2fold4_16bit_load_fintune/best_val_loss_model/adapter.bin' d5 = torch.load(lora_dir) d = {} fork, vind1.items(): v = d1[k] + d2[k] + d3[k] + d4[k] + d5[k] v = v /5. d[k] = v torch.save(d,"../model_save/final_adapter.bin")代码上可见,就是对经过5次交叉验证的gamma模型权重做了加权平均合并,但我看discussion很多人提到了,它们同样想到了该方案,不过效果并不好,似乎是这些权重还需要做方差评估,如果方差过大反而会拖累加权后的结果,感兴趣有卡有算力的能进行尝试,我就不过多提了。
回到正题,最终是先得到了llama3和Qwen的模型输出,那么蒸馏即是需要考虑这两者的结果,所以蒸馏损失选择了:
loss_fun = nn.CrossEntropyLoss() divergence_loss_fn = nn.KLDivLoss(reduction='batchmean') cos_loss_fn = nn.CosineEmbeddingLoss() outputs = model(batch['input_ids'], use_cache=False)# predict gemma2 logits = outputs.logits grads = batch['grads'] grads1 = batch['grads'][:, :3]# qwen2 grads2 = batch['grads'][:,3:]# llama3 labels = batch['labels'] loss_ce = loss_fun(logits, labels) loss_grad1 = divergence_loss_fn( F.log_softmax(logits / T, dim=1), F.softmax(grads1 / T, dim=1) ) cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1), torch.ones(logits.size()[0]).to(logits.device)) loss_grad2 = divergence_loss_fn( F.log_softmax(logits / T, dim=1), F.softmax(grads2 / T, dim=1) ) cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1), torch.ones(logits.size()[0]).to(logits.device)) loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) /5.用数学公式理解,即为交叉熵和KL散度的混合:
这里刚开始我不是很理解,然后问了下deepseek懂了:
为什么同时使用交叉熵损失和 KL 散度损失?
1. 保持监督学习能力
交叉熵损失确保学生模型能够正确预测真实标签,从而保持模型的监督学习能力。如果没有交叉熵损失,学生模型可能会过度依赖教师模型的输出,而忽视真实标签的指导,导致模型在真实数据上的性能下降。
2. 学习教师模型的软目标
KL 散度损失让学生模型学习教师模型的软目标,从而捕捉到教师模型的内部表示和知识。软目标通常包含更多的信息,可以帮助学生模型更好地理解数据的分布和特征。
3. 平衡硬标签和软目标
同时使用交叉熵损失和 KL 散度损失可以平衡硬标签和软目标的贡献。硬标签(真实标签)提供了直接的监督信号,而软目标(教师模型的输出)提供了更多的上下文信息。通过调整两者的权重,可以更好地指导学生模型的学习。
其实我认为以上主要的,是因为教师模型是两个,而不是一个,KL更适合于一个,而两个加入交叉熵我的理解为桥接,更能体现泛化,但具体为啥这样安排,只有跑了才知道,所以根据github的环境说明,有8张A100以上的,可以跑一轮,等待3天以上,观看结果了。
open-r1中的蒸馏 该repo是DeepSeek-R1的开放复现版本,由huggingface的CEO亲自提出并进行,我大致看了一下,它的规划是:
步骤 1:从 DeepSeek-R1 中提取高质量语料库来复制 R1-Distill 模型。 步骤 2:复制 DeepSeek 用于创建 R1-Zero 的纯 RL 管道。这可能涉及为数学、推理和代码整理新的大规模数据集。 步骤 3:展示我们可以通过多阶段训练从基础模型转向 RL 调整。 这里重点看step 1,即它使用distilabel来对Deepseek-R1提取蒸馏数据,以下是一个简单demo:
fromdatasetsimportload_dataset fromdistilabel.modelsimportvLLM fromdistilabel.pipelineimportPipeline fromdistilabel.steps.tasksimportTextGeneration prompt_template ="""\ You will be given a problem. Please reason step by step, and put your final answer within \boxed{}: {{ instruction }}""" dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10)) model_id ="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"# Exchange with another smol distilled r1 withPipeline( name="distill-qwen-7b-r1", description="A pipeline to generate data from a distilled r1 model", )aspipeline: llm = vLLM( model=model_id, tokenizer=model_id, extra_kwargs={ "tensor_parallel_size":1, "max_model_len":8192, }, generation_kwargs={ "temperature":0.6, "max_new_tokens":8192, }, ) prompt_column ="problem" text_generation = TextGeneration( llm=llm, template=prompt_template, num_generations=4, input_mappings={"instruction": prompt_column}ifprompt_columnisnotNoneelse{} ) if__name__ =="__main__": distiset = pipeline.run(dataset=dataset) distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")然后将该数据加入了sft中:
defmain(script_args, training_args, model_args): ################ # Model init kwargs & Tokenizer ################ quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, attn_implementation=model_args.attn_implementation, torch_dtype=model_args.torch_dtype, use_cache=Falseiftraining_args.gradient_checkpointingelseTrue, device_map=get_kbit_device_map()ifquantization_configisnotNoneelseNone, quantization_config=quantization_config, ) training_args.model_init_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) tokenizer.pad_token = tokenizer.eos_token ################ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) ################ # Training ################ trainer = SFTTrainer( model=model_args.model_name_or_path, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split]iftraining_args.eval_strategy !="no"elseNone, processing_class=tokenizer, peft_config=get_peft_config(model_args), ) trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) iftraining_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if__name__ =="__main__": parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)从代码上可以看到,这个过程是从教师模型中提取知识,并将其传递给学生模型。在这个特定的情况下,知识不是以软标签的形式直接传递,而是通过生成的推理数据来传递。这种方法通常被称为数据蒸馏(Data Distillation)或示例蒸馏(Example Distillation),它是知识蒸馏的一种变体。
最后 我看到了腾讯科技发布的一场关于DeepSeek的高质量闭门会:比技术更重要的是愿景,里面的很多内容可以作为结尾:
长期来说,通过走捷径的方式,而没有自己通过愿景去想怎么做技术方案,而是直接复现,中间可能会有不知道的坑。比如在这一代技术 long context 没有质变的前提下,解决问题的上限可能会被限制。R1-zero 可能是一个正确的方向,从头就做 R1-zero 或不通过类 o1 的数据启动可能更好。照着别人的技术方案可能不太好,希望更多探索。 蒸馏的坏处是模型 diversity 下降,影响模型上限,无法超越最强的模型。但短期看,蒸馏也是一条路线。其他模型用蒸馏也能得到较好的结果,未来在模型生态里面可能就会有老师、学生的角色区分,有能力当一名好学生也是一种可以的商业模式。