はじめに
ディープラーニングのモジュールの1つであるChainerですが、自前のネットワークを実装すると「計算中にNaNが発生」→「以降のlossがNaN」→「うまく学習できない」といった事態がよく起きます。
開発中にNaNが発生するのは健全なことで、このNaNを潰していくことがニューラルネットワークのプログラミングの1つと言えるでしょう。
このようなNaNがforward中に発生するのであれば、学習時に出力されるRanTimeErrorを参考にして、printで実際のデータの中身を見つつ、地道にnanの原因を潰して行けばいいのですが、backword中に発生すると、どこでnanの計算が発生しているのか全く検討が付きません。
そもそもforward中でも出力されるRanTimeErrorだけでは、どこで問題が生じているか一目ではわかりませんし、printによるデバッグといった原始的な方法もできれば避けたいです。
というわけで、NaN潰しに必要不可欠なChainerで使用できるデバッグモードを紹介します。
Chainerのデバッグモード
import chainer
class YourNetwork(chainer.Chain):
(中略 init部分)
def forward(self, xs, ys):
with chainer.using_config('debug', True):
なんやかんや
ネットワーク内のforwardにwith chainer.using_config('debug', True):
を書き足すだけです。
Chainerのデバッグモードの力を試してみる
私が現在実装中のKNRMでデバッグモードを試してみました。
このプログラムでは以下のような正規化の処理を含んでいます。
def normalize(self, x):
n = self.xp.linalg.norm(x.data, axis=-1)
n = n[:, :, None]
n = F.tile(n, (1, 1, x.shape[-1])) # + self.minute_num 本来はNaN回避のため極小の数値を加算する
return x / n
normalizeの入力となる x は(batch_size, sequence, word_vector)から成るクラス3のテンソルですが、ミニバッチに含まれる元々のsequenceの長さが異なるため、ゼロpaddingしています。
したがって、x / n の計算でゼロでの除算によるNaNが発生します。
では、このnormalizeを含むプログラムの実行をデバッグモードあり/なしで比較してみましょう。
なお途中経過がわかるように、1iteration毎にlossとaccuracyを出力しています。
デバッグモードなし
(上略)
/Users/aizakku/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/chainer/functions/math/basic_math.py:385: RuntimeWarning: invalid value encountered in true_divide
return utils.force_array(x[0] / x[1]),
/Users/aizakku/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/chainer/functions/math/exponential.py:59: RuntimeWarning: divide by zero encountered in log
return utils.force_array(numpy.log(x[0])),
/Users/aizakku/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/numpy/core/_methods.py:26: RuntimeWarning: invalid value encountered in reduce
return umr_maximum(a, axis, None, out, keepdims)
variable(nan) variable(0.1953125)
/Users/aizakku/workspace/nets.py:106: RuntimeWarning: invalid value encountered in greater
m = self.xp.sum(self.xp.absolute(x.data), axis=-1) > 0.
variable(nan) variable(0.6796875)
variable(nan) variable(0.5625)
(以下略)
予定通りlossがNaNになっていますね。RuntimeWarning: invalid value encountered in true_divide
から、どこかでゼロの除算が発生していることがわかりますが、どこかは判断できません。
また発生したnanによる、後々の演算でのWarningも表示されており、非常に読みづらいです。
そして、lossがnanになっているにもかかわらず、プログラムが走り続けます。
デバッグモードあり
では、with chainer.using_config('debug', True):
を書き足してデバッグモードをオンにしてみましょう
(上略)
File "/Users/aizakku/workspace/nets.py", line 116, in cross_match
x1, x2 = self.normalize(x1), self.normalize(x2)
File "/Users/aizakku/workspace/nets.py", line 103, in normalize
return x / n
File "/Users/aizakku/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/chainer/functions/math/basic_math.py", line 460, in div
return Div().apply((self, rhs))[0]
File "/Users/aizakku/.pyenv/versions/anaconda3-5.1.0/lib/python3.6/site-packages/chainer/function_node.py", line 288, in apply
raise RuntimeError(msg)
RuntimeError: NaN is detected on forward computation of _ / _
# ここでプログラムが止まる
上記のとおり、デバッグモードでは明示的にどこで問題が発生しているのか示してくれます。
今回は、nets.pyの116行目のself.normalize(x1)
を呼んでいるところで、そのself.normalize中のreturn x / n
でRuntimeErrorが発生していることを教えてくれています。
そして「RuntimeError: NaN is detected on forward computation of _ / _」とあるように、問題がforwardで発生しているのか、backwordで発生しているのかが判別できます。
さらに、NaNが生じるRuntimeErrorによってプログラム自体が停止します。(特殊な処理を記述していなければ、)NaNが発生した時点でそれ以降の学習は基本的に無意味ですので、無駄なリソース消費を避けられるという利点があります。
(ただ、問題のないRuntimeErrorでも止まってしまうので、実際に学習させるときは切っておいたほうが無難です)
まとめ
Chainerではwith chainer.using_config('debug', True):
をネットワークのforword内に書くだけでデバッグモードが利用できます。
本記事で紹介したとおり、NaN潰しなどのデバッグが非常にやりやすくなるため、実装中は必ず書き足しておき、より効率的な実装を目指しましょう!