LoginSignup
32
27

More than 5 years have passed since last update.

Tensorflowの精度低下はlog(0)のせいだった

Last updated at Posted at 2016-10-04

最近TensorFlowを使い始めたのですが,学習中に突然精度が低下して変わらなくなる問題が起きていました.
以下だと70ステップ目から突然精度が低下してます.

.
.
.
step:67 train:0.894584 test:0.756296
step:68 train:0.900654 test:0.756944
step:69 train:0.897526 test:0.758796
step:70 train:0.361345 test:0.333333
step:71 train:0.361345 test:0.333333
step:72 train:0.361345 test:0.333333
step:73 train:0.361345 test:0.333333
.
.
.

重みを見てみると以下のようになっていました.

(pdb) w1
array([[[[ nan,  nan,  nan, ...,  nan,  nan,  nan],
         [ nan,  nan,  nan, ...,  nan,  nan,  nan],
         [ nan,  nan,  nan, ...,  nan,  nan,  nan]],

        [[ nan,  nan,  nan, ...,  nan,  nan,  nan],
         [ nan,  nan,  nan, ...,  nan,  nan,  nan],
         [ nan,  nan,  nan, ...,  nan,  nan,  nan]], 
.
.
.

NaNになっているということで,"tensorflow nan"で検索してみると解決方法が出てきました.
http://stackoverflow.com/questions/33712178/tensorflow-nan-bug

問題箇所は交差エントロピーの計算部分であり,以下のようにしていました(y_convはsoftmax関数による各ラベルの確率).

cross_entropy = -tf.reduce_sum(labels*tf.log(y_conv))

このままだとlog(0)になり,NaNが出てくる可能性があります.
そこで,以下のように1e-10~1.0の範囲に正規化してからlogをとることで解決できました.

cross_entropy = -tf.reduce_sum(labels*tf.log(tf.clip_by_value(y_conv,1e-10,1.0)))

tf.nn.softmax_cross_entropy_with_logitsという関数があり,それを使って以下のようにした方がいいみたいです. →この方法だとうまくいきませんでした.

32
27
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
32
27