機械学習学習中のmizukickerです。
今回は@HiroZeldaさんの
Kerasによる知識の蒸留 (knowledge distillation) ~TGS(kaggle)~
https://qiita.com/HiroZelda/items/0ba24788c78540046bcd
を参考に学習して躓いた点を解説・共有します。
難しい話は苦手なので、イメージでざっくりつかめるように整えました。
蒸留とは?
ハイスペックコンサルがつくったマニュアルを
アルバイトの大学生でも実行できるようにすることです。
わかりやすいように下にたとえ話を続けます。
蒸留のたとえ話し
とある行列のできるラーメン屋。
このカリスマ店長のつくるラーメンは
この店舗の売上が1日50万!
その噂を聞きつけたコンサルが
ぜひともチェーン展開をめざしたいとやってきます。
このコンサルはディプラーニングを駆使し、
門外不出のラーメンを秘伝のレシピに仕上げマニュアル化することができる優秀な高給取りです。
(月給1億5000万)
彼はカリスマ店長のデータを逐一解析し、秘伝のマニュアル化に成功しました!!
このマニュアルに忠実に実行すれば、99.9%の再現度でラーメンを作ることが出来ます!
これでラーメンチェーンを展開すれば億万長者です!
しかし、ここで問題が発生します。
ハイスペックコンサルにラーメンを作らせていると、人件費がかかりすぎて全く利益が出ません。
そこでマニュアルはできたので、アルバイトの大学生にやらせて儲けようとしました。
これで黒字化するかと思いきや!
このハイスペコンサル用マニュアルだとアルバイトくんは読めなくて困ってしまいます。
正確には読めるのですが、英語で書いてあったり、複雑な数式が書いてあるので
時間がかかりすぎてしまい商売になりません。
そこで
コンサルがマニュアルを実行する様子をもとに、
アルバイト君にマニュアルを渡して、見よう見まねで試してもらいます。
コンサルトアルバイトが作ったラーメンに差が出ないように教えます。
この作業が蒸留です。
このとき、バイト君はマニュアルのどこが大事なのかわかりません。
一緒にコンサルがマニュアルを見ながら体で覚えます。
ここでコンサルさんがマニュアルに付箋を貼って
どこか重要か目印つけます。
これでバイト君でも学習しやすくなります。
これが蒸留時の温度パラメータです。
(後半で更に詳しく解説します。)
これでやっと、カリスマの味を再現したラーメンをアルバイト君が作れるようになりました。
人件費が抑えられて、どんどんチェーン展開し
安くて美味しくてラーメンチェーンが繁盛しました!
めでたしめでたし
温度パラメータの補足
なんで温度パラメータをTで割るといい感じなんでしょうか?
そもそもソフトマックスってなんでしたっけ?
ソフトマックスの復習
いろいろ数字を入れると全体の量に対する割合を
確率にして返してくれる便利な関数です!
具体的に見ていきましょう。
[a1, a2, a3] = [0.9,0.02,0.08]の場合は
\alpha1=\frac{\exp(0.9)}{\exp(0.9)+\exp(0.02)+\exp(0.08)}=0.00667641
\alpha2=\frac{\exp(0.02)}{\exp(0.9)+\exp(0.02)+\exp(0.08)}=0.99086747
\alpha3=\frac{\exp(0.08)}{\exp(0.9)+\exp(0.02)+\exp(0.08)}=0.00245611
となり、コードで書くとこんな感じです。
import numpy as np
def softmax(a):
c = np.max(a)
exp_a = np.exp(a - c)
sum_exp_a = np.sum(exp_a)
y = exp_a / sum_exp_a
return y
X = np.asarray([2,7,1])
print(X)
# [0.00667641 0.99086747 0.00245611]
print(X[0]+X[1]+X[2])
# 1
この数値をラーメン屋に例えると、
湿度が高い日は
0.00667641の確率で麺の湯で具合は10秒
0.99086747の確率で麺の湯で具合は20秒
0.00245611の確率で麺の湯で具合は30秒
茹でると美味しくなる!
とマニュアルに書いてある感じですね
こうすると、アルバイト君は湿度の高い日は何でもかんでも20秒茹でちゃいますね。
引き継ぎ資料を自分で作成するときに(付箋を貼るときに)
他の時間を試さないで柔軟な対応ができなくなってしまいます。
温度パラメータのTで割ると、どうなるか?
T = 10
Y = X/T
print(Y)
# [0.2 0.7 0.1]
soft_Y = softmax(Y)
print(soft_Y)
# [0.28140804, 0.46396343, 0.25462853]
湿度が高い日は
0.28140804の確率で麺の湯で具合は10秒
0.46396343の確率で麺の湯で具合は20秒
0.25462853の確率で麺の湯で具合は30秒
とマニュアルに付箋を貼ってわかりやすくする感じです。
こうすると、アルバイト君は湿度の高い日は何でもかんでも20秒茹でちゃうのではなく、
引き継ぎ資料を作成するときに
自分で色々と工夫した上で湯で時間を考えるようになるので
自発的に動けるアルバイトとして、店の味を守れるようになります!
最後に
ざっくりとイメージをつかめたら幸いです。
これを読んでから、もっと詳しい記事や実装を試すとスムーズに学習できるかと思います!