Edited at
PyTorchDay 17

PyTorchでのNaN検出方法


NaNとは


概要

NaNは、浮動小数点表記で表現できない領域の表記 (Not a Numberの略) である。IEEE754-1985で定義されている。機械学習の場合、不適切なハイパーパラメータの設定を行う場合に発生しやすい。12

演算結果が、NaNになっても演算は続けることが出来る。しかし、計算を続けたい場合は良いが、発生時に検知したい場合がある。ここでは、その検知方法について記載する。

なお、PyTorchは、torch (PythonとC++) とaten (C++) で記述されている。これは、処理の高速化を図るためである。このため、説明にC++のコードが入ってくる。


NaNの演算

NaNの演算は、以下の通りである。


  • NaNと別の値を演算しても、NaNのままである。

  • NaNの値は、通常の値とは異なり自身の値と比較するとTrueでは無くFalseとなる。


NaN検出のやり方

PyTorchでは、2つのNaN検出方法が提供されている。Tensorそのもの、およびBackward出力のTensorの検出である。ここでは、使い方を説明する。


Tensorの場合

TensorをNaNか否かをチェックする。

torch.isnan(torch.tensor([1, float('nan'), 2]))

上記の出力が以下となる。このため、すべての要素がゼロかそれ以外かで判断出来る。

tensor([0, 1, 0], dtype=torch.uint8)


Backward演算出力のTensorの場合

detect_anomalyのフラグを立てて、Backward演算でのTensorでNaNが出力されたか否かを確認する。

torch.autograd.set_detect_anomaly(True)

inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()

もしくは、以下のように用いる。

with torch.autograd.detect_anomaly()

inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()


NaN検出の仕組み

2つのNaNの検出の仕組みについて、説明する。


Tensorの場合

PyTorchのPython層でtorch.isnanがTensor同士でNaNの条件である不等号(!=)であるかを確認する。(torch/fuctional.pyのdef isnan(tensor):)

さて、Pythonの仕様で、不等号(!=)は、オブジェクトで定義した__ne__で処理される。PyTorchの場合、__ne__は、torch/tensor.pyで定義している。そして、neの処理はC++層(PyTypeObject)THPVariableTypeで処理を行う。


Backward演算出力のTensorの場合

Backwardの計算は、torch/csrc/autograd/engine.cppで行う。

C++のAnomalyMode構造体が、NaNをチェックするか否かのフラグとなる。AnomalyModeがONの場合、BackwardのテンソルのNaNチェックを行う。具体的なチェックするコードは以下であるtorch/csrc/autograd/engine.cpp。ここで、output.ne(output).any().item<uint8_t>()にて、!=のチェックを行っている、以下は、具体的な箇所周りのコードを示している。

  if (AnomalyMode::is_enabled()) {

AutoGradMode grad_mode(false);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && output.ne(output).any().item<uint8_t>()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
throw std::runtime_error(ss.str());
}
}
}


参考資料