LoginSignup
4
0

機械学習の学習時にロスがNaNになるときにすること

Last updated at Posted at 2022-05-26

まずデータを確認する。
次にネットワークの前後処理を確認する。
その後にネットワークを疑う。

データセットにNaNが混ざっている

意外とよくやる。
対処方法はNaNが含まれるデータは弾いてしまうか、または他の値に置き換える

欠損値を弾く

pandasを使っているならとても簡単でdf.dropna()を使えば欠損値を含む行を削除してくれます。

df.dropna()

欠損値が含まれる列を削除したいときは

df.dropna(axis=1)

また、データ成形中にうっかりnumpy配列にNaNができてしまったら弾くようにしています。

def nan_checker(array):
    """
    input ndarray
    '''
    return
    配列にNaNが含まれない   -> True
    配列にNaNが含まれる     -> False
    """
    if np.any(np.isnan(array)):
        return False
    else:
        return True

pytorchならdataloaderに、tensorflowならTERecordにする前にこのnan_checker()を挟んどくと良いかもしれません。

欠損値を置き換える

データの数が限られていて一つでも失うのが惜しいときには、欠損値を他のデータに置き換えます。
よくやるのは平均値補完

from sklearn.impute import SimpleImputer
import numpy as np

imputer = SimpleImputer(missing_value=np.nan, strategy='mean')
imputer = imputer.fit(df.values)
imputed_array = imputer.transform(df.values) #補完

シンプルにゼロ割している

このパターンがほとんどな気がします

x = x / torch.max(x) # NG
x = x / (torch.max(x) + 1e-12) # OK
x = x / (torch.max(x) + torch.finfo(torch.float32).eps) # OK

sqrtにゼロが入っている

sqrtにゼロか限りなく小さい数が入力されると発散します。

torch.sqrt(x) # NG
torch.sqrt(x + 1e-10) # OK
torch.sqrt(x + torch.finfo(torch.float32).eps) # OK

交差エントロピー誤差のlog()にゼロが入っている

logにゼロが入って発散してしまうとNaNになるので、predictの値にちっちゃい値(1e-12)を足す

cross_entropy = -tf.reduce_mean(label * tf.log(predict + 1e-12))

結果の値が小さすぎる

sigmoid関数にちっちゃすぎる値を入れると、0として返されることがあります。
また64bitから32bitへのキャストの間に小さい値が0になってしますこともあります。

学習率がでかすぎる

学習率がでかいとNaNがでることがあります...が、根本的な対策にはなってないことがほとんどなのでその場しのぎ的な対処法だと思ってください。

今後もNanに出会ったら追記していきます

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0