2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Tensorflow 2.x でもFIDが計算したい!

Last updated at Posted at 2020-07-30

Tensorflow で FID を計算したい!楽に

Tensorflow を使って FID を計算したいなーって思ったけど色々調べていたけど、以下の点があんまり良くなかったです

  • 一旦画像データを全部メモリに乗せるのが大変
    FID スコアは沢山データがあればあるほど安定します。というかデータ数によってスコアが変わります。うっかり numpy で画像を全部持ってくると、メモリに乗せてしまうので、地獄です。
  • tensorflow データセットと上手く合わせたい
    せっかく tensorflow は色々前処理をさせてくれるんだし、利用できるといいなーと思いました。

コード

そのままファイルにして保存すればいいんじゃないっすかね。

# !/usr/bin/env python3
import tensorflow as tf
import tensorflow_probability as tfp
from tqdm import tqdm
from typing import Dict

EPS_VAL = 1e-6
LENGTH_FEATURE_VEC = 2048
AUTOTUNE = tf.data.experimental.AUTOTUNE


class FID:
    def __init__(
        self,
        scaling_func: tf.image.ResizeMethod = "nearest",
        batch_size: int = 128,
        num_samples: int = 100000,
    ):
        assert num_samples >= 2048, "invalid sample size"
        self.scaling_func = scaling_func
        self.model = tf.keras.applications.InceptionV3(
            include_top=False, pooling="avg", input_shape=(299, 299, 3)
        )
        self.num_samples = num_samples
        while num_samples % batch_size != 0:
            batch_size = batch_size - 1
        self.batch_size = batch_size

    def rescale_img_size(self, img: tf.Tensor):
        return tf.image.resize(img, size=(299, 299), method=self.scaling_func)

    def gray_to_color(self, img):
        return tf.tile(img, multiples=[1, 1, 1, 3])

    @tf.function
    def calculate_fid(self, feat1, feat2):
        mu1, sigma1 = tf.reduce_mean(feat1, axis=0), tfp.stats.covariance(feat1)
        mu2, sigma2 = tf.reduce_mean(feat2, axis=0), tfp.stats.covariance(feat2)
        ssdiff = tf.reduce_sum(tf.square(mu1 - mu2))

        mu1 = tf.cast(mu1, tf.float64)
        mu2 = tf.cast(mu2, tf.float64)
        sigma1 = tf.cast(sigma1, tf.float64)
        sigma2 = tf.cast(sigma2, tf.float64)
        ssdiff = tf.cast(ssdiff, tf.float64)

        eps = tf.constant(EPS_VAL, dtype=tf.float64)
        offset = tf.eye(LENGTH_FEATURE_VEC, dtype=tf.float64) * eps
        tdot = tf.tensordot(sigma1 + offset, sigma2 + offset, axes=1)
        covmean = tf.linalg.sqrtm(tdot)
        covmean = tf.math.real(covmean)
        fid = ssdiff + tf.linalg.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid

    @tf.function
    def process(self, img: tf.Tensor):
        """
        Args:
            img (tf.Tensor): [B, H, W, C] where each element R[0, 255]
        """
        if img.shape[-1] == 1:
            img = self.gray_to_color(img)
        img = tf.cast(img, tf.float64)
        img = self.rescale_img_size(img)
        img = tf.keras.applications.inception_v3.preprocess_input(img)
        return img

    @tf.function
    def get_feature(self, img: tf.Tensor):
        return self.model(img)

    def calculate_fid_with_ds(self, ds1: tf.data.Dataset, ds2: tf.data.Dataset):

        ds1 = ds1.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)
        ds2 = ds2.map(self.process).batch(self.batch_size).prefetch(AUTOTUNE)

        feature1 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
        feature2 = tf.zeros(shape=(0, 2048), dtype=tf.float32)
        for idx, batch1, batch2 in tqdm(
            zip(range(self.num_samples // self.batch_size), ds1, ds2)
        ):
            feat1 = self.get_feature(batch1)
            feat2 = self.get_feature(batch2)
            feature1 = tf.concat([feature1, feat1], axis=0)
            feature2 = tf.concat([feature2, feat2], axis=0)
        fid = self.calculate_fid(feature1, feature2)
        return fid

使い方

if __name__ == "__main__":
    import pathlib
    import tensorflow_datasets as tfds

    @tf.function
    def extract_image(sample: Dict):
        img = tf.cast(sample["image"], tf.float32)
        shapes = tf.shape(img)
        h, w = shapes[-3], shapes[-2]
        small = tf.minimum(h, w)
        img = tf.image.resize_with_crop_or_pad(img, small, small)
        return img

    import matplotlib.pyplot as plt

    ds1 = tfds.load("cifar10")["train"].map(extract_image)
    ds2 = tfds.load("cifar10")["test"].map(extract_image)

    fid = FID(num_samples=10000)
    print(fid.calculate_fid_with_ds(ds1, ds2))

output

tf.Tensor(5.46545230432653, shape=(), dtype=float64)

ちなみに、 num_samples は、FID 計算に用いる画像の数です。データセットに含まれる画像数全部を入れるといいんじゃないっすかね。

なお、サンプルの数とFIDスコアの関係は次のとおりです。

FIDvsSample.png

使ったコードは以下です。

if __name__ == "__main__":
    import pathlib
    import tensorflow_datasets as tfds

    @tf.function
    def extract_image(sample: Dict):
        img = tf.cast(sample["image"], tf.float32)
        shapes = tf.shape(img)
        h, w = shapes[-3], shapes[-2]
        small = tf.minimum(h, w)
        img = tf.image.resize_with_crop_or_pad(img, small, small)
        return img

    import matplotlib.pyplot as plt

    ds1 = tfds.load("cifar10")["train"].map(extract_image)
    ds2 = tfds.load("cifar10")["test"].map(extract_image)

    idxs = []
    fids = []
    for i in range(3, 11):
        fid = FID(num_samples=i * 1000, batch_size=100)
        score = fid.calculate_fid_with_ds(ds1, ds2)
        print("samples {}: {}".format(i * 1000, score))
        fids.append(score.numpy())
        idxs.append(i * 1000)
    plt.plot(idxs, fids, marker="o")
    plt.title("FID Score vs Sample Size")
    plt.ylabel("FID Score")
    plt.xlabel("Sample Size")
    plt.show()
2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?