45
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

GANを使わず画像を綺麗にしたい話(SRFlow)

Last updated at Posted at 2021-12-08

はじめに

ABEJA Advent Calendar 2021の8日目の記事です。

この記事では素晴らしい技術のはずなのになかなか日の目を浴びないFlowと呼ばれる技術を使った超解像について書こうと思います。
これを読んだ暁には「そうか、だから日の目を浴びないのか」となっていると思います。
そしてなぜこの人はこんなマニアックな記事を書いているんだろうと思うことでしょう。

超解像の概要

超解像とはざっくりいうと小さい画像を大きくする技術のことを指します。画素数の少ない低解像度な小さい画像を、画素数の多い高解像度の大きい画像にするということは、何かしらの方法で画素を補間してあげる必要があります。
非常にわかりやすいこちらの記事にもあるように、超解像とは不良設定問題です。
画像丸パクで大変恐縮ですが、1x3pixelの画像を2倍拡大して2x6pixelにする場合、以下のように様々なパターンが考えられます。入力画像から思い描く完璧な出力画像を再現する情報がないので不良設定問題となるわけです。
スクリーンショット 2021-12-06 22.08.01.png

入力画像情報をもとに適切な画素を生成するので、画像生成系のタスクと非常に相性がいいのが超解像の特徴です。そして画像生成タスクといえばVAEかGANとなり、Flowなんて言葉は出てくることすらないのが現実です。今日はこの可哀想な子に少しでも日の目を浴びさせた上で、やっぱりGANがいいよという結論に持っていければなと思っています。

超解像の歴史

遡ること・・・と言いたいところですが、ぐぐったらいろいろ出てくるので、超解像の歴史をだいぶまとめると以下の感じです。

  • 辞書ベースつえー
  • 画像認識でうまくいったDeep Learning適用したら超綺麗になったよ
  • Residual構造で深くしたらもっと綺麗になったよ
  • 画素の平均二乗誤差で学習する時代は終わった。これからはGANでよりリアルな画像を作るのだ
  • GANよりFlowのほうがすごいのに・・・ ⇦ 今日の話

Flowって一体何?

これを話す前に、Flowのことを見ている人はちゃんと見ているという話をしようと思います。今年のCVPR2021でのチャレンジ課題で以下GIFのようなお題がありました。
105862172-c7de7700-5fef-11eb-8f96-319e30b6846b.gif
https://github.com/andreas128/NTIRE21_Learning_SR_Space

冒頭に述べた不良設定問題そのものをどう解決するかというものです。そして参加者の使用した技術を見てみると・・・
スクリーンショット 2021-12-06 22.33.34.png

そうです、Flowベースの手法が一番多いのです。ちなみにベスト手法もFlowベースでした。

なぜFlowベースが強かったのか

シングルモデルから複数の画像の出力を得ることができるからです。今回のお題は単一の入力から複数の画像を出力することが求められており、そのニーズに答えられるのがFlowだったということです。

SRFlowのざっくり概要

ここまで引っ張りましたがFlowを使った超解像はSRFlowと呼ばれており、以下が公式の論文です。
http://de.arxiv.org/pdf/2006.14200?gitT

ざっくりいうと、こんな構造にしたら、
スクリーンショット 2021-12-06 22.54.34.png

こんなふうに綺麗なおじさまがたくさんできるわけです。
スクリーンショット 2021-12-06 22.57.01.png

さぁみなさんも綺麗なおじさまを生成できるように中身を噛み砕いていきましょう。

Normalizing Flow

FLowと何度も書いていますが、正しくはNormalizing Flowといいます。ここから急に話がややこしくなるので、細かい話はいいから早くおじさまを作りたい人は画像出力結果までスキップしてください。

Normalizing Flowとは2015年にGoogleから発表された以下の論文で有名(?)になりました。
https://arxiv.org/pdf/1505.05770.pdf

こちらに式も踏まえて非常にわかりやすく説明されていますので、必要な部分を引用させていただきつつ、この記事ではできる限り数式を使わずにふわっと説明します。また、必要な部分と言いながらかなり引用させてもらっています。この場をお借りしてお礼申し上げます。

