概要
[1] の numpy で書かれたコードを TensorFlow 2.0 に移植するため、FID の計算を TensorFlow のみでしようとしていたところ、scipy.linalg.sqrtm()
では計算可能な 2048x2048
の ndarray
の平方根を求めようとしても、TensorFlow 側だけ Nan
になってしまう現象にぶつかった。
- TensorFlow 1.0 でも同じ現象が起こるというイシューが GitHub に立っている (同じことをすれば対処できるはず)
対処法
対処には二段階必要です。
まず、tf.linalg.sqrtm()
の一個前の計算である、$\Sigma_r \Sigma_g$ の時点で、tf.cast
を噛ませて float64
に設定しないといけません。TensorFlow のデフォルトは float32
のため、計算が失敗する可能性があるらしい1です。
sigma1, sigma2 = tf.cast(sigma1, tf.float64), tf.cast(sigma2, tf.float64)
次に、[2] を参考に、微小の数値 eps
を対角行列として加え、その後行列の掛け算を行います。ここでは、全て型を tf.float64
に設定しないといけません。
eps = tf.constant(1e-6, dtype=tf.float64)
offset = tf.eye(2048, dtype=tf.float64) * eps
tdot = tf.tensordot(sigma1+offset, sigma2+offset, axes=1)
これで対処が終わりました。あとは、
covmean = tf.linalg.sqrtm(tdot)
で、一応計算が通り、ほぼ同じ計算結果になります。
問題点
計算は止まらなくなったものの、TensorFlow 側の計算が、Scipy の計算より数十倍遅いので、修正方法が分かり次第追記する。
参考
- [1] How to Implement the Frechet Inception Distance (FID) for Evaluating GANs: https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/
- [2] FID get "nan" or "complex number": https://github.com/bioinf-jku/TTUR/issues/4