Edited at

Kerasのモデルでtfdbgを使う時の罠

More than 1 year has passed since last update.


学習中にnanやinfが出た時

どこでそれらが発生したかを追いかけるのにtfdbgという機能が使えます。基本的な使い方は公式のドキュメントか、tfdbg を使って Keras の nan や inf を潰すなどを見ればおおよそ分かるかと思います。わりと理解しやすいデバッガだと思います。


明らかに関係ない所で見つかるinf

Kerasのモデルでtfdbgを利用する際、ある条件を満たすと has_inf_or_nan で明らかに不必要な inf が見つかります。

初見だとそこに問題があるのかと思って悩まされます。

例えばこんなモデル

inputs = Input(shape=(5,))

sqrts = Lambda(lambda x: K.sqrt(x))(inputs)
outputs = Activation('softmax')(sqrts)
model = Model(inputs, outputs)

Kerasで書いたことのある人なら分かると思いますが、Trainableな重みは1つも無い構造です。なので fit しようが何しようアウトプットはインプットにしか依存しません。

しかしこのモデルに対してtfdbgを用い、run -f has_inf_or_nan を実行すると inf を見つけて来てストップします。

このモデルであれば明らかに inf なんか入るわけ無いだろと分かるわけですが、大体の人はnaninfが発生しているモデルに対してデバッグするためにtfdbgを使うわけですから、infが見つかった箇所付近を重点的にチェックすると思います。

しかし残念ながらこの inf は本来の原因とは一切関係ありません。

実はここにあるようにK.sqrtの実装の中にinfが登場します。そのため現状ではkeras.backend.sqrtをネットワーク内に含んでいるだけでhas_inf_or_nanが常にヒットします。

has_inf_or_nanは学習中のどこかで発生するinfnanを見つけやすくしてくれるはずなのに、こいつのせいでrunは毎回止まるのでいくらなんでも邪魔すぎます。

なのでIssueにして本家の意見を聞いてみました。

https://github.com/keras-team/keras/issues/9161

結果的には、まだ解決はしてないのですが少なくともこの事実を知っていないと意味のないinf発生源にずっと悩まされる人がたくさん出てきそうなので記事にしてみました。Issueに進展があったら追記します。

【追記 2018/01/24】

上記のIssueにKerasの作者から直々に返信をいただけました!

現状既に解決している問題のために敢えてそのような実装になっているようなので K.sqrt() ではなく、import tensorflow as tfのようにTensorflowを読み込んだ上でtf.sqrt()を直に使って問題無さそうという事でした。

という事で解決😋