Normalizing Flowは変分推論の技術の一つで、観測できる入力$x$(超解像では入力画像)から、未知の値$y$(超解像でいうところの新しい画素)を確率密度関数として表現する確率モデルです。なにかしらの$x$が入力された時の$y$の事後分布 $p(y\mid x)$ がわかれば、入力画像に対応する出力画像がわかるので、それを求めるために$x$と$y$の同時密度関数 $p(x, y)$ をモデルで頑張ろうということです。
これらの関係式は次のように書けます。
$$
p(y\mid x) = \frac{p(x, y)}{\int p(x, y) dy}
$$
当然ながら $p(x,y)$ は知らない画素を生成するので複雑なモデル化が必要です。しかし複雑な分布を表現しようとすると分母の積分がうまく計算できなくなり終わります。
これを解決するために、目的の $p(y\mid x)$ を別の分布 $q(y)$ で表現することで、これらの距離が最小にする問題に置き換えます。
細かい式変換をすっ飛ばすと解くべき問題は以下の式を $\theta$ で最大化することになります。
$$
\mathbb{E}_{y \sim q(y;\theta)}[\log p(x, y) - \log q(y)]
$$
ここで $q(y)$ は $\theta$ によってパラメータ化されているとしています。
Normalizing Flowでは $\theta$ を重みとして $q(y;\theta)$ をニューラルネットで表現する手法です。なぜならほとんどのケースで$q(y)$は具体的な数式で表現できるほど単純ではないからです。
もうここまでで結構な人はお腹いっぱいなはずですが、いま変分理論の説明をしただけでSRFlowはおろか、Normalizing Flowにも到達していません・・・

工夫点と制約

少しごちゃってきたのでもう少し省略しながら書きます。
Normalizing Flowをざっくりいうと、簡単な分布 $p_Z(z)$ をたくさん積み上げて複雑な $q(y;\theta)$ を表現しようというものです。簡単な分布とは例えばガウス分布などを指し、このアイディアがFlowベースの技術の根幹になります。以下の論文引用画像のように、とある分布$z$を$z' = f(z)$ という形で変換していき、$k$回変換した$z_k$が超解像でいうところの出力画素$y$の分布にできるという発想です。
図1.png

式は省略しますが、この非線形写像$f$は可逆計算ができなければいけないという制約があります。可逆変換可能な非線形関数$f$を積み重ねたニューラルネットを構築して、パラメータ$\theta$における微分を計算することになります。
以上の制約をざっくりまとめると以下の通りです。

  1. ある程度複雑なネットワークにしなければ初期分布から変換が十分に行えない
  2. 複雑なネットワークにすると微分計算も繰り返す必要がありメモリ的に厳しい
  3. 微分計算を繰り返すので微分そのものが軽くないと学習が終わらない
  4. 微分計算に工夫必要な非線形関数は逆変換ができなければいけない
    もはやこのめんどくささの時点で日の目を浴びない理由がわかってきたような気がします。

GAN/VAEとの違い

普通に考えれば $f$ は単純にCNNにすればいいのですが、$f$は逆変換が求められるので簡単にCNNポンというわけにはいかないのが辛いところです。なぜなら「逆変換ができる」=「次元変更ができない」ということを意味しており、CNNで特徴マップを増やした or 減らした時点で逆変換ができなくなるからです。
GANやVAEとの違いは以下の画像一枚でざっくりまとめることができます。
three-generative-models.png
引用元:https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html

超解像では$x$が入力画像、$x'$が出力画像になります。GANが完全に分離された潜在変数$z$から生成するのに対し、Flowでは変数変換と逆変換を次元を変えることなく処理していることを表しています。

SRFlow

本題です。少し詳しめに公式コードも添えながら中身の説明をします。コードはモジュールレベルでしか抽出しません。なぜならコード量がそこそこあるからです・・・ また、本コードは研究用途でしか使えないのでライセンスにご注意ください。LICENSE

全体像の再掲
スクリーンショット 2021-12-06 22.54.34.png

