返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

大模型知识蒸馏指南

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 17:08 |阅读模式 打印 上一主题 下一主题

最近wsdm cup到了瓶颈,租卡跑算力成本太高,而lmsys比赛的微调结果也没啥可抄的了,所以只能回头看看top方案,研究了一下阳哥的《Distill is all you need》,和第二名tascj对于训练推理的科技与狠活,有些感觉,伴随着deepseek的大火,蒸馏和强化学习又被端上了台面,我对强化学习暂时没什么兴趣,不过蒸馏跟我最近看的内容相关,在网上搜了一圈关于deepseek针对蒸馏的策略,好像没有过多内容介绍,于是想着总结找到的一些资料。

什么是模型蒸馏?

模型蒸馏即知识蒸馏(Knowledge Distillation),是一种模型压缩和加速技术。在深度学习中,大型深度神经网络虽性能优异,但因计算复杂度高、存储需求大,难以部署在资源受限设备上。模型蒸馏通过构建师生架构,让小的学生模型学习大的教师模型的知识,使学生模型在保持较小规模的同时,尽可能接近教师模型的性能。其核心组件包括知识(如教师模型的 logits、中间层特征等)、蒸馏算法(用于指导知识转移)和师生架构(决定知识传递方式)。

这里可以看比较主流的一张图,出自2021年综述:《Knowledge Distillation: A Survey》,对近年的Distillation做了一个详细概括,Knowledge Distillation的流程可以理解为:

图中除了loss之后会详细说明,唯一的未知点可能在于soft targets,它是经过softmax的下一层级结果logits(原始分数),公式为:

其中是温度系数,从公式中能很明显看出当值较大时,Softmax 输出的概率分布会更加平滑,每个类别的概率值相对更接近;值较小时,概率分布会更尖锐,高概率类别的概率值远高于其他类别。这些 soft targets 会传递给学生模型,学生模型在学习过程中不仅学习真实的hard targets信息,还能从教师模型的 soft targets 中获取类别之间的关联等知识,帮助其更好地训练和泛化。

hard targets 与 soft targets的区别可以从下面的四分类图中很形象的看出:

知识蒸馏有什么意义

  • 实现模型压缩与加速:模型蒸馏能有效压缩模型大小、降低计算复杂度,提升推理速度。如在论文研究中,通过知识蒸馏将大模型知识转移到小模型,在 CIFAR10 和 CIFAR100 数据集上进行实验,结果表明可实现不同深度模型的压缩,使轻量级学生模型在保持较高准确率的同时,显著减少模型参数和计算量,满足在资源受限设备上的部署需求 。
  • 提升模型性能:帮助学生模型学习到教师模型的有用知识,提高自身性能。在视觉识别、自然语言处理、语音识别等多个领域的研究中发现,知识蒸馏可提升模型在复杂任务中的表现。例如在自然语言处理中,对BERT 模型进行知识蒸馏得到的轻量级模型,在保持较高准确率的同时,推理速度大幅提升,能够高效完成多种语言任务 。
  • 解决数据相关问题:在数据稀缺、存在隐私问题或数据难以获取时,模型蒸馏有独特优势。数据无关蒸馏方法可利用教师模型生成合成数据训练学生模型,避免对大量真实数据的依赖。在涉及敏感数据的场景中,多教师蒸馏可让多个教师模型分别处理不同子集数据,监督学生模型训练,既能保护数据隐私,又能完成模型训练。
  • 促进跨领域与跨模态学习:跨模态蒸馏可实现不同模态间的知识转移,帮助模型更好地处理多模态数据。在一些研究中,将 RGB 图像模态的知识转移到深度图像模态,使模型在不同模态下都能取得较好的性能,拓宽了模型的应用范围。
  • 助力终身学习与持续优化:与终身学习结合,模型蒸馏可帮助模型在新任务学习中保留旧知识,避免灾难性遗忘。在不断出现新数据和新任务的场景下,通过知识蒸馏将已有知识传递给新模型,使模型能够持续学习和优化,提升其适应性和泛化能力。

如何做知识蒸馏

