KerasやTensorFlowを使っているときに、突然損失関数でnanが出てその特定にとても困ることがあります。ディープラーニングはブラックボックスになりがちなので、普通プログラムのデバッグよりもかなり大変です。この方法は、データに問題がある場合限定の方法ですが、単純ではあるもののかなり有力な方法だと思ったので記事にしておきます。
データに問題があるのか、ニューラルネットワークに問題があるのかを切り離す
まずやらなければいけないのは、普通のデバッグと同じで問題の箇所の特定です。損失関数でnanが起きるのは、主に2つあります。
ニューラルネットワークの失敗経験については、詳しくはこちらにケーススタディがまとまっていますのでぜひご覧ください
Neural Networkでの失敗経験やアンチパターンを語る
http://nonbiri-tereka.hatenablog.com/entry/2016/03/10/073633
データに問題があるケース
- データそのものにnan(欠損値)やinf(無限大)が入っていた
- 前処理の過程でこれらの値が発生した
nanはPandasを使っていれば、fillnaやdropnaといった便利な関数があるので気づきやすいのですが、infは忘れがちです。今回の方法はこのinfに対してよく効く方法です(後述)。
データそのものにinfが入っていることもありますが、infは前処理の過程でも起こります。気づきにくいのはfloat型のキャストでしょう。例えば1e100という非常に大きな値に対して、float64→float32へのキャストを行うとinfが発生します。
>>> a=np.array([1e100],dtype=np.float64)
>>> a
array([1.e+100])
>>> a.astype(np.float32)
array([inf], dtype=float32)
Pandasのfloatのデフォルトは64ビットなので、Numpyへのキャストで気づかない間にinfが発生していたなんてこともあり得ます。
ニューラルネットワークに問題があるケース
- 学習率が高すぎて損失が発散した
- 勾配が爆発した
- 外れ値のデータが入ってきて係数があらぬ方向に学習された
1点目は学習率を下げればいいだけですし、損失推移を見ていれば指数関数的に大きくなっているのでわかりやすいです。
2点目はRNNでよく起きることですが、勾配のクリッピングをすればいいです。これはフレームワーク側で対応できます。
3点目はデータに問題があるケースと複合的なのですが、標準化したり、データ側でクリッピングしたり、フレームワーク側でクリッピングしたり、またはアルゴリズムを工夫したりやり方はあります。ここは今回掘り下げません。
しかし、大事なのは、これらのニューラルネットワークへの解決のアプローチを適用しても、実際はデータに問題がある場合は解決しないということなのです。なので、データに問題がないということの確証が得られるまでは、まずデータを疑ったほうがいいのではないかと自分は思います。
infによく効く「np.errstate(all="raise")+標準化」
元ネタはKerasのissueからですが、**Numpyで標準化(平均を引き、標準偏差で割る)**をするとエラーが改善したり、Numpyがエラーメッセージを吐いてくれて、問題箇所を特定することができたりします。
例を見てみます。実際はこんなわかりやすいinfがあることはないですが、あくまで説明用です。
import numpy as np
# 模擬的なデータ
X = np.arange(24, dtype=np.float32).reshape(4,3,2)
# infの入ったデータ
infX = X.copy()
infX[1,0,0] = np.inf
def normalize(X):
mean = np.mean(X, axis=(1,2), keepdims=True)
sd = np.std(X, axis=(1,2), keepdims=True)
return (X-mean)/(sd+1e-7)
normalize(X)
infが入ったデータに対して標準化をしようとすると、平均で引いたときにエラーがでます。
…/numpy\core\_methods.py:117: RuntimeWarning: invalid value encountered in subtract
x = asanyarray(arr - arrmean)
…\module1.py:15: RuntimeWarning:
invalid value encountered in subtract
return (X-mean)/(sd+1e-7)
ただし、ジェネレーターの中ではエラーが抑止されていることもあるので、ジェネレータの中では**with np.errstate(all="raise")
**構文を使ってエラーを明示的に表示してあげると良いでしょう。
def normalize_with_warning(X):
with np.errstate(all="raise"):
mean = np.mean(X, axis=(1,2), keepdims=True)
sd = np.std(X, axis=(1,2), keepdims=True)
return (X-mean)/(sd+1e-7)
こうするともう少し詳細にエラーを出してくれます。
Traceback (most recent call last):
File "…\module1.py", line 32, in <module>
normalize_with_warning(infX)
File "…\module1.py", line 20, in normalize_with_warning
sd = np.std(X, axis=(1,2), keepdims=True)
File "…\numpy\core\fromnumeric.py", line 3038, in std
**kwargs)
File "…\numpy\core\_methods.py", line 140, in _std
keepdims=keepdims)
File "…\numpy\core\_methods.py", line 117, in _var
x = asanyarray(arr - arrmean)
FloatingPointError: invalid value encountered in subtract
さらに掘り下げて、try~exceptを使ってデバッグコードを差し込みましょう。こうすることで、データのどこで問題が発生しているのか即特定できます。
def normalize_with_warning_dubugging(X):
with np.errstate(all="raise"):
try:
mean = np.mean(X, axis=(1,2), keepdims=True)
sd = np.std(X, axis=(1,2), keepdims=True)
return (X-mean)/(sd+1e-7)
except:
print(np.where(np.isinf(X))) # エラーを特定するためのデバッグコード
raise
(array([1], dtype=int64), array([0], dtype=int64), array([0], dtype=int64))
(あとは先程と同じ)
ジェネレーター全体に対してnp.isinfをやるとログが流れてしまって大変なので、エラーを明示的に出させてあとは「try~except」すると綺麗に行きます。
nanに対しては効かない
残念なことにnanに対してはこの方法効きません。Pandasやnp.isnan()などを使ってしらみつぶしにするしかありません。
まとめ
損失関数でnanが出る場合は、データに問題があるケース、ネットワークに問題があるケースがあり、問題箇所の切り離しを行うのが大切。
データに問題があるケースでは、値にnanやinfが発生しているケースがあり、infの場合はこの「np.errstate(all="raise")+標準化」でエラー箇所を特定できる、ということでした。参考になれば幸いです。