SRFlowはNormalizing Flow $f_\theta$とLow Resolution Encoder $g_\theta$に分けられます。右下のTraining Inputが学習用の綺麗な画像で、ここを起点に左の矢印でFlow stepを積み上げていくことで学習をします。学習時は左上の低解像度画像からFlowを逆にたどって右上の超解像画像を得ます。
Flowの中身は以下のようなグレーのLevelと呼ばれるまとまりで分割されます。
スクリーンショット 2021-12-07 14.27.06.png
このLevelを変えつつマルチスケールでFlowを構築します。再掲した全体像では4Levelあることになります。学習プロセスをなぞっていきながら説明していきたいと思います。ちなみに可逆変換なレイヤーを積み上げるので、学習プロセスのコード ≒ 推論時のコードでもあります。

Squeeze

学習用画像はまずここに入ります。公式実装では学習用に160x160のクロップしたカラー画像を使用します。つまり入力次元は(160,160,3)となります。この画像を特徴量空間に並び替えます。やっていることはPixel shufflerのように画素の並び替えをしているだけです。名前を統一してほしいです。
1回のSqueezeでは(160,160,3)(80,80,12)の次元に変わります。次元が変わると言っていますが、全体の画素数は変わらないただの並び替えなので可逆変換可能です。
コードもシンプルに並び替えるだけです。SqueezeLayerクラスのforawrd関数内にnot reverseとありますが、学習時はこっちの処理が走ります。

