Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
12
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

PyTorchでのNaN検出方法

NaNとは

概要

NaNは、浮動小数点表記で表現できない領域の表記 (Not a Numberの略) である。IEEE754-1985で定義されている。機械学習の場合、不適切なハイパーパラメータの設定を行う場合に発生しやすい。12
演算結果が、NaNになっても演算は続けることが出来る。しかし、計算を続けたい場合は良いが、発生時に検知したい場合がある。ここでは、その検知方法について記載する。
なお、PyTorchは、torch (PythonとC++) とaten (C++) で記述されている。これは、処理の高速化を図るためである。このため、説明にC++のコードが入ってくる。

NaNの演算

NaNの演算は、以下の通りである。

  • NaNと別の値を演算しても、NaNのままである。
  • NaNの値は、通常の値とは異なり自身の値と比較するとTrueでは無くFalseとなる。

NaN検出のやり方

PyTorchでは、2つのNaN検出方法が提供されている。Tensorそのもの、およびBackward出力のTensorの検出である。ここでは、使い方を説明する。

Tensorの場合

TensorをNaNか否かをチェックする。

torch.isnan(torch.tensor([1, float('nan'), 2]))

上記の出力が以下となる。このため、すべての要素がゼロかそれ以外かで判断出来る。一つでもあったらNGなら、any()で確認すればよい。

tensor([0, 1, 0], dtype=torch.uint8)

Backward演算出力のTensorの場合

detect_anomalyのフラグを立てて、Backward演算でのTensorでNaNが出力されたか否かを確認する。

torch.autograd.set_detect_anomaly(True)
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()

もしくは、以下のように用いる。

with torch.autograd.detect_anomaly()
   inp = torch.rand(10, 10, requires_grad=True)
   out = run_fn(inp)
   out.backward()

NaN検出の仕組み

2つのNaNの検出の仕組みについて、説明する。

Tensorの場合

PyTorchのPython層でtorch.isnanがTensor同士でNaNの条件である不等号(!=)であるかを確認する。(torch/fuctional.pyのdef isnan(tensor):)
さて、Pythonの仕様で、不等号(!=)は、オブジェクトで定義した__ne__で処理される。PyTorchの場合、__ne__は、torch/tensor.pyで定義している。そして、neの処理はC++層(PyTypeObject)THPVariableTypeで処理を行う。

Backward演算出力のTensorの場合

Backwardの計算は、torch/csrc/autograd/engine.cppで行う。
C++のAnomalyMode構造体が、NaNをチェックするか否かのフラグとなる。AnomalyModeがONの場合、BackwardのテンソルのNaNチェックを行う。具体的なチェックするコードは以下であるtorch/csrc/autograd/engine.cpp。ここで、output.ne(output).any().item<uint8_t>()にて、!=のチェックを行っている、以下は、具体的な箇所周りのコードを示している。

  if (AnomalyMode::is_enabled()) {
    AutoGradMode grad_mode(false);
    for (int i = 0; i < num_outputs; ++i) {
      auto& output = outputs[i];
      at::OptionalDeviceGuard guard(device_of(output));
      if (output.defined() && output.ne(output).any().item<uint8_t>()) {
        std::stringstream ss;
        ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
        throw std::runtime_error(ss.str());
      }
    }
  }

参考資料

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
12
Help us understand the problem. What are the problem?