超好用!——知识蒸馏中即插即用的对抗性调度器以及调整向量Vector

一.前言

本设计思路来源于论文《Dynamic Data-Free Knowledge Distillation by Easy-to-Hard Learning Strategy》。

1.1 原理

总体架构图如下。在常规的知识蒸馏中,一般不会考虑知识的难度先后,按照我们人类的思维,肯定是先学习容易的再学习难一点的知识(总不能小学就学高数吧哈哈)。一个模型的理想状态也应该如此。

在本论文的设计图中,可以看到Generator负责生成伪数据提供给教师模型和学生模型,但是这个Generator是受到一个随时间变化的Adversarial Scheduler调节的,这个Adversarial Scheduler的作用就是让Generator随时间生成从易到难的知识(也就是图像)提供给学生模型和教师模型。

有着一个contribution还不够,作者又在最后加上了一个Reweighting
Vector,这个向量能够修改不同样本的影响力,开始时强调更简单的样本,并逐渐包含更难的样本。具体来说,对于那些模型预测相对准确(即易于模型学习)的样本,会分配较低的权重;而对于模型预测不准确(即难以模型学习)的样本,会分配较高的权重。这样,模型初期会更多地关注易于学习的样本,随着训练的进行,逐渐增加对难样本的关注,从而实现由易到难的学习策略。

超好用!——知识蒸馏中即插即用的对抗性调度器以及调整向量Vector_第1张图片
通过这两个模块,可以非常简单高效地让学生模型从易到难地学习。

1.2 优势

那这样做的优势除了人之常理还有没有更有说服力的解释呢?

1.2.1 Adversarial Scheduler

通过设置Adversarial Scheduler,模型能够在早期阶段避免过度的对抗性扰动,从而保证基础知识的有效学习,而在后期阶段逐渐增加对抗性挑战,促进模型在更复杂的情况下的性能提升。这种策略有效地平衡了学习效率和模型鲁棒性之间的关系,为模型提供了一个更平滑的学习曲线,并最终实现了更好的泛化能力。

1.2.2 Reweighting Vector

  • 动态调整学习难度:Reweighting Vector在学生模型的早期阶段,倾向于学习更容易的样本,并且可以避免学生模型陷入局部最小值。随着学生模型能力的提高,Reweighting Vector逐渐增加难度,引入更接近决策边界的样本,以促进学生模型的泛化能力。

  • 防止灾难性遗忘:在知识蒸馏过程中,学生模型可能会遗忘早期学到的知识。Reweighting Vector通过在训练过程中逐渐引入新的、更难的样本,有助于学生模型在保持旧知识的同时学习。

二 代码讲解

源代码链接click here

2.1 Adversarial Scheduler部分

if args.method == 'cudfkd':
     if epoch  > int(args.epochs * args.begin_fraction) and epoch < int(args.epochs * args.end_fraction) and args.curr_option != 'none': 
         synthesizer.adv += args.grad_adv

这个部分是在每个epoch结束时加入的,其目的在于动态调整synthesizer(生成器的关键部件)的adv值。

  # Negative Divergence.
  if self.adv > 0:
      s_out = self.student(samples, l=l)
      if self.adv_type == 'js':
          l_js = jsdiv(s_out, t_out, T=3)
          loss_adv = 1.0-torch.clamp(l_js, 0.0, 1.0)
      if self.adv_type == 'kl':
          mask = (s_out.max(1)[1]==t_out.max(1)[1]).float()
          loss_adv = -(kldiv(s_out, t_out, reduction='none', T=3).sum(1) * mask).mean()
  else:
      loss_adv = torch.zeros(1).to(self.device)
            
  loss = self.lmda_ent * ent + self.adv * loss_adv+ self.oh * loss_oh + self.act * loss_act + self.bn * loss_bn

这是synthesizer类中涉及到adv计算部分的代码,可以看到adv的值很大程度上决定了最终反向传播的loss的值,所以通过调整adv的值来控制loss进而调整Generator的参数θ是十分有效的。

2.1 Reweighting Vector部分

if args.method == 'cudfkd':
    if args.dataset == 'cifar10':
         alpha = 0.0001
    else:
         alpha = 0.00002
    lamda = datafree.datasets.utils.lambda_scheduler(args.lambda_0, global_iter, alpha=alpha)
with torch.no_grad():
    g,v = datafree.datasets.utils.curr_v(l=real_loss_s, lamda=lamda, spl_type=args.curr_option.split('_')[1])

这是λ-scheduler来调控lamda的值,进而将lamda的值传入curr_v函数中,从而得到我们想要的重权向量。

具体的lambda_scheduler和curr_v函数如下:

def lambda_scheduler(lambda_0, iter, alpha=0.0001, iter_0=500000000):
    if iter < iter_0:
        lamda = lambda_0 + alpha * iter
    else:
        lamda = lambda_0 + alpha * iter_0
    return lamda

lambda_scheduler很简单地对lamda进行调控,iter是每个epoch中正在蒸馏的轮次。(每个epoch可能要蒸馏上百个iter)

def curr_v(l, lamda, spl_type='hard'):
    if spl_type == 'hard':
        v = (l < lamda).float()
        g = -lamda * (v.sum())
    elif spl_type == 'soft':
        v = (l < lamda).float()
        v *= (1 - l / lamda)
        g = 0.5 * lamda * (v * v - 2 * v).sum()
    elif spl_type == 'log':
        v = (1 + math.exp(-lamda)) / (1 + (l - lamda).exp())
        mu = 1 + math.exp(-lamda) - v
        g = (mu * mu.log() + v * (v+1e-8).log() - lamda * v)
        # print(g, v.min(), v)

    else:
        raise NotImplementedError('Not implemented of spl type {}'.format(spl_type))

    return g, v

参数中l是本次iter中的real_loss,lamda是上问的lamda,type是按照自己的需要进行选择的。(该算法的数学原理不再解释)

三.小结

为什么要讲这两个插件模块?因为他们真的几乎能用到所有可用的知识蒸馏模型之中,并且实现起来简洁高效。有需要的伙伴快来食用!

码字不易,还请点赞收藏支持,谢谢!!!

你可能感兴趣的:(机器学习,人工智能,深度学习,python)