Tensorflow で FID を計算したい!楽に
Tensorflow を使って FID を計算したいなーって思ったけど色々調べていたけど、以下の点があんまり良くなかったです
- 一旦画像データを全部メモリに乗せるのが大変
FID スコアは沢山データがあればあるほど安定します。というかデータ数によってスコアが変わります。うっかり numpy で画像を全部持ってくると、メモリに乗せてしまうので、地獄です。 - tensorflow データセットと上手く合わせたい
せっかく tensorflow は色々前処理をさせてくれるんだし、利用できるといいなーと思いました。
コード
そのままファイルにして保存すればいいんじゃないっすかね。
# !/usr/bin/env python3
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm
from typing import Dict
EPS_VAL = 1e-6
LENGTH_FEATURE_VEC = 2048
AUTOTUNE = tf.data.experimental.AUTOTUNE
class FID:
def __init__(
self,
scaling_func: tf.image.ResizeMethod = "nearest",
batch_size: int = 128,
num_samples: int = 100000,
):
assert num_samples >= 2048, "invalid sample size"
self.scaling_func = scaling_func
self.model = tf.keras.applications.InceptionV3(
include_top=False, pooling="avg", input_shape=(299, 299, 3)
)
self.num_samples = num_samples
while num_samples % batch_size != 0:
batch_size = batch_size - 1
self.batch_size = batch_size
def rescale_img_size(self, img: tf.Tensor):
return tf.image.resize(img, size=(299, 299), method=self.scaling_func)
def gray_to_color(self, img):
return tf.tile(img, multiples=[1, 1, 1, 3])
@tf.function
def calculate_fid(self, feat1, feat2):
mu1, sigma1 = tf.reduce_mean(feat1, axis=0), tfp.stats.covariance(feat1)
mu2, sigma2 = tf.reduce_mean(feat2, axis=0), tfp.stats.covariance(feat2)
ssdiff = tf.reduce_sum(tf.square(mu1 - mu2))
mu1 = tf.cast(mu1, tf.float64)
mu2 = tf.cast(mu2, tf.float64)
sigma1 = tf.cast(sigma1, tf.float64)
sigma2 = tf.cast(sigma2, tf.float64)
ssdiff = tf.cast(ssdiff, tf.float64)
eps = tf.constant(EPS_VAL, dtype=tf.float64)
offset = tf.eye(LENGTH_FEATURE_VEC, dtype=tf.float64) * eps
tdot = tf.tensordot(sigma1 + offset, sigma2 + offset, axes=1)
covmean = tf.linalg.sqrtm(tdot)
covmean = tf.math.real(covmean)
fid = ssdiff + tf.linalg.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
@tf.function
def process(self, img: tf.Tensor):
"""
Args:
img (tf.Tensor): [B, H, W, C] where each element R[0, 255]
"""
if img.shape[-1] == 1:
img = self.gray_to_color(img)
img = tf.cast(img, tf.float64)
img = self.rescale_img_size(img)
img = tf.keras.applications.inception_v3.preprocess_input(img)
return img
@tf.function
def get_feature(self, img: tf.Tensor):
return self.model(img)
def calculate_fid_with_ds(self, ds1: tf.data.Dataset, ds2: tf.data.Dataset):
ds1 = ds1.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)
ds2 = ds2.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)
feature1 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
feature2 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
for idx, batch1, batch2 in tqdm(
zip(range(self.num_samples // self.batch_size), ds1, ds2)
):
feat1 = self.get_feature(batch1)
feat2 = self.get_feature(batch2)
feature1 = tf.concat([feature1, feat1], axis=0)
feature2 = tf.concat([feature2, feat2], axis=0)
fid = self.calculate_fid(feature1, feature2)
return fid
使い方
if __name__ == "__main__":
import pathlib
import tensorflow_datasets as tfds
@tf.function
def extract_image(sample: Dict):
img = tf.cast(sample["image"], tf.float32)
shapes = tf.shape(img)
h, w = shapes[-3], shapes[-2]
small = tf.minimum(h, w)
img = tf.image.resize_with_crop_or_pad(img, small, small)
return img
import matplotlib.pyplot as plt
ds1 = tfds.load("cifar10")["train"].map(extract_image)
ds2 = tfds.load("cifar10")["test"].map(extract_image)
fid = FID(num_samples=10000)
print(fid.calculate_fid_with_ds(ds1, ds2))
output
tf.Tensor(5.46545230432653, shape=(), dtype=float64)
ちなみに、 num_samples は、FID 計算に用いる画像の数です。データセットに含まれる画像数全部を入れるといいんじゃないっすかね。
なお、サンプルの数とFIDスコアの関係は次のとおりです。
使ったコードは以下です。
if __name__ == "__main__":
import pathlib
import tensorflow_datasets as tfds
@tf.function
def extract_image(sample: Dict):
img = tf.cast(sample["image"], tf.float32)
shapes = tf.shape(img)
h, w = shapes[-3], shapes[-2]
small = tf.minimum(h, w)
img = tf.image.resize_with_crop_or_pad(img, small, small)
return img
import matplotlib.pyplot as plt
ds1 = tfds.load("cifar10")["train"].map(extract_image)
ds2 = tfds.load("cifar10")["test"].map(extract_image)
idxs = []
fids = []
for i in range(3, 11):
fid = FID(num_samples=i * 1000, batch_size=100)
score = fid.calculate_fid_with_ds(ds1, ds2)
print("samples {}: {}".format(i * 1000, score))
fids.append(score.numpy())
idxs.append(i * 1000)
plt.plot(idxs, fids, marker="o")
plt.title("FID Score vs Sample Size")
plt.ylabel("FID Score")
plt.xlabel("Sample Size")
plt.show()