こんにちは。
普段あまり意識しないで使っている損失関数ですが、今一度原点に立ち返って見つめなおしてみたいと思います。
損失関数とは
損失関数は別名誤差関数とも呼ばれ、正解と不正解の差=誤差をできる限り小さくしてモデルを改善していくための指標として設定します。
この正解に近づいていくことを、不確実性を減らす とも言い換えらます。
不確実性100%の状態とは、サイコロのような全部の面が出る確率が同じである状態で、不確実0%の状態は、あるクラス(ラベル)が100%の確率で出るという状態です。
この「不確実性」がキーワードであり、損失関数は不確実性を数値化するものです。
情報理論の世界では、エントロピー によって不確実性を表現します。機械学習も同様で、モデルが出力する確率分布(どのクラスがどの確率で出現するかの分布)を表しています。
エントロピーは、クロード・シャノンさんが情報理論において導入したものだそうで、確率分布に基づいてその情報源の平均的な情報量を測る指標です。式をもう少し詳しく見ていくと以下のように分布全体の不確実性を測るために期待値を取っていることが分かります。
対数を取るのは、確率の積を足し算で表すことができるからです。
また、マイナスをつけているのは、確率が0~1に値を取るのでlogpは必ず0より小さくなってしまうからです。不確実性の大きさを測る指標としては、正の値であってほしいので反転させています。
エントロピーが何であるか分かったところで、機械学習の損失関数で用いられる クロスエントロピー や バイナリークロスエントロピー について説明します。
クロスエントロピー
エントロピーは一つの分布の不確実性を測ります。
では、真の確率分布とモデルが出した確率分布を比べたいときに使うのが、クロスエントロピーです。
ここで「符号化」という言葉が分かりづらいですが、情報を数値に置き換えることを指します。
ある事象について、Yes / Noしか選択肢がなければ、1 / 0 の1ビットで表せます。
たいして、春夏秋冬や東西南北という情報は00 / 01 / 10 / 11 の2ビットで表せます。このように事象の確率が高いほど短い符号で済みますが、低いほど長い符号が必要になります。
では、最適な符号長は?というと真の分布pによって決まるので、事象xの符号長は以下のように表せます。これは先ほど示した情報量I(x)と同じ式であるとおり、情報量と符号長は同じ式で解釈できます。
ただし、実際の学習では真の分布pは分からないので、モデルが出力する予測分布qを使って「この事象はq(x)だ」とみなして符号化します。
よって、真の分布pに従って事象が発生し、それをqで符号化するときの平均の符号長は以下で表せます。
バイナリークロスエントロピー
バイナリークロスエントロピーはその名のとおりクロスエントロピーの二値分類版です。
二値分類は0か1しかとりませんので、分布は以下のように表されます。
- 真の分布:($y$, $1-y$)
- 予測分布:($\hat{y}$, $1-\hat{y}$)
この分布にクロスエントロピーを適用します。
こうやってみるとシグマで表されていた式が展開されただけと分かります。
KLダイバージェンス
分布そのものを近づけたいときに使います。
特に蒸留学習で教師モデルのクラス間の確率分布に生徒モデルの確率分布を近づけたいときに使います。
この確率分布を近づけるということは、クラス間の関係性(距離)を学ばせることにつながります。
例えば、3クラス分類をする場合、モデルの出力が(猫 : 0.7, 犬 : 0.25, ウサギ : 0.05)だとしても、ラベルで表すと(猫 : 1.0, 犬 : 0.0, ウサギ : 0.0)となり、クラス間の関係が分かりません。このonehot表現では得られない「犬にも近くて、ウサギとは遠い」という関係を学ばせること蒸留の狙いです。
ちなみに、クロスエントロピーの関係は項を移動させると分かります。
ここでH(p)は真の分布だけで決まる定数なので、学習では無視できます。
つまりクロスエントロピーを最小化することは、KLを最小化することと同じことです。
クロスエントロピーとKLダイバージェンスのコード上の違い
数式上は似ていますが、PyTorchで使う際は注意が必要です。
クロスエントロピー
2通りのやり方がありますが、クロスエントロピーは、内部でlog_softmax
が働くのでlogits
で渡します。
- targetにクラス番号
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True) # 形=(B=3, C=5) の logits
target = torch.empty(3, dtype=torch.long).random_(5) # 形=(B,) の整数ラベル
output = loss(input, target)
output.backward()
第一引数input
はSoftmaxを適用する前の生の出力となっています。
第二引数target
は整数ラベルになっています。
- Targetにクラス確率
input = torch.randn(3, 5, requires_grad=True) # logits
target = torch.randn(3, 5).softmax(dim=1) # 形=(B,C) の分布(各行が和=1)
output = loss(input, target)
output.backward()
第二引数target
は確率分布(one-hotでもsoft-labelでも可)です。
KLダイバージェンス
こちらも2通りの書き方がありますが、自分でlog_softmax
(logits → log確率に変換)してから入力します。
- targetは確率分布、inputはlog確率
kl_loss = nn.KLDivLoss(reduction="batchmean")
input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1) # log q
target = F.softmax(torch.rand(3, 5), dim=1) # p
output = kl_loss(input, target)
第一引数input
はlog確率(log_softmax済み)で、
第二引数target
は確率分布になっています。
- targetもinputもlog確率
kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True) # log q
log_target = F.log_softmax(torch.rand(3, 5), dim=1) # log p
output = kl_loss(input, log_target)
第一引数input
も第二引数log_target
もlog確率となっています。
target にも log確率を入れる指定として、log_target=True
でtargetもlog確率だと明示しています。
ちなみに実際に蒸留で使いたいときは、温度パラメータにより教師モデルの出力の「柔らかさ」を調整します。これにより、確率分布をなめらかにします。
T = 4.0
log_qT = F.log_softmax(student_logits / T, dim=1)
pT = F.softmax(teacher_logits / T, dim=1)
loss_kd = F.kl_div(log_qT, pT, reduction="batchmean") * (T*T)
以上です。読んでいただきありがとうございました。