最近wsdm cup到了瓶颈,租卡跑算力成本太高,而lmsys比赛的微调结果也没啥可抄的了,所以只能回头看看top方案,研究了一下阳哥的《Distill is all you need》,和第二名tascj对于训练推理的科技与狠活,有些感觉,伴随着deepseek的大火,蒸馏和强化学习又被端上了台面,我对强化学习暂时没什么兴趣,不过蒸馏跟我最近看的内容相关,在网上搜了一圈关于deepseek针对蒸馏的策略,好像没有过多内容介绍,于是想着总结找到的一些资料。
模型蒸馏即知识蒸馏(Knowledge Distillation),是一种模型压缩和加速技术。在深度学习中,大型深度神经网络虽性能优异,但因计算复杂度高、存储需求大,难以部署在资源受限设备上。模型蒸馏通过构建师生架构,让小的学生模型学习大的教师模型的知识,使学生模型在保持较小规模的同时,尽可能接近教师模型的性能。其核心组件包括知识(如教师模型的 logits、中间层特征等)、蒸馏算法(用于指导知识转移)和师生架构(决定知识传递方式)。
这里可以看比较主流的一张图,出自2021年综述:《Knowledge Distillation: A Survey》,对近年的Distillation做了一个详细概括,Knowledge Distillation的流程可以理解为:
图中除了loss之后会详细说明,唯一的未知点可能在于soft targets,它是经过softmax的下一层级结果logits(原始分数),公式为:
其中是温度系数,从公式中能很明显看出当值较大时,Softmax 输出的概率分布会更加平滑,每个类别的概率值相对更接近;值较小时,概率分布会更尖锐,高概率类别的概率值远高于其他类别。这些 soft targets 会传递给学生模型,学生模型在学习过程中不仅学习真实的hard targets信息,还能从教师模型的 soft targets 中获取类别之间的关联等知识,帮助其更好地训练和泛化。
hard targets 与 soft targets的区别可以从下面的四分类图中很形象的看出:
做知识蒸馏的方式有非常多,从训练方案流程来看,就有离线蒸馏、在线蒸馏和自蒸馏等,从算法更新角度上,还有对抗蒸馏、多教师蒸馏等,这里我就不用豆包在灌水了,想查一大片说明,直接以bert时代的蒸馏开始看。
TinyBERT是一种轻量级的预训练语言模型,由华为和华中科技大学提出。它通过知识蒸馏技术,将BERT模型的知识迁移到一个更小的模型中,从而实现了模型体积的大幅减小和推理速度的提升。在当时,它提出了两阶段transformer蒸馏方案:在大规模语料上首先进行通用MLM任务的蒸馏,在下游任务时,先学好老师模型,再进行蒸馏,具体如下图:
关于Transformer层蒸馏,主要包括注意力attn的蒸馏和隐藏层hidn的蒸馏:
关于损失函数,TinyBert的蒸馏loss为:
第一项:词向量层损失
第二项:中间层损失
第三项:预测层损失
如果有不清晰的,可以去看论文原文,我就不做过多解释了,上述的内容根据论文开源的github地址,其中对于蒸馏训练的截取部分,可进行一一对照:
# 蒸馏配置
distill_config = DistillationConfig(
# 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好
temperature=self.temperature,
# 设置ground truth loss权重
hard_label_weight=self.hard_label_weight,
# 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重
kd_loss_type=self.kd_loss_type, kd_loss_weight=self.kd_loss_weight,
# 配置中间层蒸馏映射
intermediate_matches=[
# 配置hidden蒸馏映射、维度映射
{'layer_T':0,'layer_S':0,'feature':'hidden','loss':'hidden_mse','weight':1,
'proj': ['linear',312,768]}, # embedding层输出
{'layer_T':3,'layer_S':1,'feature':'hidden','loss':'hidden_mse','weight':1,
'proj': ['linear',312,768]},
{'layer_T':6,'layer_S':2,'feature':'hidden','loss':'hidden_mse','weight':1,
'proj': ['linear',312,768]},
{'layer_T':9,'layer_S':3,'feature':'hidden','loss':'hidden_mse','weight':1,
'proj': ['linear',312,768]},
{'layer_T':12,'layer_S':4,'feature':'hidden','loss':'hidden_mse','weight':1,
'proj': ['linear',312,768]},
# 配置attention矩阵蒸馏映射,注意layer序号从0开始
{"layer_T":2,"layer_S":0,"feature":"attention","loss":"attention_mse","weight":1},
{"layer_T":5,"layer_S":1,"feature":"attention","loss":"attention_mse","weight":1},
{"layer_T":8,"layer_S":2,"feature":"attention","loss":"attention_mse","weight":1},
{"layer_T":11,"layer_S":3,"feature":"attention","loss":"attention_mse","weight":1},
]
)
# 训练配置
optimizer = AdamW(self.student_model.parameters(), lr=self.lr) # 使用大一点的lr
train_config = TrainingConfig(
output_dir=self.student_model_dir, device=self.student_trainer.device,
data_parallel=self.enable_parallel, ckpt_frequency=self.ckpt_frequency # 一个epoch存ckpt_frequency次模型
)
# 配置model中logits hiddens attentions losses的获取方法
defsimple_adaptor(batch, model_outputs):
return{
'logits': model_outputs[-1]['logits'],'hidden': model_outputs[-1]['hiddens'],
'attention': model_outputs[-1]['attentions'],'losses': model_outputs[1],
}
# 蒸馏
distiller = GeneralDistiller(
train_config=train_config, distill_config=distill_config,
model_T=self.teacher_model, model_S=self.student_model,
adaptor_T=simple_adaptor, adaptor_S=simple_adaptor
)
withdistiller:
logger.info('start to knowledge distill ...')
distiller.train(optimizer, train_dataloader, num_epochs=epoch)
logger.info('distill finish')
KL散度的定义是建立在熵(Entropy)的基础上的。此处以离散随机变量为例,若一个离散随机变量的可能取值为,而对应的概率为,则随机变量的熵定义为:
若有两个随机变量,且其概率分布分别为,则相对的相对摘为:
之所以称之为相对熵,是因为其可以通过两随机变量的交叉嫡(Cross-Entropy)以及信息摘推导得到,针对上述离散变量的概率分布而言,其交叉摘定义为:
因此,KL散度或相对熵可通过下式得出:
在上一节中,TinyBERT在设计其蒸馏过程时采用了多种损失函数,包括词向量层损失、中间层损失和预测层损失,在大模型时代下,词向量损失不用多说,因为已经完全做了解耦,如何进行embedding我想看到这里的都知道,中间层损失的不再使用,或者说中间层蒸馏的使用变少,我理解是大模型通常已经具有足够的参数来学习复杂的特征表示,因此它的必要性相对较低,另外就是中间层叠得太厚,所能获得的收益太低,所以不如针对预测层进行相应的改进,那自然,就不得不提本节在介绍的KL散度。
那为什么作为大模型来讲,更多使用KL散度呢?我觉得可以从以下三点考虑:
上述介绍了KL散度的定义,很明显,KL损失不是一个对称形式,即,那么我们可以试图用近似分布来优化该目标:
根据上一小节的概率公式推导,可以计算出反向 (Reverse KL,RKL)为:
正向 (Forward KL,FKL)为:
其中P是teacher,Q是student,在大模型之前,似乎很多人更喜欢用FKL,正向KL散度(FKL)更受青睐的原因可能与其在传统任务上的表现有关。传统分类任务的输出空间相对较小,模式(即分布的峰值)较少,这意味着分布更倾向于单一峰值而非多峰值分布。在这种情况下,FKL表现良好,因为它倾向于让学生模型关注教师模型输出中概率较高的区域,从而产生更准确的样本。然而,对于大型语言模型(LLM)来说,输出空间更加复杂,模式更多,再使用FKL可能导致学生模型关注教师模型输出中概率较低的区域,从而产生不良样本。
如上图所示,教师模型是蓝色曲线,它的输出是可量化的,这里假设为两个高斯波峰,而黄色,是理想情况下,我们认为学生模型可以近似为正态分布来拟合教师曲线,那么会出现两种结果,一种是尽可能多的包括多峰的面积,第二种是直接拟合最高波峰的分布。所以左边是Forward KL,右边是反向。
中间的一些具体推导过程不过多赘述,近年有非常多的论文对该方案做了benchmark,比如说下图是《f-Divergence Minimization for Sequence-Level Knowledge Distillation》一文的数据:
还有《Rethinking Kullback-Leibler Divergence in Knowledge Distillation for Large Language Models》篇的数据和AKL:
另外说明一下,本节内容就是看了作者在知乎发的《LLM的知识蒸馏(KD)应该用Reverse KL?》一文才有想法撰写本节,对于想复现的小伙伴来讲,可以去看这几篇论文的github,作者还给了一些相应的可视化demo。
TRL(Transformer Reinforcement Learning)库是用于后续训练基础模型的综合库,专为使用监督微调 (SFT)、近端策略优化 (PPO) 和直接偏好优化 (DPO) 等先进技术进行训练后的基础模型而设计。这里我们只看它里面的两种trainer——SFTtrainer和GKDtrainer。
从原理方面来讲:
从损失计算方面来讲:
这两种顺序非常直观,GKDTrainer继承自SFTTrainer,SFTTrainer继承自Trainer。那从SFTtrainer看,它的调用非常简单,trl的readme直接写了一个demo:
fromtrlimportSFTConfig, SFTTrainer
fromdatasetsimportload_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
调用该类后,我又去看了下transformers的trainer,它的损失函数为:
defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if(self.label_smootherisnotNoneorself.compute_loss_funcisnotNone)and"labels"ininputs:
labels = inputs.pop("labels")
else:
labels =None
ifself.model_accepts_loss_kwargs:
loss_kwargs = {}
ifnum_items_in_batchisnotNone:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
#TODO:this needs to be fixed and made cleaner later.
ifself.args.past_index >=0:
self._past = outputs[self.args.past_index]
iflabelsisnotNone:
unwrapped_model = self.accelerator.unwrap_model(model)
if_is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
ifself.compute_loss_funcisnotNone:
loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
elifmodel_nameinMODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
ifisinstance(outputs, dict)and"loss"notinoutputs:
raiseValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are{','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"]ifisinstance(outputs, dict)elseoutputs[0]
ifself.args.average_tokens_across_devicesandself.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return(loss, outputs)ifreturn_outputselseloss
很显然这部分有非常多的自适应判断,根据我们上一层为SFTtrainer类,并且没有指定loss方法,所以将选用cross-entropy loss作为模型训练参数。
而GKDtrainer类的方式就不一样,由于KL散度是不对称的,在知识蒸馏中使用JSD,Jensen-Shannon Divergence 是基于KL散度改进的更平滑和对称的概率分布度量。论文中给出了其改进的计算公式:
那自然其重写了compute_loss,具体计算为generalized_jsd_loss,代码如下:
defgeneralized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
):
"""
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
of https://huggingface.co/papers/2306.13649 for the definition.
Args:
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
temperature: Softmax temperature (default: 1.0)
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
Returns:
loss: Scalar tensor with the generalized JSD loss
"""
# Apply temperature scaling
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# Compute log probabilities for student and probabilities for teacher
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1- beta)]),
dim=0,
)
# Compute KL divergences using F.kl_div
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# Compute the Generalized Jensen-Shannon Divergence
jsd = beta * kl_teacher + (1- beta) * kl_student
# Masking
iflabelsisnotNone:
mask = labels !=-100
jsd = jsd[mask]
# Apply reduction
ifreduction =="batchmean":
returnjsd.sum() / mask.sum()iflabelsisnotNoneelsejsd.sum() / (jsd.size(0) * jsd.size(1))
elifreduction =="sum":
returnjsd.sum()
elifreduction =="mean":
returnjsd.mean()
else:
returnjsd
defcompute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
# compute student output
outputs_student = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# compute teacher output in eval mode
self.teacher_model.eval()
withtorch.no_grad():
outputs_teacher = self.teacher_model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
# slice the logits for the generated tokens using the inputs["prompts"] lengths
prompt_lengths = inputs["prompts"].shape[1]
shifted_student_logits = outputs_student.logits[:, prompt_lengths -1:-1, :]
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths -1:-1, :]
shifted_labels = inputs["labels"][:, prompt_lengths:]
# compute loss
loss = self.generalized_jsd_loss(
student_logits=shifted_student_logits,
teacher_logits=shifted_teacher_logits,
labels=shifted_labels,
beta=self.beta,
)
# empty cache
empty_cache()
# Return loss
return(loss, outputs_student)ifreturn_outputselseloss
对于该类好不好用,我也不知道,暂时没用过,只能说从理论来分析,JSD损失和KL损失的区别,不过与SFTtrainer类似,调用方式也很简单,可以跑几次看看情况:
fromdatasetsimportload_dataset
importrandom
fromtransformersimportAutoTokenizer
fromtrlimport(
GKDConfig,
GKDTrainer,
LogCompletionsCallback,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
################
# Training
################
trainer = GKDTrainer(
model=model_config.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=test_data,
processing_class=tokenizer,
peft_config=get_peft_config(model_config),
)
completions_callback = LogCompletionsCallback(trainer, trainer.generation_config, num_prompts=8)
trainer.add_callback(completions_callback)
trainer.train()
# Save
trainer.save_model(training_args.output_dir)
本节是对阳哥夺冠方案中关于蒸馏部分的经典总结,在这里做一个旁征博引,因为没有算力,具体我也没复现过,不过算是除了写这篇推文的初衷,本来是想做一个top方案亮点汇总,只是因为deepseek的爆火针对其中一个方向做了延展。那话不多说,github原址为:https://github.com/shyoulala/LMSYS_BlackPearl
该仓库的目录结构为:
./model_path # 预训练模型的路径,存放预训练模型的权重和配置文件
./src_fast # 快速训练脚本的存放位置,可能包含简化的训练代码
./src # 完整解决方案的代码目录,包含整个项目的完整训练和处理流程
./data # 数据目录,存放训练数据和其他相关数据
./data/oof # Out-of-Fold 数据目录,可能用于交叉验证的中间结果
./data/processed_data # 处理后的数据目录,存放经过预处理的数据
./data/processed_data/orgemma2fold4 # 训练集,包含用于直接蒸馏的 70b 概率数据(第4折)
./data/processed_data/orgemma2fold2 # 同上,第2折
./data/processed_data/orgemma2fold0 # 同上,第0折
./data/processed_data/orgemma2fold1 # 同上,第1折
./data/processed_data/orgemma2fold3 # 同上,第3折
./data/lmsys-chatbot-arena # 可能存放与 LMSYS Chatbot Arena 相关的数据或资源
./sub # 输出目录,用于存放训练结果、预测结果等
./model_save # 训练模型的保存路径,存放训练完成后的模型文件
./model_save_or # 另一个模型保存路径,可能是用于存放原始模型或特定版本的模型
./model_save_or/v7_ut_gemma_v7_64r128_ddgemma2_16bit # 经过后处理(如蒸馏)的模型版本,可能是 Gemma2-9B 的 16bit 版本
挺难想象的,大模型时代竟然还能做交叉验证,不过lmsys是个三分类任务,依照之前逻辑也没什么问题,该方案主要是用llama3-70B和Qwen2-72B-instruct对gamma2-9B做蒸馏,所有大致流程,都通过run_pipeline.sh有显现:
#!/bin/bash
set -e
qwen_path=../model_path/qwen2_72b
llama_path=../model_path/llama3_70b
gemma_path=../model_path/Gemma2_9b
qwen_path_ut=../model_save/qwen2_4bit_pretrain/epoch_0_model/adapter.bin
llama_path_ut=../model_save/llama3_4bit_pretrain/epoch_0_model/adapter.bin
gemma_path_ut=../model_save/gemma2_4bit_pretrain/epoch_0_model/adapter.bin
fold=$1
echo run
{fold}
# train llama3 70b
sh run_fintune.sh llama3 ${llama_path} ${llama_path_ut} ${fold}
# predict train logits
python predict_train.py ${llama_path} ../model_save/llama3_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/llama3fold${fold}/train.parquet ../data/oof/llama3fold${fold}_train.parquet
# train qwen2 70b
sh run_fintune.sh qwen2 ${qwen_path} ${qwen_path_ut} ${fold}
# predict train logits
python predict_train.py ${qwen_path} ../model_save/qwen2_4bit_load_fintune/epoch_0_model/adapter.bin ../data/processed_data/qwen2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet
# merge logits
python merge_logits.py ../data/processed_data/gemma2fold${fold}/train.parquet ../data/oof/qwen2fold${fold}_train.parquet ../data/oof/llama3fold${fold}_train.parquet ../data/processed_data/gemma2fold${fold}/train_logits.parquet
# distill fintune gemma2-9b
sh run_fintune_16bit_distill.sh gemma2 ${gemma_path} ${gemma_path_ut} ${fold}
中间几步有挺多有趣的操作,比如是如何做post train的,以及最后merge logits,这里仅谈蒸馏之前的merge lora,因为代码足够简单:
importtime
fromdataclassesimportdataclass
importpickle
importtorch
importsklearn
importnumpyasnp
importpandasaspd
fromtqdm.autoimporttqdm
fromtransformersimportGemma2ForSequenceClassification, GemmaTokenizerFast, BitsAndBytesConfig
fromtransformers.data.data_collatorimportpad_without_fast_tokenizer_warning
frompeftimportget_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
lora_dir ='../model_save/gemma2fold0_16bit_load_fintune/best_val_loss_model/adapter.bin'
d1 = torch.load(lora_dir)
lora_dir ='../model_save/gemma2fold1_16bit_load_fintune/best_val_loss_model/adapter.bin'
d2 = torch.load(lora_dir)
lora_dir ='../model_save/gemma2fold2_16bit_load_fintune/best_val_loss_model/adapter.bin'
d3 = torch.load(lora_dir)
lora_dir ='../model_save/gemma2fold3_16bit_load_fintune/best_val_loss_model/adapter.bin'
d4 = torch.load(lora_dir)
lora_dir ='../model_save/gemma2fold4_16bit_load_fintune/best_val_loss_model/adapter.bin'
d5 = torch.load(lora_dir)
d = {}
fork, vind1.items():
v = d1[k] + d2[k] + d3[k] + d4[k] + d5[k]
v = v /5.
d[k] = v
torch.save(d,"../model_save/final_adapter.bin")
代码上可见,就是对经过5次交叉验证的gamma模型权重做了加权平均合并,但我看discussion很多人提到了,它们同样想到了该方案,不过效果并不好,似乎是这些权重还需要做方差评估,如果方差过大反而会拖累加权后的结果,感兴趣有卡有算力的能进行尝试,我就不过多提了。
回到正题,最终是先得到了llama3和Qwen的模型输出,那么蒸馏即是需要考虑这两者的结果,所以蒸馏损失选择了:
loss_fun = nn.CrossEntropyLoss()
divergence_loss_fn = nn.KLDivLoss(reduction='batchmean')
cos_loss_fn = nn.CosineEmbeddingLoss()
outputs = model(batch['input_ids'], use_cache=False)# predict gemma2
logits = outputs.logits
grads = batch['grads']
grads1 = batch['grads'][:, :3]# qwen2
grads2 = batch['grads'][:,3:]# llama3
labels = batch['labels']
loss_ce = loss_fun(logits, labels)
loss_grad1 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads1 / T, dim=1)
)
cos_loss1 = cos_loss_fn(F.softmax(grads1 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss_grad2 = divergence_loss_fn(
F.log_softmax(logits / T, dim=1),
F.softmax(grads2 / T, dim=1)
)
cos_loss2 = cos_loss_fn(F.softmax(grads2 / T, dim=1), F.softmax(logits / T, dim=1),
torch.ones(logits.size()[0]).to(logits.device))
loss = (loss_ce + loss_grad1 + cos_loss1 + loss_grad2 + cos_loss2) /5.
用数学公式理解,即为交叉熵和KL散度的混合:
这里刚开始我不是很理解,然后问了下deepseek懂了:
为什么同时使用交叉熵损失和 KL 散度损失?
1. 保持监督学习能力
交叉熵损失确保学生模型能够正确预测真实标签,从而保持模型的监督学习能力。如果没有交叉熵损失,学生模型可能会过度依赖教师模型的输出,而忽视真实标签的指导,导致模型在真实数据上的性能下降。
2. 学习教师模型的软目标
KL 散度损失让学生模型学习教师模型的软目标,从而捕捉到教师模型的内部表示和知识。软目标通常包含更多的信息,可以帮助学生模型更好地理解数据的分布和特征。
3. 平衡硬标签和软目标
同时使用交叉熵损失和 KL 散度损失可以平衡硬标签和软目标的贡献。硬标签(真实标签)提供了直接的监督信号,而软目标(教师模型的输出)提供了更多的上下文信息。通过调整两者的权重,可以更好地指导学生模型的学习。
其实我认为以上主要的,是因为教师模型是两个,而不是一个,KL更适合于一个,而两个加入交叉熵我的理解为桥接,更能体现泛化,但具体为啥这样安排,只有跑了才知道,所以根据github的环境说明,有8张A100以上的,可以跑一轮,等待3天以上,观看结果了。
该repo是DeepSeek-R1的开放复现版本,由huggingface的CEO亲自提出并进行,我大致看了一下,它的规划是:
这里重点看step 1,即它使用distilabel来对Deepseek-R1提取蒸馏数据,以下是一个简单demo:
fromdatasetsimportload_dataset
fromdistilabel.modelsimportvLLM
fromdistilabel.pipelineimportPipeline
fromdistilabel.steps.tasksimportTextGeneration
prompt_template ="""\
You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
{{ instruction }}"""
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10))
model_id ="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"# Exchange with another smol distilled r1
withPipeline(
name="distill-qwen-7b-r1",
description="A pipeline to generate data from a distilled r1 model",
)aspipeline:
llm = vLLM(
model=model_id,
tokenizer=model_id,
extra_kwargs={
"tensor_parallel_size":1,
"max_model_len":8192,
},
generation_kwargs={
"temperature":0.6,
"max_new_tokens":8192,
},
)
prompt_column ="problem"
text_generation = TextGeneration(
llm=llm,
template=prompt_template,
num_generations=4,
input_mappings={"instruction": prompt_column}ifprompt_columnisnotNoneelse{}
)
if__name__ =="__main__":
distiset = pipeline.run(dataset=dataset)
distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")
然后将该数据加入了sft中:
defmain(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=Falseiftraining_args.gradient_checkpointingelseTrue,
device_map=get_kbit_device_map()ifquantization_configisnotNoneelseNone,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Training
################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split]iftraining_args.eval_strategy !="no"elseNone,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
iftraining_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if__name__ =="__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)
从代码上可以看到,这个过程是从教师模型中提取知识,并将其传递给学生模型。在这个特定的情况下,知识不是以软标签的形式直接传递,而是通过生成的推理数据来传递。这种方法通常被称为数据蒸馏(Data Distillation)或示例蒸馏(Example Distillation),它是知识蒸馏的一种变体。
我看到了腾讯科技发布的一场关于DeepSeek的高质量闭门会:比技术更重要的是愿景,里面的很多内容可以作为结尾:
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |