LoginSignup
2
0

More than 3 years have passed since last update.

TensorFlow 2.0 の Frechet Inception Distance (FID) の計算の tf.linalg.sqrtm() で Nan が出るときの応急処置

Posted at

概要

[1] の numpy で書かれたコードを TensorFlow 2.0 に移植するため、FID の計算を TensorFlow のみでしようとしていたところ、scipy.linalg.sqrtm() では計算可能な 2048x2048ndarray の平方根を求めようとしても、TensorFlow 側だけ Nan になってしまう現象にぶつかった。

  • TensorFlow 1.0 でも同じ現象が起こるというイシューが GitHub に立っている (同じことをすれば対処できるはず)

対処法

対処には二段階必要です。

まず、tf.linalg.sqrtm() の一個前の計算である、$\Sigma_r \Sigma_g$ の時点で、tf.cast を噛ませて float64 に設定しないといけません。TensorFlow のデフォルトは float32 のため、計算が失敗する可能性があるらしい1です。

sigma1, sigma2 = tf.cast(sigma1, tf.float64), tf.cast(sigma2, tf.float64)

次に、[2] を参考に、微小の数値 eps を対角行列として加え、その後行列の掛け算を行います。ここでは、全て型を tf.float64 に設定しないといけません。

eps = tf.constant(1e-6, dtype=tf.float64)
offset = tf.eye(2048, dtype=tf.float64) * eps
tdot = tf.tensordot(sigma1+offset, sigma2+offset, axes=1)

これで対処が終わりました。あとは、

covmean = tf.linalg.sqrtm(tdot)

で、一応計算が通り、ほぼ同じ計算結果になります。

問題点

計算は止まらなくなったものの、TensorFlow 側の計算が、Scipy の計算より数十倍遅いので、修正方法が分かり次第追記する。

参考

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0