座標を出力する物体検出モデルにおける、MSE損失の罠

はじめに

画像から物体検出し、ロボットアームでその物体をつかむ、といったタスクを考えてみます。
物体検出では対象の座標を知りたいので、画像を入力、座標を出力とする回帰問題として扱えばよさそうです。
画像認識さえできれば簡単に思えます。

しかしながら、損失関数の選択に罠があります。
回帰問題でよく使われるMean Squared Error(MSE)やMean Absolute Error(MAE)を使う場合、対象物体が見分けやすい時には問題ないのですが、見分けにくいダミー物体がある場合にはうまく行きません。
「二兎を追う者一兎をも得ず」のような状況が発生してしまうのです。

ちなみに、物体検出に限らず、回帰問題においては同様の問題が発生することがあります。
YOLOの実験をしているときにもこのような問題に遭遇しました。
WaveNetやMaskRCNN(キーポイント検出)で、連続値であるはずの出力をあえて分類として扱っているのも、このような状況に対応するためと思われます。

今回は、上記のような問題(名前はついているんでしょうか?)を、損失関数の設計で回避する方法を紹介します。
問題を再現するシンプルなタスク設定で、損失関数の動きを見てみましょう。

コードと実験結果はこちらにあります:
https://gist.github.com/stnk20/58352ae84a4208229d1d8306f7fae8cc

タスク設定

画像上に見分けのつかない複数の点があります。どれかひとつが対象物体でほかはダミーです。対象物体の座標を答えることが目標です。
image.png
今回、画像の部分は本質的ではないので、画像は小さない16x16とし、黒い背景に白い点があるだけの単純なものとします。
ランダムに点を置いて上のような画像をつくり、学習データとします。

回帰モデル

画像を入力、座標を出力とする回帰問題とします。
適当にモデルをつくります:

import keras
from keras.models import Model
from keras.layers import Input,Flatten,Dense
from keras.layers.convolutional import Conv2D

x = Input(shape=(size,size,1))
h = Conv2D(2,(5,5))(x)
h = Flatten()(h)
h = Dense(16,activation="relu")(h)
y = Dense(2)(h) # x,y座標
model = Model(inputs=[x],outputs=[y])

以下、損失関数を変えながら学習させます。

学習結果の見方ですが、出力座標を色付きの点で示しています。予測したピクセルに物体が存在すれば緑色、そうでない場合は水色としています。

結果1 MSE損失関数を使った場合

image.png
image.png
image.png

結果2 MAE損失関数を使った場合

image.png
image.png
image.png

結果1,2の考察

各点の間を出力してしまっています。
モデルの気持ちになるとすれば、
「どっちか正解かわかんないし、とりあえず真ん中で」
といったところでしょうか。優柔不断です。

こうなってしまう原因は損失関数の形状にあります。
この問題では、同じ入力画像に対して、正解とダミーが入れ替わった2通りの結果になる可能性があります。ですので、両方の可能性を考慮した損失関数、つまり各場合の平均が実際の損失関数とみなせます。

MSE,MAEそれぞれのケースを見てみます。

MSEのグラフ@google検索
MSEは下向きの放物面形状であり、放物面は何回足しても放物面です。
これが最小になるのは正解点とダミー点のちょうど中間となります。結果とあっています。

MAEのグラフ@google検索
MAEは下向きの四角錐形状であり、2つ足すと角ばったコップのような形になります。
これは2つの点の間の「底」の部分が平坦になっており、その底のどこかを出力するのが最適です。これも結果とあいますね。

MSEやMAEのまずいところは、損失関数が対象物体以外のところで最小値をとりうる事です。
ということは、多峰性がある損失関数であれば問題を解決できそうです。

結果3 カスタム損失関数を使った場合

一例として、ガウシアンを逆さにした形の損失関数を試してみます。点同士が十分に離れていれば、多峰性の損失関数を表現できます(パラメータσで調整できます)。
ガウシアンは遠方で勾配消失するので、安定のためにMAEに足すかたちにしておきます。
こんな形です

def custom_loss(y_true,y_pred):
    # MAE + negative gaussian
    sigma = 0.1
    return K.mean(K.abs(y_pred - y_true), axis=-1) + 1-K.exp( - ( (y_pred[:,0]-y_true[:,0])**2+(y_pred[:,1]-y_true[:,1])**2 )/sigma**2 )

image.png
image.png
image.png

いいですね。

結果3の考察

パラメータσは概ね見分けるべき点間隔を目安に0.1に設定しました(画像サイズの1/10相当)が、それでうまくいっているようです。0.3や0.02では精度がでませんでした。
このパラメータσやMAEに加える比率は、問題によって変える必要があるでしょう。

おわりに

トイプロブレムではありますが、MSE・MAE損失関数が多峰性を表現できないことによる問題と、その解消方法を見ることができました。
なお、今回は対象が全く見分けられないという特殊なケースでしたが、実用上そこまで極端な状況は少ないと思うので、MAEでも十分な場合もあると思います。

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.