LoginSignup
4
8

More than 5 years have passed since last update.

座標を出力する物体検出モデルにおける、MSE損失の罠

Posted at

はじめに

画像から物体検出し、ロボットアームでその物体をつかむ、といったタスクを考えてみます。
物体検出では対象の座標を知りたいので、画像を入力、座標を出力とする回帰問題として扱えばよさそうです。
画像認識さえできれば簡単に思えます。

しかしながら、損失関数の選択に罠があります。
回帰問題でよく使われるMean Squared Error(MSE)やMean Absolute Error(MAE)を使う場合、対象物体が見分けやすい時には問題ないのですが、見分けにくいダミー物体がある場合にはうまく行きません。
「二兎を追う者一兎をも得ず」のような状況が発生してしまうのです。

ちなみに、物体検出に限らず、回帰問題においては同様の問題が発生することがあります。
YOLOの実験をしているときにもこのような問題に遭遇しました。
WaveNetやMaskRCNN(キーポイント検出)で、連続値であるはずの出力をあえて分類として扱っているのも、このような状況に対応するためと思われます。

今回は、上記のような問題(名前はついているんでしょうか?)を、損失関数の設計で回避する方法を紹介します。
問題を再現するシンプルなタスク設定で、損失関数の動きを見てみましょう。

コードと実験結果はこちらにあります:
https://gist.github.com/stnk20/58352ae84a4208229d1d8306f7fae8cc

タスク設定

画像上に見分けのつかない複数の点があります。どれかひとつが対象物体でほかはダミーです。対象物体の座標を答えることが目標です。
image.png
今回、画像の部分は本質的ではないので、画像は小さない16x16とし、黒い背景に白い点があるだけの単純なものとします。
ランダムに点を置いて上のような画像をつくり、学習データとします。

回帰モデル

画像を入力、座標を出力とする回帰問題とします。
適当にモデルをつくります:

import keras
from keras.models import Model
from keras.layers import Input,Flatten,Dense
from keras.layers.convolutional import Conv2D

x = Input(shape=(size,size,1))
h = Conv2D(2,(5,5))(x)
h = Flatten()(h)
h = Dense(16,activation="relu")(h)
y = Dense(2)(h) # x,y座標
model = Model(inputs=[x],outputs=[y])

以下、損失関数を変えながら学習させます。

学習結果の見方ですが、出力座標を色付きの点で示しています。予測したピクセルに物体が存在すれば緑色、そうでない場合は水色としています。

結果1 MSE損失関数を使った場合

image.png
image.png
image.png

結果2 MAE損失関数を使った場合

image.png
image.png
image.png

結果1,2の考察

各点の間を出力してしまっています。
モデルの気持ちになるとすれば、
「どっちか正解かわかんないし、とりあえず真ん中で」
といったところでしょうか。優柔不断です。

こうなってしまう原因は損失関数の形状にあります。
この問題では、同じ入力画像に対して、正解とダミーが入れ替わった2通りの結果になる可能性があります。ですので、両方の可能性を考慮した損失関数、つまり各場合の平均が実際の損失関数とみなせます。

MSE,MAEそれぞれのケースを見てみます。

MSEのグラフ@google検索
MSEは下向きの放物面形状であり、放物面は何回足しても放物面です。
これが最小になるのは正解点とダミー点のちょうど中間となります。結果とあっています。

MAEのグラフ@google検索
MAEは下向きの四角錐形状であり、2つ足すと角ばったコップのような形になります。
これは2つの点の間の「底」の部分が平坦になっており、その底のどこかを出力するのが最適です。これも結果とあいますね。

MSEやMAEのまずいところは、損失関数が対象物体以外のところで最小値をとりうる事です。
ということは、多峰性がある損失関数であれば問題を解決できそうです。

結果3 カスタム損失関数を使った場合

一例として、ガウシアンを逆さにした形の損失関数を試してみます。点同士が十分に離れていれば、多峰性の損失関数を表現できます(パラメータσで調整できます)。
ガウシアンは遠方で勾配消失するので、安定のためにMAEに足すかたちにしておきます。
こんな形です

def custom_loss(y_true,y_pred):
    # MAE + negative gaussian
    sigma = 0.1
    return K.mean(K.abs(y_pred - y_true), axis=-1) + 1-K.exp( - ( (y_pred[:,0]-y_true[:,0])**2+(y_pred[:,1]-y_true[:,1])**2 )/sigma**2 )

image.png
image.png
image.png

いいですね。

結果3の考察

パラメータσは概ね見分けるべき点間隔を目安に0.1に設定しました(画像サイズの1/10相当)が、それでうまくいっているようです。0.3や0.02では精度がでませんでした。
このパラメータσやMAEに加える比率は、問題によって変える必要があるでしょう。

おわりに

トイプロブレムではありますが、MSE・MAE損失関数が多峰性を表現できないことによる問題と、その解消方法を見ることができました。
なお、今回は対象が全く見分けられないという特殊なケースでしたが、実用上そこまで極端な状況は少ないと思うので、MAEでも十分な場合もあると思います。

4
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
8