YouTubeみたいなタイトルをつけたことは後悔してない(してる)。
さて今回はご紹介するのはこちら!
複雑なDeep Learningのネットワークで学習していたら出てくるnanさん〜(BGM)。
うざいですよね〜。しかも一回出るとやり直さざる得ない。
でも奴の出現する法則がよくわからない。
うまくいっている風でもなんか出ます。
検索しても「てめえの学習が悪いんじゃね?(意訳)」と言われます...(実際悪いんでしょうけど)。
とはいえ、それを直すのも泥沼なので対処療法で直しましょう!!
注意点として、nanの回避はできますが、学習がうまくいくかは保証できません。
環境としては
- PyTorch >= 0.4.1
- Python3.7
とします。
ChainerやTensorflow使いのかたは申し訳ありません...。
サクッと解決法
model = ...
optimizer = ...
loss = ...
if torch.isnan(loss):
model = prev_model
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(prev_optimizer.state_dict())
else:
prev_model = copy.deepcopy(model)
prev_optimizer = copy.deepcopy(optimizer)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Step by step
1.nan検出をしよう
まずはこれが重要ですね。
torch.isnan()
を使いましょう。
>>> import torch
>>> torch.isnan(torch.tensor([1,2,float('nan')]))
tensor([0, 0, 1], dtype=torch.uint8)
2.modelとoptimizerの値を保存しておこう
nanが出るケースは2パターンあります。
1.lossがnanになる
2.1つ前のパラメータのbackward時に一部パラメータがnanになる
現象としては結局どちらも同じですが、
一番最初にlossがnanになるのかパラメータがnanになるのか、という話ですね
1のケースが多いと思われがちですが、意外と精査すると2のケースもあります。
そのためうまくいっているデータを取っておきましょう。
optimizerも取っておくのは次で必要だからです。
if torch.isnan(loss):
# ignore this section
else:
prev_model = copy.deepcopy(model)
prev_optimizer = copy.deepcopy(optimizer)
optimizer.zero_grad()
loss.backward()
optimizer.step()
一応ちゃんとdeepcopyしておきましょう。
3.nanが出たら代入しよう
そのまま代入すれば解決...と思いきや、
prev_modelをmodelに代入すると紐づいてるoptimizerの何やかんや(何かはわからない)が切れて更新されなくなります。
optimizerをもう一度定義し直しましょう。
if torch.isnan(loss):
# section 3
model = prev_model
optimizer = torch.optim.Adam(model.parameters())
optimizer.load_state_dict(prev_optimizer.state_dict())
else:
prev_model = copy.deepcopy(model)
prev_optimizer = copy.deepcopy(optimizer)
optimizer.zero_grad()
loss.backward()
optimizer.step()
これで一応解決です。
確認用コード
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
#
def __init__(self):
super(Model, self).__init__()
self.l1 = nn.Linear(10,5)
def forward(self, x):
out = self.l1(x)
return out
criterion = nn.CrossEntropyLoss()
model = Model()
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
x = torch.rand((1,10), requires_grad=True)
t = torch.tensor([3])
# save model and optimizer
prev_model = copy.deepcopy(model)
prev_optimizer = copy.deepcopy(optimizer)
optimizer.zero_grad()
print(model.l1.weight)
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()
print(model.l1.weight)
optimizer.zero_grad()
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()
print(model.l1.weight)
model = prev_model
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer.load_state_dict(prev_optimizer.state_dict())
optimizer.zero_grad()
print(model.l1.weight)
y = model(x)
loss = criterion(y,t)
loss.backward()
optimizer.step()
print(model.l1.weight)