def squeeze2d(input, factor=2):
    assert factor >= 1 and isinstance(factor, int)
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
    x = input.view(B, C, H // factor, factor, W // factor, factor)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
    x = x.view(B, C * factor * factor, H // factor, W // factor)
    return x


def unsqueeze2d(input, factor=2):
    assert factor >= 1 and isinstance(factor, int)
    factor2 = factor ** 2
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert C % (factor2) == 0, "{}".format(C)
    x = input.view(B, C // factor2, factor, factor, H, W)
    x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
    x = x.view(B, C // (factor2), H * factor, W * factor)
    return x


class SqueezeLayer(nn.Module):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def forward(self, input, logdet=None, reverse=False):
        if not reverse:
            output = squeeze2d(input, self.factor)  # Squeeze in forward
            return output, logdet
        else:
            output = unsqueeze2d(input, self.factor)
            return output, logdet

Transition Step

この部分です。ここはSqueezeで発生しがちな市松模様を吸収するために存在しているらしいです。ここではActnormと1x1 Convolutionの二つの処理がセットとなっています。
スクリーンショット 2021-12-07 15.13.01.png

Actnorm

基本的にはただの正規化レイヤーです。入力されたデータの平均位置とスケールをアフィン変換で調整しています。バッチ正規化と違うのは、平行移動のbiasとスケール変換のlogsが学習パラメータというところです。学習時の最初の入力の平均と分散で初期化して、あとは学習の中でいい感じの正規化をしてもらおうという魂胆です。Flowの学習はメモリが厳しくあまり大きなバッチサイズにできないので、バッチ正規化ではなくActnormで正規化をしているのだと思います。(僕の予想)
実装もシンプルです。学習時の移動平均はreturn input + biasとし、推論時はreturn input - biasとするだけで、スケール変換も学習時はinput = input * torch.exp(logs)とし、推論時はinput = input * torch.exp(-logs)とするだけです。

class _ActNorm(nn.Module):
    """
    Activation Normalization
    Initialize the bias and scale with a given minibatch,
    so that the output per-channel have zero mean and unit variance for that.

    After initialization, `bias` and `logs` will be trained as parameters.
    """

    def __init__(self, num_features, scale=1.):
        super().__init__()
        # register mean and scale
        size = [1, num_features, 1, 1]
        self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
        self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
        self.num_features = num_features
        self.scale = float(scale)
        self.inited = False

    def _check_input_dim(self, input):
        return NotImplemented

    def initialize_parameters(self, input):
        self._check_input_dim(input)
        if not self.training:
            return
        if (self.bias != 0).any():
            self.inited = True
            return
        assert input.device == self.bias.device, (input.device, self.bias.device)
        with torch.no_grad():
            bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
            vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
            logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
            self.bias.data.copy_(bias.data)
            self.logs.data.copy_(logs.data)
            self.inited = True

    def _center(self, input, reverse=False, offset=None):
        bias = self.bias

        if offset is not None:
            bias = bias + offset

        if not reverse:
            return input + bias
        else:
            return input - bias

    def _scale(self, input, logdet=None, reverse=False, offset=None):
        logs = self.logs

        if offset is not None:
            logs = logs + offset

        if not reverse:
            input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
            # input = input * torch.exp(logs+logs_offset)
        else:
            input = input * torch.exp(-logs)
        if logdet is not None:
            """
            logs is log_std of `mean of channels`
            so we need to multiply pixels
            """
            dlogdet = thops.sum(logs) * thops.pixels(input)
            if reverse:
                dlogdet *= -1
            logdet = logdet + dlogdet
        return input, logdet

    def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
        if not self.inited:
            self.initialize_parameters(input)
        self._check_input_dim(input)

        if offset_mask is not None:
            logs_offset *= offset_mask
            bias_offset *= offset_mask
        # no need to permute dims as old version
        if not reverse:
            # center and scale

            # self.input = input
            input = self._center(input, reverse, bias_offset)
            input, logdet = self._scale(input, logdet, reverse, logs_offset)
        else:
            # scale and center
            input, logdet = self._scale(input, logdet, reverse, logs_offset)
            input = self._center(input, reverse, bias_offset)
        return input, logdet

1x1 convolution

カーネルが1x1である以上、ただのアフィン変換なのでFlowでも使えます。この手法(Actnormもですが)はGlowで提案されました。本家ではLU分解で計算を軽くする工夫をしていますが、SRFlowでは使っていません。これは論文中にも使わっていないと記載があります。コード上にはself.LU = LU_decomposedという使おうかなーという意図は見えるので使わなかった本当の理由はよくわからんです。

class InvertibleConv1x1(nn.Module):
    def __init__(self, num_channels, LU_decomposed=False):
        super().__init__()
        w_shape = [num_channels, num_channels]
        w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
        self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
        self.w_shape = w_shape
        self.LU = LU_decomposed

    def get_weight(self, input, reverse):
        w_shape = self.w_shape
        pixels = thops.pixels(input)
        dlogdet = torch.slogdet(self.weight)[1] * pixels
        if not reverse:
            weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
        else:
            weight = torch.inverse(self.weight.double()).float() \
                .view(w_shape[0], w_shape[1], 1, 1)
        return weight, dlogdet
        
    def forward(self, input, logdet=None, reverse=False):
        """
        log-det = log|abs(|W|)| * pixels
        """
        weight, dlogdet = self.get_weight(input, reverse)
        if not reverse:
            z = F.conv2d(input, weight)
            if logdet is not None:
                logdet = logdet + dlogdet
            return z, logdet
        else:
            z = F.conv2d(input, weight)
            if logdet is not None:
                logdet = logdet - dlogdet
            return z, logdet

Conditional Flow Step

ここです。この部分が1番の肝です。
スクリーンショット 2021-12-07 21.53.18.png
青の塊が一つのセットでこのセットを何個も積み上げていくことになります。公式パラメータでは16セット積み上げています。1Levelあたり16セットなので、4Levelなら合計64セット積み上げることになります。
Actnormと1x1 Convolutionは同じなので割愛します。

Affine Injector

SRFlowのタスクは低解像度の画像を高解像度にすることです。今まで説明してきた中では推論時の低解像度画像の情報が入っていません。これをFlowの中に直接注入するのがAffine Injectorです。低解像度の画像はLow Resolution Encoder $g_\theta$によって特徴量を抽出されます。$g_\theta$のencode結果**$u$**でスケール方向とバイアス方向にアフィン変換します。このときencode結果は$g_\theta$の中間層もconcatして大きめの特徴量を抽出しています。(各Levelでの画像サイズx320次元の特徴マップ)
コードは次のConditional Affine Couplingと合わせて紹介します。

Conditional Affine Coupling

ここが1番のメイン変換です。SRFlowではcouplingと呼ばれる手法で計算コストの削減をしつつ、複雑なネットワークを構築します。これはNICEで提案され、Real NVPで改善された手法です。この手法では入力ベクトル$y$を前半の$y_1$と後半の$y_2$に分割し、変換後のベクトル$z$も前半の$z_1$と後半の$z_2$に分割します。そして$y_2$を変換するためのパラメータを$y_1$から求めます。

$$
z_1 = y_1 \
z_2 = exp(s(y_1))⋅y_2 + t(y_1)
$$

このとき関数$s$と$t$はどんなに複雑になってもよいうのが自由度が高く、複雑な変換ができるようになった要因の一つです。SRFlowでは上式の$z_2$を求める時に$y_1$にLow Resolution Encoderの結果をconcatして畳み込むことで低解像度画像の特徴量をうまく活用しています。

ちなみにLow Resolution Encoderから伸びている矢印は入力画像の特徴量を注入していることを意味しています。
スクリーンショット 2021-12-07 22.38.54.png

class CondAffineSeparatedAndCond(nn.Module):
    def __init__(self, in_channels, opt):
        super().__init__()
        self.need_features = True
        self.in_channels = in_channels
        self.in_channels_rrdb = 320
        self.kernel_hidden = 1
        self.affine_eps = 0.0001
        self.n_hidden_layers = 1
        hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
        self.hidden_channels = 64 if hidden_channels is None else hidden_channels

        self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'],  0.0001)

        self.channels_for_nn = self.in_channels // 2
        self.channels_for_co = self.in_channels - self.channels_for_nn

        if self.channels_for_nn is None:
            self.channels_for_nn = self.in_channels // 2

        self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
                              out_channels=self.channels_for_co * 2,
                              hidden_channels=self.hidden_channels,
                              kernel_hidden=self.kernel_hidden,
                              n_hidden_layers=self.n_hidden_layers)

        self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
                                out_channels=self.in_channels * 2,
                                hidden_channels=self.hidden_channels,
                                kernel_hidden=self.kernel_hidden,
                                n_hidden_layers=self.n_hidden_layers)

    def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
        if not reverse:
            z = input
            assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)

            # Feature Conditional
            scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
            z = z + shiftFt
            z = z * scaleFt
            logdet = logdet + self.get_logdet(scaleFt)

            # Self Conditional
            z1, z2 = self.split(z)
            scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
            self.asserts(scale, shift, z1, z2)
            z2 = z2 + shift
            z2 = z2 * scale

            logdet = logdet + self.get_logdet(scale)
            z = thops.cat_feature(z1, z2)
            output = z
        else:
            z = input

            # Self Conditional
            z1, z2 = self.split(z)
            scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
            self.asserts(scale, shift, z1, z2)
            z2 = z2 / scale
            z2 = z2 - shift
            z = thops.cat_feature(z1, z2)
            logdet = logdet - self.get_logdet(scale)

            # Feature Conditional
            scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
            z = z / scaleFt
            z = z - shiftFt
            logdet = logdet - self.get_logdet(scaleFt)

            output = z
        return output, logdet

    def asserts(self, scale, shift, z1, z2):
        assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
        assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
        assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
        assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])

    def get_logdet(self, scale):
        return thops.sum(torch.log(scale), dim=[1, 2, 3])

    def feature_extract(self, z, f):
        h = f(z)
        shift, scale = thops.split_feature(h, "cross")
        scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
        return scale, shift

    def feature_extract_aff(self, z1, ft, f):
        z = torch.cat([z1, ft], dim=1)
        h = f(z)
        shift, scale = thops.split_feature(h, "cross")
        scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
        return scale, shift

    def split(self, z):
        z1 = z[:, :self.channels_for_nn]
        z2 = z[:, self.channels_for_nn:]
        assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
        return z1, z2

    def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
        layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]

        for _ in range(n_hidden_layers):
            layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
            layers.append(nn.ReLU(inplace=False))
        layers.append(Conv2dZeros(hidden_channels, out_channels))

        return nn.Sequential(*layers)

