LoginSignup
26
16

More than 3 years have passed since last update.

PyTorchでのNaN検出方法

Last updated at Posted at 2018-12-16

NaNとは

概要

NaNは、浮動小数点表記で表現できない領域の表記 (Not a Numberの略) である。IEEE754-1985で定義されている。機械学習の場合、不適切なハイパーパラメータの設定を行う場合に発生しやすい。12
演算結果が、NaNになっても演算は続けることが出来る。しかし、計算を続けたい場合は良いが、発生時に検知したい場合がある。ここでは、その検知方法について記載する。
なお、PyTorchは、torch (PythonとC++) とaten (C++) で記述されている。これは、処理の高速化を図るためである。このため、説明にC++のコードが入ってくる。

NaNの演算

NaNの演算は、以下の通りである。

  • NaNと別の値を演算しても、NaNのままである。
  • NaNの値は、通常の値とは異なり自身の値と比較するとTrueでは無くFalseとなる。

NaN検出のやり方

PyTorchでは、2つのNaN検出方法が提供されている。Tensorそのもの、およびBackward出力のTensorの検出である。ここでは、使い方を説明する。

Tensorの場合

TensorをNaNか否かをチェックする。

torch.isnan(torch.tensor([1, float('nan'), 2]))

上記の出力が以下となる。このため、すべての要素がゼロかそれ以外かで判断出来る。一つでもあったらNGなら、any()で確認すればよい。

tensor([0, 1, 0], dtype=torch.uint8)

Backward演算出力のTensorの場合

detect_anomalyのフラグを立てて、Backward演算でのTensorでNaNが出力されたか否かを確認する。

torch.autograd.set_detect_anomaly(True)
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()

もしくは、以下のように用いる。

with torch.autograd.detect_anomaly()
   inp = torch.rand(10, 10, requires_grad=True)
   out = run_fn(inp)
   out.backward()

NaN検出の仕組み

2つのNaNの検出の仕組みについて、説明する。

Tensorの場合

PyTorchのPython層でtorch.isnanがTensor同士でNaNの条件である不等号(!=)であるかを確認する。(torch/fuctional.pyのdef isnan(tensor):)
さて、Pythonの仕様で、不等号(!=)は、オブジェクトで定義した__ne__で処理される。PyTorchの場合、__ne__は、torch/tensor.pyで定義している。そして、neの処理はC++層(PyTypeObject)THPVariableTypeで処理を行う。

Backward演算出力のTensorの場合

Backwardの計算は、torch/csrc/autograd/engine.cppで行う。
C++のAnomalyMode構造体が、NaNをチェックするか否かのフラグとなる。AnomalyModeがONの場合、BackwardのテンソルのNaNチェックを行う。具体的なチェックするコードは以下であるtorch/csrc/autograd/engine.cpp。ここで、output.ne(output).any().item<uint8_t>()にて、!=のチェックを行っている、以下は、具体的な箇所周りのコードを示している。

  if (AnomalyMode::is_enabled()) {
    AutoGradMode grad_mode(false);
    for (int i = 0; i < num_outputs; ++i) {
      auto& output = outputs[i];
      at::OptionalDeviceGuard guard(device_of(output));
      if (output.defined() && output.ne(output).any().item<uint8_t>()) {
        std::stringstream ss;
        ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
        throw std::runtime_error(ss.str());
      }
    }
  }

参考資料

26
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
16