AIモデルの精度確認や学習時の最適化のために誤差を評価する時に使う絶対誤差と二乗誤差について、2乗の方が微分が出来るし計算がしやすいから二乗誤差を使う、という説明が多いと思います。それなら二乗誤差のみが残って絶対誤差を使う場面はなくなってしまうと思うのですが、それぞれ収束する先が違うからどちらも残っているのだろうと思っています。
どこに収束していくはずだろうかというのを調べてみます。
損失関数
実際の値を $\boldsymbol{y}$ (各要素は $y_i$)、予測値を $\boldsymbol{\hat{y}}$ (各要素は $\hat{y_i}$) として記載します。
また、その個数は $n$ とします。
誤差はその差のことなので、 $\boldsymbol{y} - \boldsymbol{\hat{y}}$ (各要素は $y_i-\hat{y_i}$)のことを指していて、それが小さい方が良いです。
ただ、この値は多次元なので例えば誤差が $\left( \begin{array}{c} -2\\7 \end{array} \right)$ だった予測Aと誤差が $\left( \begin{array}{c} 4\\-6 \end{array} \right)$ だった予測Bとだと、どちらが誤差が小さいのかの比較ができないため、比較ができるようにするためのものです。
平均絶対誤差(mean absolute error)
maeと省略される関数で計算されます。誤差の各要素の絶対値の平均を使います。
\mathrm{mae}(\boldsymbol{y},\boldsymbol{\hat{y}})=|y_i-\hat{y_i}|の平均=\frac{1}{n}\sum_{i=1}^n|y_i-\hat{y_i}|
平均二乗誤差(mean squared error)
mseと省略される関数で計算されます。誤差の各要素の2乗の平均を使います。
\mathrm{mse}(\boldsymbol{y},\boldsymbol{\hat{y}})=(y_i-\hat{y_i})^2の平均=\frac{1}{n}\sum_{i=1}^n(y_i-\hat{y_i})^2
収束先
普通はそれぞれ別の入力に対する予測値と実際の値があるので、そもそもどちらも一定ではないものだと思います。
しかし、ここでは入力が同じものを集めてきたという仮定で考えてみます。
すると、予測値はランダムな要素を入れていない予測モデルならば一定の予測値になります。
対して、実際の値は誤差項によるブレがあるので一定とは限りません。
この時、予測値が誤差項に対してどのように影響を受けるのかを考えることで、どこに収束する傾向があるのかを考ええてみます。
仮定を式の形で記載すると、とある予測値 $t$ があり、任意の $i$ に対して $\hat{y_i}=t$ ということになります。
平均絶対誤差の場合
$\mathrm{mae}(\boldsymbol{y},\boldsymbol{\hat{y}})$ が1番小さくなるように $t$ を求めます。
\begin{eqnarray}
\mathrm{mae}(\boldsymbol{y},\boldsymbol{\hat{y}})
&=&\frac{1}{n}\sum_{i=1}^n|y_i-\hat{y_i}|\\
&=&\frac{1}{n}\sum_{i=1}^n|y_i-t|\\
&=&\frac{1}{n} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}|y_i-t|+ \sum_{\substack{1 \leq i \leq n\\y_i=t}}|y_i-t|+ \sum_{\substack{1 \leq i \leq n\\y_i>t}}|y_i-t| \right)\\
&=&\frac{1}{n} \left( \sum_{\substack{1 \leq i \leq n\\y_i < t}}(t-y_i)+ \sum_{\substack{1 \leq i \leq n\\y_i > t}}(y_i-t) \right)
\end{eqnarray}
※補足
シグマ記号の部分が見慣れない書き方だと思いますが、これはシグマの下の条件式を and 条件として見て、真になる項たちを足していきますよ、という意味になります。
また、場所の都合で省略しましたが、iは整数という条件で足しています。
脱線しますが、$\displaystyle{\sum_{\substack{1 \leq k \leq 10\\kが偶数}}k^2}$ を数式の形とpythonのコードで求める方法を書くので、分かりやすい方で理解いただければ幸いです。
\begin{eqnarray}
\sum_{\substack{1 \leq k \leq 10\\kが偶数}}k^2
&=& (1から10の間で偶数のみ取り出して、k^2にして足す)\\
&=& (2,4,6,8,10を順番にkに代入して、k^2にして足す)\\
&=&2^2+4^2+6^2+8^2+10^2\\
&=&220
\end{eqnarray}
summary = 0
for k range(1,11): # 1~10の範囲を考える
if k % 2 == 0: # 範囲の条件以外の条件文もtrueならば
summary += k**2 # シグマの中身の値を足し合わせる
print(summary)
では平均絶対誤差の最小値探しに戻ります。
最小値を求める場合、微分を使って増減表を書くと分かりやすいと思っています。
そのためこの式を $t$ の関数だと考えて微分してみます。ここで、 $t$ の関数だと考える理由は、入力の値や $\boldsymbol{y}$ は固定をしたと仮定しているので定数だとみなしているのと、 $t$ に着目をしたいという意図があるから $t$ のみを動かすと考えています。
また、 $t$ がいづれかの $y_i$ と重なる場所では微分不可能なのですが、それ以外の場所での傾きを調べるため、定義域は $y_i$ に重ならないところとして微分をします。
\begin{eqnarray}
\frac{d}{dt} \mathrm{mae}(\boldsymbol{y},\boldsymbol{\hat{y}})
&=&\frac{d}{dt} \left( \frac{1}{n} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}(t-y_i)+ \sum_{\substack{1 \leq i \leq n\\y_i>t}}(y_i-t) \right) \right)\\
&=& \frac{1}{n} \left( \frac{d}{dt} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}(t-y_i)+ \sum_{\substack{1 \leq i \leq n\\y_i>t}}(y_i-t) \right) \right)\\
&=& \frac{1}{n} \left( \frac{d}{dt} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}(t-y_i) \right) + \frac{d}{dt} \left( \sum_{\substack{1 \leq i \leq n\\y_i>t}}(y_i-t) \right) \right)\\
&=& \frac{1}{n} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}\frac{d}{dt}(t-y_i) + \sum_{\substack{1 \leq i \leq n\\y_i \geq t}}\frac{d}{dt}(y_i-t) \right) \\
&=& \frac{1}{n} \left( \sum_{\substack{1 \leq i \leq n\\y_i<t}}1 + \sum_{\substack{1 \leq i \leq n\\y_i>t}}(-1) \right) \\
&=& \frac{1}{n} ((y_i < t になる個数) - (y_i>tになる個数)) \\
\end{eqnarray}
最後が突然日本語になりましたが、「集合に含まれる要素の数」という記号を導入するよりは分かりやすいと思うので許してください。
では増減表を書いていきますが、要素の個数 $n$ が奇数か偶数かで分かれてしまうので、2通りに分けて書きます。
(1)$n$ が奇数で、 $n=2k+1$ と表せるとき
$t$ が $y_i$を小さい方から数えて$k+1$番目の値より大きいとすると、($y_it$) になる個数が $k$ 個以下になるので、微分した値は正になります。
同じように $t$ が $y_i$を小さい方から数えて$k+1$番目の値より小さいとすると、($y_it$) になる個数が $k+1$ 個以上になるので、微分した値は負になります。
$t$の値 | -∞ | … | $y_i$を小さい方から数えて$k+1$番目の値 | … | ∞ |
---|---|---|---|---|---|
微分した値 | -n | 負の値 | - | 正の値 | n |
グラフの傾き | ↘ | ↘ | - | ↗ | ↗ |
なので、$t$ が最小になるのは、$y_i$を小さい方から数えて$k+1$番目の値のときになります。
(2)$n$ が偶数で、 $n=2k$ と表せるとき
$t$ が $y_i$を小さい方から数えて$k+1$番目の値より大きいとすると、($y_i\lt t$) になる個数が $k+1$ 個以上で、($y_i>t$) になる個数が $k-1$ 個以下になるので、微分した値は正になります。
同じように $t$ が $y_i$を小さい方から数えて$k$番目の値より小さいとすると、($y_i\lt t$) になる個数が $k-1$ 個以下で、($y_i>t$) になる個数が $k+1$ 個以上になるので、微分した値は負になります。
また、ちょうど $t$ が $y_i$を小さい方から数えて$k$番目の値と$k+1$番目の値との間にある場合は、($y_i\lt t$) になる個数も($y_i>t$) になる個数も $k$ 個になるので、微分した値は0になります。
$t$の値 | -∞ | … | $y_i$を小さい方から数えて$k$番目の値 | … | $y_i$を小さい方から数えて$k+1$番目の値 | … | ∞ |
---|---|---|---|---|---|---|---|
微分した値 | -n | 負の値 | - | 0 | - | 正の値 | n |
グラフの傾き | ↘ | ↘ | - | → | - | ↗ | ↗ |
なので、$t$ が最小になるのは、$y_i$を小さい方から数えて$k$番目の値と$k+1$番目の値との間にいるときになります。
ただし、$k$番目の値と$k+1$番目の値がちょうど同じになる場合は幅はなく、小さい方から数えて$k$番目のことだと考えます。
(1)(2)を合わせて考えると、中央値の時が一番平均絶対誤差が小さくなるので、中央値に収束するだろうと考えられます。
平均二乗誤差の場合
maeと同じく $\mathrm{mse}(\boldsymbol{y},\boldsymbol{\hat{y}})$ が1番小さくなるように $t$ を求めますが、こちらは全ての $t$ の値で微分が可能なので、もう少し簡単に式変形が出来ます。
\begin{eqnarray}
\mathrm{mse}(\boldsymbol{y},\boldsymbol{\hat{y}})
&=&\frac{1}{n}\sum_{i=1}^n(y_i-\hat{y_i})^2\\
&=&\frac{1}{n}\sum_{i=1}^n(y_i-t)^2\\
&=&\frac{1}{n}\sum_{i=1}^n(y_i^2-2y_it+t^2)\\
\frac{d}{dt} \mathrm{mse}(\boldsymbol{y},\boldsymbol{\hat{y}})
&=&\frac{d}{dt} \left( \frac{1}{n} \left( \sum_{i=1}^n(y_i^2-2y_it+t^2) \right) \right)\\
&=&\frac{1}{n} \left( \frac{d}{dt} \left( \sum_{i=1}^n(y_i^2-2y_it+t^2) \right) \right)\\
&=&\frac{1}{n} \left( \sum_{i=1}^n \left( \frac{d}{dt}(y_i^2-2y_it+t^2) \right) \right)\\
&=&\frac{1}{n} \left( \sum_{i=1}^n (-2y_i+2t) \right)\\
&=&\frac{1}{n} \left( \sum_{i=1}^n (-2y_i)+\sum_{i=1}^n (2t) \right)\\
&=&\frac{1}{n} \left(-2 \sum_{i=1}^n y_i+2t \sum_{i=1}^n 1 \right)\\
&=&\frac{1}{n} \left(-2 \sum_{i=1}^n y_i+2tn \right)\\
&=& -2 \frac{\sum_{i=1}^n y_i}{n} +2t
\end{eqnarray}
$\frac{\sum_{i=1}^n y_i}{n}$ は $\boldsymbol{y}$ の平均値であることに注意して増減表を書いていきます。
$t$ が $\boldsymbol{y}$ の平均値と一致しているときちょうど微分した値は0になります。
もし平均値よりも大きい場合は正に、平均値よりも小さい場合は負になります。
$t$の値 | -∞ | … | $\boldsymbol{y}$の平均値 | … | ∞ |
---|---|---|---|---|---|
微分した値 | -∞ | 負の値 | 0 | 正の値 | ∞ |
グラフの傾き | ↘ | ↘ | → | ↗ | ↗ |
なので、$t$ が最小になるのは、$\boldsymbol{y}$ の平均値と一致しているときになります。
平均値の時が一番平均二乗誤差が小さくなるので、平均値に収束するだろうと考えられます。
具体例
(1,2,3,6,8) という5つのデータだった場合に $t=2,3,4,5$ それぞれの平均絶対誤差と平均二乗誤差を求めてみます。
ただし、このデータは中央値が3, 平均値が4になっています。
(1) t=2の時
各データとの誤差は $(1-2,2-2,3-2,6-2,8-2)=(-1,0,1,4,6)$ となります。
平均絶対誤差 = $\frac{1}{5}(|-1|+|0|+|1|+|4|+|6|) = \frac{1}{5}(1+0+1+4+6) = \frac{12}{5} = 2.4$
平均二乗誤差 = $\frac{1}{5}((-1)^2+0^2+1^2+4^2+6^2) = \frac{1}{5}(1+0+1+16+36) = \frac{54}{5} = 10.8$
(2) t=3の時
各データとの誤差は $(1-3,2-3,3-3,6-3,8-3)=(-2,-1,0,3,5)$ となります。
平均絶対誤差 = $\frac{1}{5}(|-2|+|-1|+|0|+|3|+|5|) = \frac{1}{5}(2+1+0+3+5) = \frac{11}{5} = 2.2$
平均二乗誤差 = $\frac{1}{5}((-2)^2+(-1)^2+0^2+3^2+5^2) = \frac{1}{5}(4+1+0+9+25) = \frac{39}{5} = 7.8$
(3) t=4の時
各データとの誤差は $(1-4,2-4,3-4,6-4,8-4)=(-3,-2,-1,2,4)$ となります。
平均絶対誤差 = $\frac{1}{5}(|-3|+|-2|+|-1|+|2|+|4|) = \frac{1}{5}(3+2+1+2+4) = \frac{12}{5} = 2.4$
平均二乗誤差 = $\frac{1}{5}((-3)^2+(-2)^2+(-1)^2+2^2+4^2) = \frac{1}{5}(9+4+1+4+16) = \frac{34}{5} = 6.8$
(4) t=5の時
各データとの誤差は $(1-5,2-5,3-5,6-5,8-5)=(-4,-3,-2,1,3)$ となります。
平均絶対誤差 = $\frac{1}{5}(|-4|+|-3|+|-2|+|1|+|3|) = \frac{1}{5}(4+3+2+1+3) = \frac{13}{5} = 2.6$
平均二乗誤差 = $\frac{1}{5}((-4)^2+(-3)^2+(-2)^2+1^2+3^2) = \frac{1}{5}(16+9+4+1+9) = \frac{39}{5} = 7.8$
tの値 | 2 | 3 | 4 | 5 |
---|---|---|---|---|
平均絶対誤差 | 2.4 | 2.2 | 2.4 | 2.6 |
平均二乗誤差 | 10.8 | 7.8 | 6.8 | 7.8 |
それぞれ最小になるところがずれてはいますが、(少なくとも整数の範囲では)中央値や平均値のときが一番小さくなることが確認できました。
最後に
平均絶対誤差と平均二乗誤差が目指す最小の点に注目して違いを調べてみました。
損失関数の性質から、データのばらつきに対して「平均値のあたりを予測してほしい」なのか「中央値のあたりを予測してほしい」なのかでどちらを使うのかを判断してもよさそうです。
外れ値だったりどちらかにデータが偏っているような分布だったりするときにはこの2つの代表値はずれる傾向にあるので、学習データの分布の確認の仕方の1つの参考にするとよいかもしれません。
参考文献
回帰の誤差、2乗するか絶対値をとるか
適切な誤差指標の選び方
【評価指標】MAE とは
【評価指標】平均二乗誤差 (MSE)とは