NaNとは
##概要
NaNは、浮動小数点表記で表現できない領域の表記 (Not a Numberの略) である。IEEE754-1985で定義されている。機械学習の場合、不適切なハイパーパラメータの設定を行う場合に発生しやすい。12
演算結果が、NaNになっても演算は続けることが出来る。しかし、計算を続けたい場合は良いが、発生時に検知したい場合がある。ここでは、その検知方法について記載する。
なお、PyTorchは、torch (PythonとC++) とaten (C++) で記述されている。これは、処理の高速化を図るためである。このため、説明にC++のコードが入ってくる。
##NaNの演算
NaNの演算は、以下の通りである。
- NaNと別の値を演算しても、NaNのままである。
- NaNの値は、通常の値とは異なり自身の値と比較すると
True
では無くFalse
となる。
NaN検出のやり方
PyTorchでは、2つのNaN検出方法が提供されている。Tensorそのもの、およびBackward出力のTensorの検出である。ここでは、使い方を説明する。
Tensorの場合
TensorをNaNか否かをチェックする。
torch.isnan(torch.tensor([1, float('nan'), 2]))
上記の出力が以下となる。このため、すべての要素がゼロかそれ以外かで判断出来る。一つでもあったらNGなら、any()
で確認すればよい。
tensor([0, 1, 0], dtype=torch.uint8)
Backward演算出力のTensorの場合
detect_anomalyのフラグを立てて、Backward演算でのTensorでNaNが出力されたか否かを確認する。
torch.autograd.set_detect_anomaly(True)
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()
もしくは、以下のように用いる。
with torch.autograd.detect_anomaly()
inp = torch.rand(10, 10, requires_grad=True)
out = run_fn(inp)
out.backward()
NaN検出の仕組み
2つのNaNの検出の仕組みについて、説明する。
Tensorの場合
PyTorchのPython層でtorch.isnanがTensor同士でNaNの条件である不等号(!=)であるかを確認する。(torch/fuctional.pyのdef isnan(tensor):)
さて、Pythonの仕様で、不等号(!=)は、オブジェクトで定義した__ne__
で処理される。PyTorchの場合、__ne__
は、torch/tensor.pyで定義している。そして、neの処理はC++層(PyTypeObject)THPVariableTypeで処理を行う。
Backward演算出力のTensorの場合
Backwardの計算は、torch/csrc/autograd/engine.cppで行う。
C++のAnomalyMode構造体が、NaNをチェックするか否かのフラグとなる。AnomalyModeがONの場合、BackwardのテンソルのNaNチェックを行う。具体的なチェックするコードは以下であるtorch/csrc/autograd/engine.cpp。ここで、output.ne(output).any().item<uint8_t>()
にて、!=のチェックを行っている、以下は、具体的な箇所周りのコードを示している。
if (AnomalyMode::is_enabled()) {
AutoGradMode grad_mode(false);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && output.ne(output).any().item<uint8_t>()) {
std::stringstream ss;
ss << "Function '" << fn.name() << "' returned nan values in its " << i << "th output.";
throw std::runtime_error(ss.str());
}
}
}
#参考資料
-
PyTorchコミュニティ
-
Pythonライブラリ
-
3.3.3. Customizing class creation
-
__ne__
について
-
-
9.2. math — Mathematical functions
-
math.nan
について
-
-
PEP 485 -- A Function for testing approximate equality
-
nan
の導入
-
-
3.3.3. Customizing class creation
-
そのほか
- (stackoverflow)Pytorch Operation to detect NaNs
-
2018年12月以降の更新
-
Move isnan to C++ #15722
- isnan関数がpython実装からC++実装に移行
-
Move isnan to C++ #15722