做知识蒸馏的方式有非常多,从训练方案流程来看,就有离线蒸馏、在线蒸馏和自蒸馏等,从算法更新角度上,还有对抗蒸馏、多教师蒸馏等,这里我就不用豆包在灌水了,想查一大片说明,直接以bert时代的蒸馏开始看。

unsetunsettinybertunsetunset

TinyBERT是一种轻量级的预训练语言模型,由华为和华中科技大学提出。它通过知识蒸馏技术,将BERT模型的知识迁移到一个更小的模型中,从而实现了模型体积的大幅减小和推理速度的提升。在当时,它提出了两阶段transformer蒸馏方案:在大规模语料上首先进行通用MLM任务的蒸馏,在下游任务时,先学好老师模型,再进行蒸馏,具体如下图:

关于Transformer层蒸馏,主要包括注意力attn的蒸馏和隐藏层hidn的蒸馏:

关于损失函数,TinyBert的蒸馏loss为:

  1. 第一项:词向量层损失

  • 计算学生词向量和老师词向量的均方误差:
  • 因为和的维度末必一致,这里需要参数做映射
  • 第二项:中间层损失

    • 学生第 i 层多头注意力矩阵和老师第 j 层多头注意力矩阵计算MSE, K 为注意力的head数
    • 学生的第 i 层隐层输出和 老师的第 j 层隐层输出计算MSE,用做映射
    • 若学生4层,老师12层,则老师的 (3,6,9,12) 层分别蒸馏到学生的 (1,2,3,4) 层。
    • 中间层的损失由隐层均方误差损失和注意力损失组成:
    • 隐层均方误差损失:
    • 注意力损失:
  • 第三项:预测层损失

    • 学生学习老师的soft label
    • 并计算交叉熵:

    如果有不清晰的,可以去看论文原文,我就不做过多解释了,上述的内容根据论文开源的github地址,其中对于蒸馏训练的截取部分,可进行一一对照:

    # 蒸馏配置
    distill_config = DistillationConfig(
    # 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好
    temperature=self.temperature,
    # 设置ground truth loss权重
    hard_label_weight=self.hard_label_weight,
    # 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重
    kd_loss_type=self.kd_loss_type, kd_loss_weight=self.kd_loss_weight,
    # 配置中间层蒸馏映射
    intermediate_matches=[
    # 配置hidden蒸馏映射、维度映射
    {'layer_T':0,'layer_S':0,'feature':'hidden','loss':'hidden_mse','weight':1,
    'proj': ['linear',312,768]}, # embedding层输出
    {'layer_T':3,'layer_S':1,'feature':'hidden','loss':'hidden_mse','weight':1,
    'proj': ['linear',312,768]},
    {'layer_T':6,'layer_S':2,'feature':'hidden','loss':'hidden_mse','weight':1,
    'proj': ['linear',312,768]},
    {'layer_T':9,'layer_S':3,'feature':'hidden','loss':'hidden_mse','weight':1,
    'proj': ['linear',312,768]},
    {'layer_T':12,'layer_S':4,'feature':'hidden','loss':'hidden_mse','weight':1,
    'proj': ['linear',312,768]},
    # 配置attention矩阵蒸馏映射,注意layer序号从0开始
    {"layer_T":2,"layer_S":0,"feature":"attention","loss":"attention_mse","weight":1},
    {"layer_T":5,"layer_S":1,"feature":"attention","loss":"attention_mse","weight":1},
    {"layer_T":8,"layer_S":2,"feature":"attention","loss":"attention_mse","weight":1},
    {"layer_T":11,"layer_S":3,"feature":"attention","loss":"attention_mse","weight":1},
    ]
    )

    # 训练配置
    optimizer = AdamW(self.student_model.parameters(), lr=self.lr) # 使用大一点的lr
    train_config = TrainingConfig(
    output_dir=self.student_model_dir, device=self.student_trainer.device,
    data_parallel=self.enable_parallel, ckpt_frequency=self.ckpt_frequency # 一个epoch存ckpt_frequency次模型
    )

    # 配置model中logits hiddens attentions losses的获取方法
    defsimple_adaptor(batch, model_outputs):
    return{
    'logits': model_outputs[-1]['logits'],'hidden': model_outputs[-1]['hiddens'],
    'attention': model_outputs[-1]['attentions'],'losses': model_outputs[1],
    }

    # 蒸馏
    distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=self.teacher_model, model_S=self.student_model,
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor
    )
    withdistiller:
    logger.info('start to knowledge distill ...')
    distiller.train(optimizer, train_dataloader, num_epochs=epoch)
    logger.info('distill finish')

    unsetunsetKL散度(**Kullback-Leibler divergence**)unsetunset

    KL散度的定义是建立在熵(Entropy)的基础上的。此处以离散随机变量为例,若一个离散随机变量的可能取值为,而对应的概率为,则随机变量的熵定义为:

    若有两个随机变量,且其概率分布分别为,则相对的相对摘为:

    之所以称之为相对熵,是因为其可以通过两随机变量的交叉嫡(Cross-Entropy)以及信息摘推导得到,针对上述离散变量的概率分布而言,其交叉摘定义为:

    因此,KL散度或相对熵可通过下式得出:

    在上一节中,TinyBERT在设计其蒸馏过程时采用了多种损失函数,包括词向量层损失、中间层损失和预测层损失,在大模型时代下,词向量损失不用多说,因为已经完全做了解耦,如何进行embedding我想看到这里的都知道,中间层损失的不再使用,或者说中间层蒸馏的使用变少,我理解是大模型通常已经具有足够的参数来学习复杂的特征表示,因此它的必要性相对较低,另外就是中间层叠得太厚,所能获得的收益太低,所以不如针对预测层进行相应的改进,那自然,就不得不提本节在介绍的KL散度。

    那为什么作为大模型来讲,更多使用KL散度呢?我觉得可以从以下三点考虑:

    1. 知识蒸馏的需求:大模型在进行知识蒸馏时,需要将教师模型的知识传递给学生模型。KL散度能够衡量两个概率分布之间的差异,适合用于衡量教师模型和学生模型之间的输出分布差异。通过最小化KL散度,可以使得学生模型的输出分布尽可能接近教师模型的输出分布。
    2. 考虑分布的整体差异:KL散度不仅考虑了预测分布与真实分布之间的交叉熵,还考虑了真实分布的熵。这使得KL散度能够更全面地衡量两个分布之间的差异,适合用于大模型这种需要精细调整输出分布的场景。
    3. 优化目标的一致性:在知识蒸馏中,优化KL散度等价于优化交叉熵。但是,KL散度在某些情况下能够提供更稳定的优化目标,尤其是在教师模型和学生模型的输出分布差异较大时。

    unsetunset前向KL(forward)和后向KL(reverse)unsetunset

    上述介绍了KL散度的定义,很明显,KL损失不是一个对称形式,即,那么我们可以试图用近似分布来优化该目标:

    1. Minimizing the forward KL:
    2. Minimizing the reverse KL:

    根据上一小节的概率公式推导,可以计算出反向 (Reverse KL,RKL)为:

    正向 (Forward KL,FKL)为:

    其中P是teacher,Q是student,在大模型之前,似乎很多人更喜欢用FKL,正向KL散度(FKL)更受青睐的原因可能与其在传统任务上的表现有关。传统分类任务的输出空间相对较小,模式(即分布的峰值)较少,这意味着分布更倾向于单一峰值而非多峰值分布。在这种情况下,FKL表现良好,因为它倾向于让学生模型关注教师模型输出中概率较高的区域,从而产生更准确的样本。然而,对于大型语言模型(LLM)来说,输出空间更加复杂,模式更多,再使用FKL可能导致学生模型关注教师模型输出中概率较低的区域,从而产生不良样本。

    如上图所示,教师模型是蓝色曲线,它的输出是可量化的,这里假设为两个高斯波峰,而黄色,是理想情况下,我们认为学生模型可以近似为正态分布来拟合教师曲线,那么会出现两种结果,一种是尽可能多的包括多峰的面积,第二种是直接拟合最高波峰的分布。所以左边是Forward KL,右边是反向。

    中间的一些具体推导过程不过多赘述,近年有非常多的论文对该方案做了benchmark,比如说下图是《f-Divergence Minimization for Sequence-Level Knowledge Distillation》一文的数据:

    还有《Rethinking Kullback-Leibler Divergence in Knowledge Distillation for Large Language Models》篇的数据和AKL:

    另外说明一下,本节内容就是看了作者在知乎发的《LLM的知识蒸馏(KD)应该用Reverse KL?》一文才有想法撰写本节,对于想复现的小伙伴来讲,可以去看这几篇论文的github,作者还给了一些相应的可视化demo。

    unsetunsettrl中的知识蒸馏unsetunset

    TRL(Transformer Reinforcement Learning)库是用于后续训练基础模型的综合库,专为使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术进行训练后的基础模型而设计。这里我们只看它里面的两种trainer——SFTtrainer和GKDtrainer。

    从原理方面来讲:

    • SFTTrainer:SFTTrainer 即监督微调训练器,主要是对预训练语言模型进行有监督的微调。它利用给定的输入输出对数据,通过最小化模型输出与真实标签之间的损失,让模型学习到特定任务的模式,将预训练模型适配到具体的下游任务。
    • GKDTrainer:GKDTrainer 是用于知识蒸馏的一种训练器,基于知识蒸馏原理,利用教师模型的知识来指导学生模型的训练,使学生模型学习到教师模型的知识,比如输出分布、特征表示等,以提高学生模型的性能。

    从损失计算方面来讲:

    • SFTTrainer:通常计算模型输出与真实标签之间的交叉熵损失等,衡量模型预测结果与实际标注的差异,通过反向传播来更新模型参数,使模型输出尽可能接近真实标签。
    • GKDTrainer:主要计算学生模型与教师模型输出之间的散度,如 Jensen - Shannon Divergence(JSD)、Kullback - Leibler Divergence(KLD)等,让学生模型学习教师模型的输出分布等知识。

    这两种顺序非常直观,GKDTrainer继承自SFTTrainer,SFTTrainer继承自Trainer。那从SFTtrainer看,它的调用非常简单,trl的readme直接写了一个demo:

    fromtrlimportSFTConfig, SFTTrainer
    fromdatasetsimportload_dataset

    dataset = load_dataset("trl-lib/Capybara", split="train")

    training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
    trainer = SFTTrainer(
    args=training_args,
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
    )
    trainer.train()

    调用该类后,我又去看了下transformers的trainer,它的损失函数为:

     defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    """
    How the loss is computed by Trainer. By default, all models return the loss in the first element.

    Subclass and override for custom behavior.
    """
    if(self.label_smootherisnotNoneorself.compute_loss_funcisnotNone)and"labels"ininputs:
    labels = inputs.pop("labels")
    else:
    labels =None
    ifself.model_accepts_loss_kwargs:
    loss_kwargs = {}
    ifnum_items_in_batchisnotNone:
    loss_kwargs["num_items_in_batch"] = num_items_in_batch
    inputs = {**inputs, **loss_kwargs}
    outputs = model(**inputs)
    # Save past state if it exists
    #TODO:this needs to be fixed and made cleaner later.
    ifself.args.past_index >=0:
    self._past = outputs[self.args.past_index]

    iflabelsisnotNone:
    unwrapped_model = self.accelerator.unwrap_model(model)
    if_is_peft_model(unwrapped_model):
    model_name = unwrapped_model.base_model.model._get_name()
    else:
    model_name = unwrapped_model._get_name()
    # User-defined compute_loss function
    ifself.compute_loss_funcisnotNone:
    loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
    elifmodel_nameinMODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
    loss = self.label_smoother(outputs, labels, shift_labels=True)
    else:
    loss = self.label_smoother(outputs, labels)
    else:
    ifisinstance(outputs, dict)and"loss"notinoutputs:
    raiseValueError(
    "The model did not return a loss from the inputs, only the following keys: "
    f"{','.join(outputs.keys())}. For reference, the inputs it received are{','.join(inputs.keys())}."
    )
    # We don't use .loss here since the model may return tuples instead of ModelOutput.
    loss = outputs["loss"]ifisinstance(outputs, dict)elseoutputs[0]

    ifself.args.average_tokens_across_devicesandself.model_accepts_loss_kwargs:
    loss *= self.accelerator.num_processes

    return(loss, outputs)ifreturn_outputselseloss

    很显然这部分有非常多的自适应判断,根据我们上一层为SFTtrainer类,并且没有指定loss方法,所以将选用cross-entropy loss作为模型训练参数。

    而GKDtrainer类的方式就不一样,由于KL散度是不对称的,在知识蒸馏中使用JSD,Jensen-Shannon Divergence 是基于KL散度改进的更平滑和对称的概率分布度量。论文中给出了其改进的计算公式:

    那自然其重写了compute_loss,具体计算为generalized_jsd_loss,代码如下:

     defgeneralized_jsd_loss(
    student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
    ):
    """
    Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
    of https://huggingface.co/papers/2306.13649 for the definition.

    Args:
    student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
    teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
    labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
    beta: Interpolation coefficient between 0 and 1 (default: 0.5)
    temperature: Softmax temperature (default: 1.0)
    reduction: Specifies the reduction to apply to the output (default: 'batchmean')

    Returns:
    loss: Scalar tensor with the generalized JSD loss
    """

    # Apply temperature scaling
    student_logits = student_logits / temperature
    teacher_logits = teacher_logits / temperature

    # Compute log probabilities for student and probabilities for teacher
    student_log_probs = F.log_softmax(student_logits, dim=-1)
    teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

    # Compute the log of the mixture distribution
    # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
    beta = torch.tensor(beta, dtype=student_log_probs.dtype)
    mixture_log_probs = torch.logsumexp(
    torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1- beta)]),
    dim=0,
    )

    # Compute KL divergences using F.kl_div
    # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
    kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
    kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)

    # Compute the Generalized Jensen-Shannon Divergence
    jsd = beta * kl_teacher + (1- beta) * kl_student

    # Masking
    iflabelsisnotNone:
    mask = labels !=-100
    jsd = jsd[mask]

    # Apply reduction
    ifreduction =="batchmean":
    returnjsd.sum() / mask.sum()iflabelsisnotNoneelsejsd.sum() / (jsd.size(0) * jsd.size(1))
    elifreduction =="sum":
    returnjsd.sum()
    elifreduction =="mean":
    returnjsd.mean()
    else:
    returnjsd

    defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    # compute student output
    outputs_student = model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    )

    # compute teacher output in eval mode
    self.teacher_model.eval()
    withtorch.no_grad():
    outputs_teacher = self.teacher_model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    )

    # slice the logits for the generated tokens using the inputs["prompts"] lengths
    prompt_lengths = inputs["prompts"].shape[1]
    shifted_student_logits = outputs_student.logits[:, prompt_lengths -1:-1, :]
    shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths -1:-1, :]
    shifted_labels = inputs["labels"][:, prompt_lengths:]

    # compute loss
    loss = self.generalized_jsd_loss(
    student_logits=shifted_student_logits,
    teacher_logits=shifted_teacher_logits,
    labels=shifted_labels,
    beta=self.beta,
    )

    # empty cache
    empty_cache()

    # Return loss
    return(loss, outputs_student)ifreturn_outputselseloss

    对于该类好不好用,我也不知道,暂时没用过,只能说从理论来分析,JSD损失和KL损失的区别,不过与SFTtrainer类似,调用方式也很简单,可以跑几次看看情况:

    fromdatasetsimportload_dataset
    importrandom
    fromtransformersimportAutoTokenizer
    fromtrlimport(
    GKDConfig,
    GKDTrainer,
    LogCompletionsCallback,
    ModelConfig,
    ScriptArguments,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    )

    ################
    # Training
    ################
    trainer = GKDTrainer(
    model=model_config.model_name_or_path,
    teacher_model=training_args.teacher_model_name_or_path,
    args=training_args,
    train_dataset=dataset[args.dataset_train_split],
    eval_dataset=test_data,
    processing_class=tokenizer,
    peft_config=get_peft_config(model_config),
    )
    completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
    trainer.add_callback(completions_callback)
    trainer.train()

    # Save
    trainer.save_model(training_args.output_dir)

    lmsys方案思考

    本节是对阳哥夺冠方案中关于蒸馏部分的经典总结,在这里做一个旁征博引,因为没有算力,具体我也没复现过,不过算是除了写这篇推文的初衷,本来是想做一个top方案亮点汇总,只是因为deepseek的爆火针对其中一个方向做了延展。那话不多说,github原址为:https://github.com/shyoulala/LMSYS_BlackPearl

    该仓库的目录结构为:

    ./model_path # 预训练模型的路径,存放预训练模型的权重和配置文件
    ./src_fast # 快速训练脚本的存放位置,可能包含简化的训练代码
    ./src # 完整解决方案的代码目录,包含整个项目的完整训练和处理流程
    ./data # 数据目录,存放训练数据和其他相关数据
    ./data/oof # Out-of-Fold 数据目录,可能用于交叉验证的中间结果
    ./data/processed_data # 处理后的数据目录,存放经过预处理的数据
    ./data/processed_data/orgemma2fold4 # 训练集,包含用于直接蒸馏的 70b 概率数据(第4折)
    ./data/processed_data/orgemma2fold2 # 同上,第2折
    ./data/processed_data/orgemma2fold0 # 同上,第0折
    ./data/processed_data/orgemma2fold1 # 同上,第1折
    ./data/processed_data/orgemma2fold3 # 同上,第3折
    ./data/lmsys-chatbot-arena # 可能存放与 LMSYS Chatbot Arena 相关的数据或资源
    ./sub # 输出目录,用于存放训练结果、预测结果等
    ./model_save # 训练模型的保存路径,存放训练完成后的模型文件
    ./model_save_or # 另一个模型保存路径,可能是用于存放原始模型或特定版本的模型
    ./model_save_or/v7_ut_gemma_v7_64r128_ddgemma2_16bit # 经过后处理(如蒸馏)的模型版本,可能是 Gemma2-9B 的 16bit 版本

    挺难想象的,大模型时代竟然还能做交叉验证,不过lmsys是个三分类任务,依照之前逻辑也没什么问题,该方案主要是用llama3-70B和Qwen2-72B-instruct对gamma2-9B做蒸馏,所有大致流程,都通过run_pipeline.sh有显现:

    #!/bin/bash
    set -e

    qwen_path=../model_path/qwen2_72b
    llama_path=../model_path/llama3_70b
    gemma_path=../model_path/Gemma2_9b

    qwen_path_ut=../model_save/qwen2_4bit_pretrain/epoch_0_model/adapter.bin
    llama_path_ut=../model_save/llama3_4bit_pretrain/epoch_0_model/adapter.bin
    gemma_path_ut=../model_save/gemma2_4bit_pretrain/epoch_0_model/adapter.bin


    fold=$1
    echo run{fold}
    # train llama3 70b
    sh run_fintune.sh llama3 ${llama_path} ${llama_path_ut} ${fold}
    # predict train logits
    python predict_train.py ${llama_path} ../model_save/llama3_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/llama3fold${fold}/train.parquet ../data/oof/llama3fold${fold}_train.parquet

    # train qwen2 70b
    sh run_fintune.sh qwen2 ${qwen_path} ${qwen_path_ut} ${fold}
    # predict train logits
    python predict_train.py ${qwen_path} ../model_save/qwen2_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/qwen2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet

    # merge logits
    python merge_logits.py ../data/processed_data/gemma2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet ../data/oof/llama3fold${fold}_train.parquet ../data/processed_data/gemma2fold${fold}/train_logits.parquet

    # distill fintune gemma2-9b
    sh run_fintune_16bit_distill.sh gemma2 ${gemma_path} ${gemma_path_ut} ${fold}

    中间几步有挺多有趣的操作,比如是如何做post train的,以及最后merge logits,这里仅谈蒸馏之前的merge lora,因为代码足够简单:

    importtime
    fromdataclassesimportdataclass
    importpickle
    importtorch
    importsklearn
    importnumpyasnp
    importpandasaspd
    fromtqdm.autoimporttqdm
    fromtransformersimportGemma2ForSequenceClassification, GemmaTokenizerFast, BitsAndBytesConfig
    fromtransformers.data.data_collatorimportpad_without_fast_tokenizer_warning
    frompeftimportget_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

    lora_dir ='../model_save/gemma2fold0_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d1 = torch.load(lora_dir)
    lora_dir ='../model_save/gemma2fold1_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d2 = torch.load(lora_dir)
    lora_dir ='../model_save/gemma2fold2_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d3 = torch.load(lora_dir)
    lora_dir ='../model_save/gemma2fold3_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d4 = torch.load(lora_dir)
    lora_dir ='../model_save/gemma2fold4_16bit_load_fintune/best_val_loss_model/adapter.bin'
    d5 = torch.load(lora_dir)

    d = {}
    fork, vind1.items():
    v = d1[k] + d2[k] + d3[k] + d4[k] + d5[k]
    v = v /5.
    d[k] = v
    torch.save(d,"../model_save/final_adapter.bin")

    代码上可见,就是对经过5次交叉验证的gamma模型权重做了加权平均合并,但我看discussion很多人提到了,它们同样想到了该方案,不过效果并不好,似乎是这些权重还需要做方差评估,如果方差过大反而会拖累加权后的结果,感兴趣有卡有算力的能进行尝试,我就不过多提了。

    回到正题,最终是先得到了llama3和Qwen的模型输出,那么蒸馏即是需要考虑这两者的结果,所以蒸馏损失选择了:

    loss_fun = nn.CrossEntropyLoss()
    divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
    cos_loss_fn = nn.CosineEmbeddingLoss()
    outputs = model(batch['input_ids'], use_cache=False)# predict gemma2
    logits = outputs.logits
    grads = batch['grads']
    grads1 = batch['grads'][:, :3]# qwen2
    grads2 = batch['grads'][:,3:]# llama3
    labels = batch['labels']
    loss_ce = loss_fun(logits, labels)
    loss_grad1 = divergence_loss_fn(
    F.log_softmax(logits / T, dim=1),
    F.softmax(grads1 / T, dim=1)
    )
    cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
    torch.ones(logits.size()[0]).to(logits.device))

    loss_grad2 = divergence_loss_fn(
    F.log_softmax(logits / T, dim=1),
    F.softmax(grads2 / T, dim=1)
    )
    cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
    torch.ones(logits.size()[0]).to(logits.device))

    loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) /5.

    用数学公式理解,即为交叉熵和KL散度的混合:

    这里刚开始我不是很理解,然后问了下deepseek懂了:

    为什么同时使用交叉熵损失和 KL 散度损失?

    1. 保持监督学习能力

    交叉熵损失确保学生模型能够正确预测真实标签,从而保持模型的监督学习能力。如果没有交叉熵损失,学生模型可能会过度依赖教师模型的输出,而忽视真实标签的指导,导致模型在真实数据上的性能下降。

    2. 学习教师模型的软目标

    KL 散度损失让学生模型学习教师模型的软目标,从而捕捉到教师模型的内部表示和知识。软目标通常包含更多的信息,可以帮助学生模型更好地理解数据的分布和特征。

    3. 平衡硬标签和软目标

    同时使用交叉熵损失和 KL 散度损失可以平衡硬标签和软目标的贡献。硬标签(真实标签)提供了直接的监督信号,而软目标(教师模型的输出)提供了更多的上下文信息。通过调整两者的权重,可以更好地指导学生模型的学习。

    其实我认为以上主要的,是因为教师模型是两个,而不是一个,KL更适合于一个,而两个加入交叉熵我的理解为桥接,更能体现泛化,但具体为啥这样安排,只有跑了才知道,所以根据github的环境说明,有8张A100以上的,可以跑一轮,等待3天以上,观看结果了。

    open-r1中的蒸馏

    该repo是DeepSeek-R1的开放复现版本,由huggingface的CEO亲自提出并进行,我大致看了一下,它的规划是:

    • 步骤 1:从 DeepSeek-R1 中提取高质量语料库来复制 R1-Distill 模型。
    • 步骤 2:复制 DeepSeek 用于创建 R1-Zero 的纯 RL 管道。这可能涉及为数学、推理和代码整理新的大规模数据集。
    • 步骤 3:展示我们可以通过多阶段训练从基础模型转向 RL 调整。

    这里重点看step 1,即它使用distilabel来对Deepseek-R1提取蒸馏数据,以下是一个简单demo:

    fromdatasetsimportload_dataset
    fromdistilabel.modelsimportvLLM
    fromdistilabel.pipelineimportPipeline
    fromdistilabel.steps.tasksimportTextGeneration


    prompt_template ="""\
    You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
    {{ instruction }}"""

    dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10))

    model_id ="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"# Exchange with another smol distilled r1

    withPipeline(
    name="distill-qwen-7b-r1",
    description="A pipeline to generate data from a distilled r1 model",
    )aspipeline:

    llm = vLLM(
    model=model_id,
    tokenizer=model_id,
    extra_kwargs={
    "tensor_parallel_size":1,
    "max_model_len":8192,
    },
    generation_kwargs={
    "temperature":0.6,
    "max_new_tokens":8192,
    },
    )
    prompt_column ="problem"
    text_generation = TextGeneration(
    llm=llm,
    template=prompt_template,
    num_generations=4,
    input_mappings={"instruction": prompt_column}ifprompt_columnisnotNoneelse{}
    )


    if__name__ =="__main__":
    distiset = pipeline.run(dataset=dataset)
    distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")

    然后将该数据加入了sft中:

    defmain(script_args, training_args, model_args):
    ################
    # Model init kwargs & Tokenizer
    ################
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    attn_implementation=model_args.attn_implementation,
    torch_dtype=model_args.torch_dtype,
    use_cache=Falseiftraining_args.gradient_checkpointingelseTrue,
    device_map=get_kbit_device_map()ifquantization_configisnotNoneelseNone,
    quantization_config=quantization_config,
    )
    training_args.model_init_kwargs = model_kwargs
    tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    ################
    # Dataset
    ################
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

    ################
    # Training
    ################
    trainer = SFTTrainer(
    model=model_args.model_name_or_path,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=dataset[script_args.dataset_test_split]iftraining_args.eval_strategy !="no"elseNone,
    processing_class=tokenizer,
    peft_config=get_peft_config(model_args),
    )

    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    iftraining_args.push_to_hub:
    trainer.push_to_hub(dataset_name=script_args.dataset_name)


    if__name__ =="__main__":
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    main(script_args, training_args, model_args)

    从代码上可以看到,这个过程是从教师模型中提取知识,并将其传递给学生模型。在这个特定的情况下,知识不是以软标签的形式直接传递,而是通过生成的推理数据来传递。这种方法通常被称为数据蒸馏(Data Distillation)或示例蒸馏(Example Distillation),它是知识蒸馏的一种变体。

    最后

    我看到了腾讯科技发布的一场关于DeepSeek的高质量闭门会:比技术更重要的是愿景,里面的很多内容可以作为结尾:

    1. 长期来说,通过走捷径的方式,而没有自己通过愿景去想怎么做技术方案,而是直接复现,中间可能会有不知道的坑。比如在这一代技术 long context 没有质变的前提下,解决问题的上限可能会被限制。R1-zero 可能是一个正确的方向,从头就做 R1-zero 或不通过类 o1 的数据启动可能更好。照着别人的技术方案可能不太好,希望更多探索。
    2. 蒸馏的坏处是模型 diversity 下降,影响模型上限,无法超越最强的模型。但短期看,蒸馏也是一条路线。其他模型用蒸馏也能得到较好的结果,未来在模型生态里面可能就会有老师、学生的角色区分,有能力当一名好学生也是一种可以的商业模式。

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