学習中にnanやinfが出た時
どこでそれらが発生したかを追いかけるのにtfdbgという機能が使えます。基本的な使い方は公式のドキュメントか、tfdbg を使って Keras の nan や inf を潰すなどを見ればおおよそ分かるかと思います。わりと理解しやすいデバッガだと思います。
明らかに関係ない所で見つかるinf
Kerasのモデルでtfdbgを利用する際、ある条件を満たすと has_inf_or_nan
で明らかに不必要な inf
が見つかります。
初見だとそこに問題があるのかと思って悩まされます。
例えばこんなモデル
inputs = Input(shape=(5,))
sqrts = Lambda(lambda x: K.sqrt(x))(inputs)
outputs = Activation('softmax')(sqrts)
model = Model(inputs, outputs)
Kerasで書いたことのある人なら分かると思いますが、Trainableな重みは1つも無い構造です。なので fit
しようが何しようアウトプットはインプットにしか依存しません。
しかしこのモデルに対してtfdbgを用い、run -f has_inf_or_nan
を実行すると inf
を見つけて来てストップします。
このモデルであれば明らかに inf
なんか入るわけ無いだろと分かるわけですが、大体の人はnan
やinf
が発生しているモデルに対してデバッグするためにtfdbgを使うわけですから、inf
が見つかった箇所付近を重点的にチェックすると思います。
しかし残念ながらこの inf
は本来の原因とは一切関係ありません。
実はここにあるようにK.sqrt
の実装の中にinf
が登場します。そのため現状ではkeras.backend.sqrt
をネットワーク内に含んでいるだけでhas_inf_or_nan
が常にヒットします。
has_inf_or_nan
は学習中のどこかで発生するinf
やnan
を見つけやすくしてくれるはずなのに、こいつのせいでrunは毎回止まるのでいくらなんでも邪魔すぎます。
なのでIssueにして本家の意見を聞いてみました。
結果的には、まだ解決はしてないのですが少なくともこの事実を知っていないと意味のないinf
発生源にずっと悩まされる人がたくさん出てきそうなので記事にしてみました。Issueに進展があったら追記します。
【追記 2018/01/24】
上記のIssueにKerasの作者から直々に返信をいただけました!
現状既に解決している問題のために敢えてそのような実装になっているようなので K.sqrt()
ではなく、import tensorflow as tf
のようにTensorflowを読み込んだ上でtf.sqrt()
を直に使って問題無さそうという事でした。
という事で解決😋