当p(x)较大时,qθ(x) 也需要比较大且比p(x)相对更大,否则公式右边很大的情况下,FKL整体就无法达到最小;
当p(x)较小时,p(x) 在 log 外趋于0占主导,FKL整体总是能比较小,跟qθ(x)关系较小。所以在优化的时候,qθ(x) 会覆盖p(x)的所有mode,即便此时有可能导致高估 p(x) 很小的部分,对应上述图中的橙色部分。
当qθ(x) 较大时,为了在优化时候降低RKL,p(x) 必须较大,因此 p(x) 概率最大的 mode 也要对应 qθ(x) 概率最大的地方,p(x) 概率很小的地方必须对应 qθ(x) 概率为0的地方,也就是说 qθ(x) 拟合了 p(x) 概率最大的部分。对应前述图中的绿色部分。
当 qθ(x) 等于0时,p(x) 取什么样的值都不影响优化。
由此可以看下MiniLLM具体的方法图:
论文中的另一个视角个人觉得特别好,就是将RKL和逆强化学习进行对比,并给出了数学说明,可以看一下
公式说明:这里的公式序号均来自论文本身,目的是结合论文一起看可能更好,不破坏原有公式顺序。
既然可以这么类比,RKL约等于逆强化学习,FKL等价于模仿学习,而在实际应用和理论说明中,逆强化学习的效果都会比模仿学习更优,虽然更加难以训练,但其泛化性能,理论上限肯定会更高,所以结论是MiniLLM的RKL理论上是更优的。模仿学习和逆强化学习这个说明可以查看:https://www.zhihu.com/question/470949607/answer/2450111740?utm_id=0
上面两个部分,其实都在说明,MiniLLM理论证明上是更优的蒸馏方法。所以我们可以去进行大胆尝试。实际的训练过程,类似于RLHF的训练方式,教师模型在训练中只推理,作为奖励信号去训练模型。作者也提供了类似ranking loss的更简单的平替方式去优化,对比传统Bert时代的蒸馏方法都会有提升。感谢作者!
# https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
# 这里是实际蒸馏loss的计算
def get_distil_loss(args, tokenizer, model, teacher_model, model_batch, no_model_batch, logits):
with torch.no_grad():
teacher_model.eval()
teacher_outputs = teacher_model(**model_batch, use_cache=False) # 教师模型推理
teacher_logits = teacher_outputs.logits # 获取教师分布logits
if args.model_parallel:
distil_losses = mpu.parallel_soft_cross_entropy_loss(logits.float(), teacher_logits.float())
distil_losses = distil_losses.view(-1)
loss_mask = no_model_batch["loss_mask"].view(-1)
distil_loss = (distil_losses * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) #
inf_mask = torch.isinf(logits)
#log_softmax实际上是在教师和学生的交叉熵;交叉熵损失在形式上等价于KL散度减去一个常数项(分布P 的熵)在最小化KL散度时可以忽略
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (no_model_batch["label"] != -100).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss
一般来说,实际训练中还会加上sft数据的loss,确保不跑偏,类似rlhf中的reference model的作用:
output = self.model(inputs, attention_mask=attention_mask, return_output=True)
sft_loss = self.loss_fn(output.logits, labels)
需要注意教师模型和学生模型需要使用同源的模型。即相同的tokenizer,对于国产模型来说,qwen、deepseek、yi等都有相同tokenizer不同尺寸的模型可供选择。
整体说明:
人类认知系统中的两种推理系统,系统1和系统2,系统1被认为是无意识的,能够快速识别和迅速判断,也叫做快思考,系统2被认为是处理复杂问题如数学和逻辑问题,需要深思熟虑,也叫做慢思考。
在大模型中,可以将中间的流程如多次调用大模型、中间的思考tokens类比为深思熟虑的过程,这些方法如cot、RaR等等带来更好的推理效果,但与此同时,耗时问题会导致这些方法很难用于生产落地。于是很多方法都在尝试将系统2的效果蒸馏到系统1当中(毕竟自2023年 gpt4出来后,应该有非常多的黑盒蒸馏gpt4数据训练到各家系统中;还有很久之前llama2的ghost attention:在每一句中都加入system prompt让 Llama 2 有效地遵循多轮指令,都是一些蒸馏的有效形式)。
这篇论文的主要与之前差异点在于,显式的提出System2的推理能力蒸馏到System1中,并做了很多实验进行验证。可以理解为论文提供了非常好的一种数据合成的范式,通过使用这些数据进行指令微调等方法,提升System1的推理能力。
以下几个公式是对System1和System2的形式化说明:
也就是说,通过上述公式3可以得到的大量训练数据,但是实际会存在质量问题。论文主要通过一致性标准进行过滤。
输出一致性:输入不变,对输出进行N次采样,通过投票实现,少数服从多数
输入扰动下的一致性:输出不变,对输入增加扰动,比如选择题改顺序但答案没变化,不一致则过滤
但猜测实际可能有更多更精细化的方式实现。
然后就是四种方式在不同数据集上的效果,我觉得给出Prompt可能是最好的方法体现形式
Prompt:
"{question}"\nRephrase and expand the question, and respond.
让模型先改写,改写可能提供更丰富的文本信息,然后再回答,能让大模型用自己的知识体现理解问题,回答问题。
让大模型过滤无效信息,去除有偏信息和不相干上下文,然后再改写基础上进行回答
论文通过这四种System2的方式,蒸馏到System1当中,做了很多实验,结果就不一一贴了,都是差的也不会发paper,总结下来,整体是有效的,如RaR蒸馏可用于澄清任务指令相关任务、S2A能有效提升有偏任务,Branch-Solve-Merge蒸馏能作为LLM-Judge评估任务,但是在复杂推理任务上的蒸馏,目前还做得不好。这可能也是一个共识,需要持续研究。
总结,不管是黑盒蒸馏,还是白盒蒸馏,都是现如今非常好的将更大模型的知识注入到较小模型中去的方式,不断提升小模型的知识密度,这样可以再更多的落地场景中应用。期待这个方向后续更多的工作。
| 欢迎光临 链载Ai (https://www.lianzai.com/) | Powered by Discuz! X3.5 |