LoginSignup
24
14

More than 5 years have passed since last update.

ちょwwwPyTorchでnanが出て困ってるんだがww【解決してやるよ】

Last updated at Posted at 2018-12-21

YouTubeみたいなタイトルをつけたことは後悔してない(してる)。

さて今回はご紹介するのはこちら!
複雑なDeep Learningのネットワークで学習していたら出てくるnanさん〜(BGM)。

うざいですよね〜。しかも一回出るとやり直さざる得ない。
でも奴の出現する法則がよくわからない。

うまくいっている風でもなんか出ます。

検索しても「てめえの学習が悪いんじゃね?(意訳)」と言われます...(実際悪いんでしょうけど)。

とはいえ、それを直すのも泥沼なので対処療法で直しましょう!!

注意点として、nanの回避はできますが、学習がうまくいくかは保証できません。

環境としては

  • PyTorch >= 0.4.1
  • Python3.7

とします。
ChainerやTensorflow使いのかたは申し訳ありません...。

サクッと解決法

model = ...
optimizer = ...
loss =  ...

if torch.isnan(loss):
    model = prev_model
    optimizer = torch.optim.Adam(model.parameters())
    optimizer.load_state_dict(prev_optimizer.state_dict())
else:
    prev_model = copy.deepcopy(model)
    prev_optimizer = copy.deepcopy(optimizer)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Step by step

1.nan検出をしよう

まずはこれが重要ですね。
torch.isnan()を使いましょう。

>>> import torch
>>> torch.isnan(torch.tensor([1,2,float('nan')]))
tensor([0, 0, 1], dtype=torch.uint8)

2.modelとoptimizerの値を保存しておこう

nanが出るケースは2パターンあります。

1.lossがnanになる
2.1つ前のパラメータのbackward時に一部パラメータがnanになる

現象としては結局どちらも同じですが、
一番最初にlossがnanになるのかパラメータがnanになるのか、という話ですね

1のケースが多いと思われがちですが、意外と精査すると2のケースもあります。

そのためうまくいっているデータを取っておきましょう。
optimizerも取っておくのは次で必要だからです。

if torch.isnan(loss):
    # ignore this section
else:
    prev_model = copy.deepcopy(model)
    prev_optimizer = copy.deepcopy(optimizer)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

一応ちゃんとdeepcopyしておきましょう。

3.nanが出たら代入しよう

そのまま代入すれば解決...と思いきや、
prev_modelをmodelに代入すると紐づいてるoptimizerの何やかんや(何かはわからない)が切れて更新されなくなります。

optimizerをもう一度定義し直しましょう。

if torch.isnan(loss):
    # section 3
    model = prev_model
    optimizer = torch.optim.Adam(model.parameters())
    optimizer.load_state_dict(prev_optimizer.state_dict())
else:
    prev_model = copy.deepcopy(model)
    prev_optimizer = copy.deepcopy(optimizer)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

これで一応解決です。

確認用コード

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
  # 
  def __init__(self):
    super(Model, self).__init__()
    self.l1 = nn.Linear(10,5)

  def forward(self, x):
    out = self.l1(x)
    return out

criterion = nn.CrossEntropyLoss()
model = Model()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)

x = torch.rand((1,10), requires_grad=True)
t = torch.tensor([3])

# save model and optimizer
prev_model = copy.deepcopy(model)
prev_optimizer = copy.deepcopy(optimizer)

optimizer.zero_grad()
print(model.l1.weight)
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()

print(model.l1.weight)

optimizer.zero_grad()
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()

print(model.l1.weight)

model = prev_model
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer.load_state_dict(prev_optimizer.state_dict())

optimizer.zero_grad()
print(model.l1.weight)
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()

print(model.l1.weight)
24
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
14