返回顶部
热门问答 更多热门问答
技术文章 更多技术文章

从零开始学大模型,什么,GAN也能用于知识蒸馏?知识蒸馏算法之Adversarial distillation!!

[复制链接]
链载Ai 显示全部楼层 发表于 昨天 09:30 |阅读模式 打印 上一主题 下一主题

ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;display: table;border-bottom: 2px solid rgb(15, 76, 129);color: rgb(63, 63, 63);visibility: visible;">引言

Adversarial distillation,对抗性知识蒸馏,结合了对抗学习的理念和传统的知识蒸馏方法,以促进学生模型(简化模型)更好地模仿教师模型(复杂模型)的行为和知识。这种方法的核心是通过对抗的方式,提高学生模型对数据分布和教师模型特征的学习能力。

ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;border-left: 3px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">基本原理

对抗性知识蒸馏通常包含以下几个步骤:


  1. 教师模型和学生模型的建立:首先,需要一个已经训练好的教师模型和一个结构简化的学生模型。

  2. 生成器和鉴别器的使用:

  • 生成器:在一些方法中,生成器用于生成逼真的数据样本,这些样本用来训练学生模型,使其输出更加接近教师模型。

  • 鉴别器:用来判断输出或特征来自教师模型还是学生模型,通过优化鉴别器,间接地推动学生模型更好地模仿教师模型的行为。

  • 对抗性优化:通过迭代优化生成器和鉴别器,不断调整学生模型的参数,使得学生模型,在鉴别器难以区分其与教师模型之间的差异时,取得最佳性能。


  • 对抗性知识蒸馏,通常有三种形式,如下图所示,
    a)基于生成器的对抗性知识蒸馏,在这种方法中,生成器(教师模型也可以用来充当鉴别器,不需要有一个独立的鉴别器)不仅仅是生成数据样本,而是专门生成训练数据或特征,更好地模拟教师模型的输出。生成器试图生成逼真的训练数据,学生模型则尝试根据这些数据进行学习,目标是使学生模型的输出尽可能接近教师模型的输出。


    b)基于鉴别器的对抗性知识蒸馏,鉴别器用来区分学生模型和教师模型的输出或特征。通常,鉴别器的任务是,判断给定的输出或特征是否来自教师模型,在这类方法中,学生模型作为生成器来参与训练。学生模型的训练目标是欺骗鉴别器,使其不能正确区分两者的差异,从而逼近教师模型的性能。
    c)基于联合优化的在线对抗性知识蒸馏,教师模型和学生模型是同时训练的,这种方法也被称为在线蒸馏。使用一个或多个鉴别器,来评估和对比教师和学生模型的表现,通过联合优化过程,学生和教师模型不断调整自身参数,以最小化鉴别器的判别能力,最终目标是使鉴别器难以区分学生和教师的输出。这种方法特别适合于实时系统和需要快速适应新数据的场景。


    ingFang SC", Cambria, Cochin, Georgia, Times, "Times New Roman", serif;border-left: 3px solid rgb(15, 76, 129);color: rgb(63, 63, 63);">Pytorch实现demo

    假设我们已经有了一个预训练好的教师模型和一个未训练的学生模型。

    import torchimport torch.nn as nn
    # 定义教师模型和学生模型class TeacherModel(nn.Module):def __init__(self):super(TeacherModel, self).__init__()self.conv = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)self.relu = nn.ReLU()self.fc = nn.Linear(16*14*14, 10)
    def forward(self, x):x = self.relu(self.conv(x))x = x.view(x.size(0), -1)return self.fc(x)
    class StudentModel(nn.Module):def __init__(self):super(StudentModel, self).__init__()self.conv = nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1)self.relu = nn.ReLU()self.fc = nn.Linear(8*14*14, 10)
    def forward(self, x):x = self.relu(self.conv(x))x = x.view(x.size(0), -1)return self.fc(x)
    teacher = TeacherModel()student=StudentModel()


    定义鉴别器

    class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.fc = nn.Linear(10, 1)
    def forward(self, x):return torch.sigmoid(self.fc(x))

    ingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;font-size: 15px;letter-spacing: 0.1em;color: rgb(63, 63, 63);">训练过程中,我们需要同时优化学生模型和鉴别器

    # 损失函数和优化器criterion = nn.BCELoss()optimizer_student = torch.optim.Adam(student.parameters(), lr=0.001)optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=0.001)
    for epoch in range(num_epochs):for data in dataloader:inputs, _ = data# 教师和学生模型的预测teacher_outputs = teacher(inputs)student_outputs = student(inputs)# 真实标签和假标签real_labels = torch.ones(inputs.size(0), 1)fake_labels = torch.zeros(inputs.size(0), 1)# 训练鉴别器discriminator_real = discriminator(teacher_outputs.detach())discriminator_fake = discriminator(student_outputs.detach())real_loss = criterion(discriminator_real, real_labels)fake_loss = criterion(discriminator_fake, fake_labels)discriminator_loss = (real_loss + fake_loss) / 2optimizer_discriminator.zero_grad()discriminator_loss.backward()optimizer_discriminator.step()
    # 训练学生模型outputs = discriminator(student_outputs)student_loss = criterion(outputs, real_labels)optimizer_student.zero_grad()student_loss.backward()optimizer_student.step()


回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

链载AI是专业的生成式人工智能教程平台。提供Stable Diffusion、Midjourney AI绘画教程,Suno AI音乐生成指南,以及Runway、Pika等AI视频制作与动画生成实战案例。从提示词编写到参数调整,手把手助您从入门到精通。
  • 官方手机版

  • 微信公众号

  • 商务合作

  • Powered by Discuz! X3.5 | Copyright © 2025-2025. | 链载Ai
  • 桂ICP备2024021734号 | 营业执照 | |广西笔趣文化传媒有限公司|| QQ