논문 링크: https://arxiv.org/pdf/2205.10536.pdf

Official Implementation: https://github.com/hunto/DIST_KD

 

GitHub - hunto/DIST_KD: Official implementation of paper "Knowledge Distillation from A Stronger Teacher", NeurIPS 2022

Official implementation of paper "Knowledge Distillation from A Stronger Teacher", NeurIPS 2022 - GitHub - hunto/DIST_KD: Official implementation of paper "Knowledge Distillation fro...

github.com


Introduction

Existing challenges:

  • Heavy deep learning models are difficult to deploy due to computational and memory limitations.
  • Knowledge distillation (KD) has been used to boost student model performance, but methods such as vanilla KD face issues when teacher and student model sizes differ significantly.
  • There is a lack of thorough analysis on the impact of training strategies for stronger teachers on KD.

Contributions of this paper:

  • A systematic study of prevalent strategies for designing and training deep neural networks is conducted to understand stronger teachers and their effect on KD.
  • The Pearson correlation coefficient is proposed as a new match manner to replace KL divergence, addressing the challenges of vanilla KD.
  • The paper introduces distilling intra-class relations for further performance improvement.
  • The proposed method (DIST) is simple, efficient, and practical, and outperforms existing KD methods in various tasks, such as image classification, object detection, and semantic segmentation.

The DIST (Distillation from A Stronger Teacher) method relaxes the exact match in vanilla KD by maximizing linear correlation to preserve the inter-class and intra-class relations of teacher and student predictions. Specifically, it uses the Pearson correlation coefficient to measure the inter-class relation (between multiple classes for each instance) and intra-class relation (between multiple instances for each class).

Here is a summary of the DIST method:

  1. For inter-class relation, maximize the correlation row-wise for each pair of prediction vectors Y(s) and Y(t).
  2. For intra-class relation, maximize the correlation column-wise for each pair of prediction matrices Y(s) and Y(t).
  3. The overall training loss Ltr consists of the classification loss, inter-class KD loss, and intra-class KD loss.

 

 

Code Implementation

import torch
import torch.nn as nn

def pearson_correlation(y_s, y_t):
    y_s_mean = torch.mean(y_s, dim=1, keepdim=True)
    y_t_mean = torch.mean(y_t, dim=1, keepdim=True)
    y_s_std = torch.std(y_s, dim=1, keepdim=True)
    y_t_std = torch.std(y_t, dim=1, keepdim=True)
    cov = torch.mean((y_s - y_s_mean) * (y_t - y_t_mean), dim=1)
    corr = cov / (y_s_std * y_t_std)
    return 1 - corr

def dist_loss(y_s, y_t, alpha, beta, gamma):
    Lcls = nn.CrossEntropyLoss()(y_s, y_t.argmax(dim=1))
    Linter = pearson_correlation(y_s, y_t).mean()
    Lintra = pearson_correlation(y_s.t(), y_t.t()).mean()
    Ltr = alpha * Lcls + beta * Linter + gamma * Lintra
    return Ltr

# Example usage:
# y_s: student prediction matrix
# y_t: teacher prediction matrix
# alpha, beta, gamma: factors for balancing losses
loss = dist_loss(y_s, y_t, alpha, beta, gamma)

---

Reviewing the use of DIST knowledge distillation on a custom dataset that exhibits class imbalance, multi-task classification, and domain variability.

DIST occurs significant performance degradation compared to Hinton's distillation.

 

 

 

+ Recent posts