stMind

about Tech, Computer vision and Machine learning

Chapter 3.2.1 of Efficient Deep Learning: Distillation

Distillationは、事前学習済みの大規模ネットワーク(教師モデル)を用いて、小規模ネットワーク(生徒モデル)を学習する仕組み。教師モデルと、生徒モデルは同じ入力を受け取り、教師モデルの方は生徒モデルのためのSoft labelを生成する。Soft labelは、正解もしくは不正解のクラスの情報を含むHard labelと異なり、例えばトラックは猫よりも車クラスの方が近いといったようなクラス間の関係の知識を含むと考えられ、これはHard labelによる学習だけでは得られない知識と言えます。

生徒モデルは、Hard labelに対するCross entropy lossと、Soft labelに対するDistillation loss(こちらもCross entropy loss)を用いて学習していきます。このプロセスにおいて、通常は教師モデルのパラメータは更新せず、生徒モデルのパラメータのみ更新を行います。

f:id:satojkovic:20210718172250p:plain

Distilling the Knowledge in a Neural Networkでは、音声認識タスクで、単一のdistillation modelが10個のアンサンブルモデルと同等精度になったと報告されている。また、DistilBERT, a distilled version of BERTでは、BERT-baseの97%程度の精度を保ちつつ、40%小規模なモデルで、かつCPU実行が60%高速化されたそうです。(DistillBERTはhuggingfaceのチームが著者)

以上の基本的なDistillationの派生系として、Softmax後の値ではなく、中間の特徴マップを使ってDistillationしたり(Paying more attention to attention / Mobilebert)、Unlabeled dataに対して教師モデルで推論した結果のPseudo labelを用いて生徒モデルを学習し、生徒モデルの精度を向上するアプローチもある。