self.fFeaturesがAffine Injector、self.fAffineがConditional Affine Couplingで使用されます。もっとわかりやすい名前つけてほしかった。

Split

基本的にLevelの中でFlowの計算は完結します。Splitレイヤーでは頑張って計算した$z$を特徴量方向に半分に分割して半分だけを次のLevelに渡します。モデルの軽量化と各Levelでの異なった解像度に対する汎化性能を上げているのだというのが個人的見解です。dropoutに近い感覚です。

class Split2d(nn.Module):
    def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
        super().__init__()

        self.num_channels_consume = int(round(num_channels * consume_ratio))
        self.num_channels_pass = num_channels - self.num_channels_consume

        self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
                                out_channels=self.num_channels_consume * 2)
        self.logs_eps = logs_eps
        self.position = position
        self.opt = opt

    def split2d_prior(self, z, ft):
        if ft is not None:
            z = torch.cat([z, ft], dim=1)
        h = self.conv(z)
        return thops.split_feature(h, "cross")

    def exp_eps(self, logs):
        return torch.exp(logs) + self.logs_eps

    def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
        if not reverse:
            # self.input = input
            z1, z2 = self.split_ratio(input)
            mean, logs = self.split2d_prior(z1, ft)
            
            eps = (z2 - mean) / self.exp_eps(logs)

            logdet = logdet + self.get_logdet(logs, mean, z2)

            # print(logs.shape, mean.shape, z2.shape)
            # self.eps = eps
            # print('split, enc eps:', eps)
            return z1, logdet, eps
        else:
            z1 = input
            mean, logs = self.split2d_prior(z1, ft)

            if eps is None:
                #print("WARNING: eps is None, generating eps untested functionality!")
                eps = GaussianDiag.sample_eps(mean.shape, eps_std)

            eps = eps.to(mean.device)
            z2 = mean + self.exp_eps(logs) * eps

            z = thops.cat_feature(z1, z2)
            logdet = logdet - self.get_logdet(logs, mean, z2)

            return z, logdet
            # return z, logdet, eps

    def get_logdet(self, logs, mean, z2):
        logdet_diff = GaussianDiag.logp(mean, logs, z2)
        # print("Split2D: logdet diff", logdet_diff.item())
        return logdet_diff

    def split_ratio(self, input):
        z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
        return z1, z2

