1.概要
独学で機械学習・データサイエンスの勉強を始めてから約3年が経ちました。いろいろなデータに触れてきましたが、その中でも回帰問題を解く機会が多かったように思います。その経験上個人的に盲点だなと感じたのが損失関数をいかに定義するかという部分です。機械学習というとやはりなんのモデル使うのかだったりとか、いかにデータを整えるかに注目が行きがちです。しかし、自分の目的(=どんな回帰がしたいか)を照らし合わせて、適切な損失関数を定義しなくては意味のないものを学習させてしまいます。例えるなら、英会話ができるようになりたいのに、大学受験の英語を勉強しているかのようです。そのため、今回は自分の持っているデータに対応した回帰問題の損失関数を定義する方法を紹介します。ここで紹介する方法は一般化線形モデルという頻度主義統計学に基づいた統計モデルの線形性の仮定を取り除き、機械学習に拡張させた方法とも捉えることができます。また、この記事の最後には任意の確率分布に一般化した損失関数となるBregman Divergenceについて紹介できればと思います。
2.対象
この記事は機械学習フレームワークの実装経験がある方を対象としています。
また、この記事では以下の事項についてはここでは説明しません。
- 機械学習フレームワークでの実装方法
- 基本的な確率・統計の知識(確率分布・最尤推定法など)
3.はじめに
統計的機械学習で回帰問題を解く場合、モデル$g(X|\lambda)$のパラメータ$\lambda$に対して、目的に沿った損失関数$L$を定義し、データセット$X\in \mathbb{R}^{N \times m}, Y \in \mathbb{R}^{N}$ を用いて$L$を最小にする$\lambda$を求めます。数式を使うと、
\min_{\lambda} L(Y,g(X|\lambda))
と表されます。DeepLearningでは、
- $g(X|\lambda)$ : ネットワーク全体
- $\lambda$ : $W,b$など
に相当し、LightGBMなどのアンサンブルツリーモデルでは、
- $g(X|\lambda)$ : 決定木の集合全体
- $\lambda$ : 分岐や葉の重みなど
に相当します。そして、ごく単純な回帰問題では、
L_{MSE}(Y,g(X|\lambda)) = \sum_i\left(y_i - g(X_i|\lambda)\right)^2
と表されるMSEが損失関数として用いられます。また、ロジスティック回帰のように二値分類を回帰で解くときは、
L_{CE}(Y,g(X|\lambda)) = \sum_i-y_i\log(g(X_i|\lambda))-(1-y_i)\log\left(1-g(X_i|\lambda)\right)
と表されるCross Entropyがよく用いられます($X_i,y_i$は$X,Y$の$i$番目のサンプル)。ただ、このCEの定義は有名であるものの、$g(X|\lambda)$の値域が[0,1]という制約があります。この記事では機械学習モデル$g(X|\lambda)$の値域を一律に実数全体として扱いたいので、sigmoid関数を使って書き直した、
L_{CE}(Y,g(X|\lambda)) = \sum_i-y_i\log\left(\frac{1}{1+\exp(g(X_i|\lambda))}\right)-(1-y_i)\log\left(1-\frac{1}{1+\exp(g(X_i|\lambda))})\right)
をCEの定義とします。
主に機械学習のタスクは分類と回帰に分かれており、回帰ではMSEとCEくらい覚えておけばモデルは動きます。しかし、以下のデータを扱うときはMSEやCrossEntropyが損失関数として適切でない場合があります。
- N段階の順序を表すデータ
- 枚数や個数を表すデータ
- 所用時間を表すデータ
- 発生確率を表すデータ
なぜなら、上記のデータは値の範囲が限られていたり、連続でない値をとったりするからです。そのため、二値でない値だからといって何も考えずにMSEを使っていては実際には起こり得ない値(0より小さい値など)を予測値として出力してしまいます。また、MSEでは100個と101個の差も、0個と1個の差も同じ1として扱うため、適切な「近さ」を反映できていない場合があります。よって、このような問題を解決するには相応な工夫が必要になります。以下からは上記のような、データに対する適切な損失の導き方を紹介します。
4.Cross Entropyとはなにか
MSEは直感的に「本当の値に近づける」という感がありますが、Cross Entropyはなぜ回帰問題に使われるのでしょうか。そこでまずはCross Entropyがなぜ回帰問題の損失関数として用いられているのかを例にとって説明することで、同じフォーマットで上記のデータに対応する損失関数を導きたいと思います。
4-1.モデルの値域の変換
まずはモデルの値域の変換についてです、3.で定義したように$g(X|\lambda)$の値域は実数全体としているため、値域を制限する必要があります。ここではその方法について確認します。
二値分類とは目的変数$Y$が0(負例)または1(正例)のみをとるときに、モデル$g(X|\lambda)$によってサンプル$X_i$が正例である確率$p$を求める問題です。しかしモデルは一般に値域が実数全体であるため、確率として適切な値域[0,1]に変換する必要があります。
その場合によく使われる関数がsigmoid関数です。sigmoid関数は、
\operatorname{sig}(x)=\frac{1}{(1+\exp(-x))}
のように表され、二値分類問題では$p = \operatorname{sig}(g(X|\lambda))$とすることが一般的です。
4-2.ベルヌーイ分布
しかし、このように値域を指定したからと言ってMSEを使って良いという話ではありません。次に考えるべきは、$g(X|\lambda)$において、$\lambda$はどのようなパラメータであるべきかです。
これを説明するためにここで二値分類のモデルについてもう少し詳細に説明します。統計的機械学習で二値分類を解く場合、予測モデルへの入力$X_i$に対応して、ベルヌーイ分布($\operatorname{Bernoulli}(y|\theta_i)$)のパラメータ$p_i\in [0,1]$が決定すると考えます。すなわち$p_i = \operatorname{sig}(g(X_i|\lambda))$となります。そしてこのベルヌーイ分布に基づいて確率的に$y_i$が観測されるとみなします。つまり、モデルは
y_i \sim \operatorname{Bernoulli}(y | p = g(X_i|\lambda))
と表すのが厳密には正しいと言えます。そして、このような統計モデルを仮定するとき、頻度主義統計学1では求めるべきパラメーターを最尤推定法によって求めます。最尤推定法によって求めたパラメータは最尤推定量と呼び、サンプル数が十分なとき、一致性,漸近的正規性,漸近的有効性など、統計的に良い性質を持ちます。そのため、「良い」パラメータ$\lambda$は、$p_i = \operatorname{sig}(g(X_i|\lambda))$としたベルヌーイ分布の尤度最大化によって求められます。実はこの尤度に対数をとって(-1)をかけた「負の対数尤度」がCross Entropyに一致します。
具体的にCross Entropyを求めてみます。ベルヌーイ分布の尤度は
l(p) = p^y(1-p)^{(1-y)}
と表されるので、上記のモデルに従うNサンプルの同時尤度は
l(\lambda) = \prod_i p_i^{y_i}(1-p_i)^{(1-y_i)}
と表せます。さらに対数をとって、$p_i = \operatorname{sig}(g(X_i|\lambda))$を代入すれば、
L(\lambda)=-\log(l(\lambda)) = \sum_i-y_i\log\left(\frac{1}{1+\exp(g(X_i|\lambda))}\right)-(1-y_i)\log\left(1-\frac{1}{1+\exp(g(X_i|\lambda))})\right)
となり、3.で定義したCross Entropyと一致しました。(-1)をかけているので、尤度の最大化がこのCross Entropyの最小化に一致します。
4-3.Cross Entropyまとめ
以上を要約すると、
- sigmoid関数によって値域を変換
- ベルヌーイ分布に基づいた尤度最大化(= 負の対数尤度の最小化)
によって二値分類問題の損失関数Cross Entropyが導出できたと言えます。この2点が以下での鍵となってきます。
5.任意の確率分布での定式化
4.では、sigmoid関数による値域の変換と、ベルヌーイ分布に基づく最尤推定法によって導くことを説明しました。この値域の変換と最尤推定法こそが3.で挙げたデータに対応する損失関数を導くツールとなります。以下からはこれらを用いて任意の確率分布への拡張をし、損失関数を定義したいと思います。
5-1.定式化
$\theta \in \mathbf{D}$をパラメータとする確率分布$P(y|\theta)$を仮定した場合の損失関数を導きます。
まず、値域の変換は、関数$T:\mathbb{R} \mapsto \mathbf{D}$によって行うとします。そうすると、等式
\theta_i = T(g(X_i|\lambda))
が成り立ちます。これを用いるとNサンプルの同時尤度は
l(\lambda) = \prod_i P(y_i|\theta_i) = \prod_i P(y_i|\theta_i = T(g(X_i|\lambda)))
と表せるので、これに対数をとって(-1)をかければ、
L(\lambda) = -\log(l(\lambda)) = -\sum_i \log\left(P(y_i|\theta_i = T(g(X_i|\lambda)))\right)
となり、最小化したい損失関数を導出できました。もう一度振り返ると、$T$をsigmoid関数、$P(y|\theta)$をベルヌーイ分布とすればこの損失関数がCross Entropyとなります。また、この考え方はMSEについても同様に言えて、$T$を$f(x)=x$、$P(y|\theta)$を(分散一定の)正規分布としたときの損失関数がMSEとなります。
5-2.例(カウントデータ)
5-1について、3で挙げた扱いの難しいデータのうちカウントデータを一つ例として説明します。
目的変数がカウントデータ(非負離散値)である場合、$P(y|\theta)$としてポアソン分布を用いることができます。ポアソン分布は、
P(y|\theta) = \frac{\theta^y\exp(-\theta)}{y!}
と表されるので、Nサンプルの同時尤度は、
l(\lambda) = \prod_i\frac{\theta_i^{y_i}\exp(-\theta_i)}{y_i!}
と表されます。また、ポアソン分布において$\theta$は$\theta\in(0,\infty)$を満たすので、値域の変換は、指数関数$e^x$を用います。つまり、
\theta_i = T(g(X_i|\lambda))=\exp(g(X_i|\lambda))
ということです。これを上の式に代入し、対数をとって(-1)をかければ、
L(\lambda) = -\log l(\lambda) = \exp(g(X_i|\lambda))-y_ig(X_i|\lambda)+C
となり、カウントデータ=ポアソン分布を仮定したときの損失関数が定義できました。この損失関数はポアソン損失と呼ばれます。(Cは定数)
5-3. そのほかの確率分布
5-3ではカウントデータを例に、ポアソン分布の尤度関数からポアソン損失を導きました。ここで詳細は割愛しますが、この考えを使うことで、3.で挙げたデータを全て扱うことができます。具体的には順序を表すデータには二項分布、時間を表すデータには指数分布、確率を表すデータにはベータ分布などが適用できます。
また、値域の変換に関しても幅広く対応できます。例えばポアソン分布のように負の値を取らないなどで、上界または下界のどちらかのみを指定したいときには指数関数$f(x) = a \pm e^x$とすれば、値域を$x \gtrless a$とすることができます。また、ベルヌーイ分布のように上界下界どちらも指定したいときは、sigmoid関数を拡大/平行移動した、
f(x)=(b - a)\frac{1}{(1+\exp(-x))} + a
を用いることで、値域を(a,b)とすることが可能です。
5-4.一般化線形モデルとの関係
このように、5-1を使うことで簡単に損失関数を導くことができました。この値域の変換と最尤推定法による回帰は、一般化線形モデルと深い関わりがあります。
というのも、一般化線形モデルという統計モデルは、単純な重回帰モデル
y = a + \mathbf{b}^T\mathbf{x}
を任意の確率分布に拡張したモデルです。一般化線形モデルはここでは深く説明しませんが、実はこの統計モデルの定式化は、5-1で説明した$g(X|\lambda)$に線形性の仮定を追加したモデルと一致するのです。また、値域の変換は一般化線形モデルで言うところのlink関数(の逆関数)の役割を果たしています。
一般化線形モデルについては、みどり本として有名な「データ解析のための統計モデリング入門(久保拓弥)」で詳しく説明されています。
5-5. 定式化のまとめ
このセクションでは、モデルの値域の変換と確率分布の仮定をするとによって、「良い」損失関数を導けることを説明しました。ここまで読んでいただけた方々は、もうどんな目的変数でも対応する適切な損失関数を導くことができるようになっているはずです。
6.(おまけ)自然指数型分布族と一般化Bregman Divergenceによってさらなる一般化を目指す。
5.までの内容で基本的な確率分布は抑えることができたはずです。しかし、「わざわざ確率分布から損失を計算するのは面倒」とか、「もっと複雑な確率分布を扱いたい」とかいう要望もあるかもしれません(そう思わない方は読み飛ばしてください)。そんな要望に応えるには、自然指数型分布族と一般化Bregman Divergenceを用いる必要があります。この二つの概念はあまり有名ではないので、以下に定義および、それら二つから導かれる定理を記述しておきます。
6-1.【定義】自然指数型分布族 (Natural exponential family)
1変数関数$F$をパラメータとする確率分布$p_F(y|\theta)$が、
p_F(y|\theta) = p_0(y)\operatorname{exp}(\theta y -F(\theta))\\
F(\theta) = \operatorname{log}\int_{-\infty}^{\infty}p_0(x)\operatorname{exp}(\theta x) \operatorname{dx}
と表されるとき、自然指数型分布族に属する。また、ここで$\theta$は自然パラメータと呼ぶ。
6-2.【定義】一般化Bregman Divergence
凸関数$F: \mathbb{R} \rightarrow \mathbb{R}$に対して,変数$z,y$の一般化Bregman Divergenceは、
\mathbb{D}_{F}(\theta \| y)=F(\theta)+F^{*}(y)-\theta y
と定義される。ただし、$F^\ast$は$F$の凸双対で$F^{*}(\mu)=\sup _{\theta \in \operatorname{dom} F}[\langle\theta, \mu\rangle- F(\theta)]$を満たす。
6-3.【定理】自然指数型分布族の損失関数はBregman Divergenceに一致する
$\theta$に関する自然指数型分布族の尤度最大化はBregman Divergence$\mathbb{D}_{F}(\theta |x)$の最小化と一致する。
[証明]
上記の自然指数型分布の定義式に対数をとると、
\log p_{F}(x | \theta)=\log p_{0}(x)+\theta x - F(\theta)
となる。そこにBregman Divergenceの定義式を変形した
\theta x - F(\theta) = F^{*}(x)-\mathbb{D}_{F}(\theta \|x)
を代入すると、
\log p_{F}(x | \theta)=\log p_{0}(x)+F^{*}(x)-\mathbb{D}_{F}(\theta \|x)
となる。$\theta$に関して$\log p_{0}(x)$と$F^{*}(x)$は定数であるので、自然指数型分布族の対数尤度の最大化Bregman Divergence$\mathbb{D}_{F}(\theta |x)$の最小化に一致する。
[証明終了]
6-4.損失関数としてのBregman Divergence
上記をまとめると、
- 任意の凸関数$F$をパラメーターとする確率分布をまとめて自然指数型分布族とする。
- 関数$F$に関してBregman Divergenceという指標がある。
という二つの定義から、
- 自然指数型分布族の確率分布に基づく損失関数は関数$F$に関するBregman Divergenceと一致する
という定理が得られるということになります。また、$\mathbb{D}_{F}$において$F^{*}(x)$は$\theta$に依存しないので無視でき、損失関数は実質的に
L(\theta) = \theta x - F(\theta)
となり、こんなに簡単に損失関数が書けるのです。こんな嬉しい性質を持つ自然指数型分布族ですが、実は正規分布、ポアソン分布、指数分布などはこの自然指数型分布族に属しています(wikipedia)。そのため、これらに対応した$F$さえ覚えておけば、Bregman Divergenceによって簡単に損失関数が求められます。具体的には、
- 正規分布: $F(\theta) = \theta^2, \theta \in \mathbb{R}$
- ポアソン分布: $F(\theta) = \exp(\theta),\theta \in (0,\infty)$
- ベルヌーイ分布: $F(\theta)=\operatorname{log}(1+\exp(\theta)),\theta \in (0,1)$
- 指数分布: $F(\theta)=-\log(-\theta),\theta \in (-\infty,0)$
という対応です。($\theta$の定義域を確認してモデルの値域を変換する必要が残ることに注意)
また、このほかでも自分が使いたい確率分布を自然指数型分布族として表せば同じように考えられます。ちなみに、この$F(\theta)$はキュムラント母関数(cumulant generating function)や分配関数(log-partition function)などと呼ばれ、モーメント母関数(積率母関数、moment generating function)に対数をとった関数です。すなわち、ここなどから分布を特定してmoment generating functionを参照し、それに対数をとれば$F$は計算しなくてもわかるのです。$F$さえわかってしまえば、やはり簡単に損失関数が定義できます。
7.おわりに
今回は、頻度主義統計学に基づく統計的機械学習理論から適切な損失関数を定義する方法について説明しました。この方法は統計モデルとして有名な一般化線形モデルの機械学習への拡張と捉えることもできます。実務でも使える内容なはずなので手元に良いデータがある方は是非使ってみてください。そして最後は、少々難解な数学を用いて、確率分布について一般化した自然指数型分布族に基づく損失関数を紹介しました。この方法は損失関数を求める操作をかなり単純化させるので、自分もこの方法を知ったときは驚愕しました。なかなか使うチャンスは少ないかもしれませんが、よければこちらも使っていただけると嬉しいです。また、この記事ではBregman Divergenceについては必要最低限しか紹介しなかったので、今後機会があればそちらも詳細を記事にしたいと思います。
最後までご覧いただきありがとうございました。
-
ベイズ統計学なら事後分布最大化法などがとられる ↩