ロス計算

それぞれのレイヤーで計算したlogdetを最後にnegative log-likelihoodに変換してこれをロスとします。

objective = logdet.clone()
objective = objective + flow.GaussianDiag.logp(None, None, z)
nll = (-objective) / float(np.log(2.) * pixels)

推論(超解像)

冒頭に超解像とは少ない画素から画素を補間すると記載しましたが、Normalizing Flowの制約である「次元を変更できない」時点でもう普通には構築できないということにお気づきでしょうか・・・
次元が変更できないということは、ネットワークのinputからoutputまでの総画素数(width x height x ch)が変わることはないので、画像の拡大どころではありません。これを解決するのがLow Resolution Encoder $g_\theta$です。
スクリーンショット 2021-12-07 1.59.10.png
Low Resolution Encoder $g_\theta$はFlow $f_\theta$の外側にいるため制約の対象外となります。低解像度画像は$g_\theta$に入力されて次元を拡張します。
例えば入力画像のshapeが(100,100,3)で、4倍拡大用のFlowを構築していた場合、Flowの中の総次元数は拡大後の400x400x3になります。結局拡大そのものの処理はFlowの外で行うということです。$g_\theta$は拡大さえできればなんでもOKです。公式実装ではESRGANで提案されたResidual-in-Residual Dense Block(RRDB)のPNSR(≒最小二乗誤差)でpre-trainされたモデルを使用しています。

実際の拡大画像

今回はPNSR代表のRRDBとESRGANとSRFlowで4倍拡大をして結果を見てみます。定量評価はめんどくさいのでしません。
違いがはっきりわかる8倍拡大にしたかったのですが、RRDBとESRGANの事前学習モデルが4倍しかなかったので諦めました。自分で学習するのがめんどくさかった
使った画像はとりあえずでいつも使われるset5set14です。適当に何個か貼ります。
後述するSRFlowのパラメータであるheatは0.9に設定しています。

RRDB、ESRGANとの比較結果

左から順に元画像、RRDB、ESRGAN、SRFlowです。サムネだとわかりにくいのでぜひクリックして等倍で見ていただきたいです。
baboon_concat.png
baby_concat.png
barbara_concat.png
butterfly_concat.png
coastguard_concat.png
head_concat.png
lenna_concat.png
man_concat.png
pepper_concat.png
ppt3_concat.png

どうでしょうか?こんなに頑張ったのに全部一緒に見えませんか?
等倍でちゃんと見るとRRDBは全体的に塗りつぶしたような画像が生成されており、ESRGANとSRFlowは細部を丁寧に再現できています。ESRGANで色合いが変わっているものがあるのは謎です。定量評価は載せていませんが、SRflowの方がいい結果を残しています。
全部一緒に見えるというのは純粋に4倍拡大程度では違いがわからないくらい超解像の精度が上がっているとも言えます。いいことですね。

SRFlowの8倍拡大

左が元画像、右がSRFlowで生成したものです。画像はDIV2Kから持ってきました。(学習に使っていないデータです)
0803_concat.png
0804_concat.png

2枚目の服のチェック柄はかなり潰れてしまいましたが、8倍なら十分すぎる綺麗さじゃないでしょうか。上のオオカミが下のように拡大されるわけです。8倍拡大となると1pixelから16pixelを求めるわけですからここまで鮮明に再現されるのは素晴らしい性能です。
0805x8.png
0805x8_srflow.png

SRFlowで単一モデルから複数の画像を生成

SRFlowでは初期分布次第で様々な画像を生成することができます。これがCVPRのチャレンジ課題でよく使われた理由です。モデル図にも$z$というノイズっぽい画像があります。
スクリーンショット 2021-12-08 1.16.51.png
これはガウス分布にしたがってランダム作成した初期分布になります。超解像の時はこのガウス分布+Low Resolution Encoderの情報をInputとしてFlowが逆方向に進んでいって超解像画像を生成します。つまり毎回画像を生成するたびに違う画像ができるということです。ほんとか試してみます。
0869x8_concat.png
全く同じとしか思えない3枚の猫ができました。かわいいです。

ガウス分布の広がりに制限をかける

実はガウス分布を作るとき(torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)))に標準偏差を0~1で振ってあげると違う画像ができます。論文中ではtemperature, コード上ではheatなんて呼ばれています。1に近づくほど様々な画素を生成しようとします。つまり細部を細かく再現しようとします。0に近づくとPSNRで学習したときのように画像がベタ塗りになっていきます。

左がstd=0、右がstd=1の画像です。確かにハイパーパラメータなどを何も変えずに同じモデルから違う画像を生成することができました。
0869x8_concat1.png

# SRFlowの学習
実際に使う時は何かしらの処理を加えたりfine tuningしたりすると思いますが、SRFlowはとにかく学習にかかるGPUメモリ消費がすごいです。160x160の学習用画像でバッチサイズ16にしてもデフォルトの設定で学習する時は常に15GB弱のGPUメモリを使用します。

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| 52%   54C    P2   309W / 350W |  14499MiB / 24268MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

この学習をやりたかっとこともあり、私の個人オンプレ環境はGTX1080tiからRTX3090に格上げされました。
学習コストが高いこともFLowが流行らない大きな理由の一つだと思います。

まとめ

SRFlowは技術的に非常に面白いですし、精度も素晴らしいですが以下の欠点があります。

  • モデル構築の制約が多い
  • 学習コストが高い(推論コストも同様に高い)
  • なんかむずかしいしめんどくさい実装も多い
  • 別にGANでよくない?

ただ、SRFlowは初期分布で様々な画像を生成できるので不良検出タスクのデータ拡張に使えそうです。

おまけ

最後におじさまを超解像してみました。
ojisama.png ojisama_srflow.png

できませんでした:innocent:
低解像度の元画像が見つからなかったのでスクリーンショットで画像を取得したのですが、これだとうまくいかないみたいです。超解像全ての課題ですが、結局縮小方法が一致していないとうまく画像生成できません。低解像度画像と高解像度画像のペアで学習することの弊害です。

お知らせ

現在ABEJAでは一緒にAIの社会実装を進める仲間を募集しています。
ABEJA Advent Calendar 2021を読んで少しでもいいねとおもったら、まずはお話を聞きに来てください。
【現在募集中の職種】 はこちらから確認できます。ご応募をお待ちしております。

45
20
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
45
20

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?