21
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

2021年のディープラーニング論文を1人で読むAdvent CalendarAdvent Calendar 2021

Day 13

【GAN応用】Deformable Convを使ったフォント生成についての論文紹介【DG-Font】

Last updated at Posted at 2021-12-12

2021年のディープラーニング論文を1人で読むAdvent Calendar13日目の記事です。今回紹介するのはGANの応用で、フォントの合成です。

ひとえにフォントといっても対応すべき文字は大量にあります。中国語なら6万字以上、日本語なら5万字以上、韓国語なら11172文字もあります。さらにフォントごとのスタイルが大量にあるので、いちいち手作業では作ってられない。だからAIで作ってしまおうというのが、この論文の動機です。

実装的には「Deformable Convを使った幾何学変形可能なSkip Connection」「Content ImageとStyle Imageを直交化し、未知のフォントに対する生成も可能にする」というなかなか渋いことをやっています。著者は華東師範大学の方々で、論文はCVPR2021に採択されています。

フォントの文字はいっぱいあって作るの大変。AIで作っちゃおう

図のようなことをしたいのです。左の「Reference calligraphy」は実際に人間が作ったフォント、これを真似るような形で、既存のフォントから右の「Imitation result」を作りたいのです。

13_01.png

実践的な問題としては、Aというフォントは中国語、日本語、韓国語すべてに対応していたとします(Google Font等においてある著名なフォントがこれに当たるでしょう)。一方でBというフォントはスタイルが独特で、対応している文字が少ないとします。Bというフォントを使いたいが、実際に表示すると対応文字が少ないから「・」と欠けて表示されてしまう。これをどうにかしたい、という結構需要のありそうな問題です。

この問題でのニューラルネットワークの入出力は次のようになります(ただし本論文の実験では、フォントごとの対応文字数までは勘案していない点には注意が必要です)。

  • 入力:「フォントBでは対応していないが、フォントAでは対応している文字(Content Image)」+「フォントBで対応している文字(Style Image)」
  • 出力:「フォントBのスタイルでの、Content Imageと同じ文字(Immitation Result)」

これはUnpairedなImage to image translation(例:CycleGAN)や、スタイル変換の問題(例:AdaIN)として捉えられます。しかし、これらのやり方は、ぼやけたりストロークを無視したりとフォント生成には向いていません。なぜなら、これらは一般に、テクスチャや色を転移する傾向にあり1、フォント特有の幾何学的な形状を転移するのが難しいからです。

この問題を解決するために、この論文では**FDSC(Feature Deformation Skip Connection)**というニューラルネットワークのモジュールを導入しています。

関連:Kaggleのベンガル語コンペ

実はこの論文、Kaggleのベンガル語コンペの1位のSolutionと内容が近いのです。以下の記事に詳細にまとめられています。

【Kaggle】2020年に開催された画像分類コンペの1位の解法を紹介します

このコンペの1位のSolutionでは、CycleGANを使ってフォント生成を行っています。この論文ではCycleGANはある意味でベースラインとなっており、より高画質なフォント生成に成功しています。自分はKaggleやってないのでどうこう言うつもりはありませんが、もし似たようなコンペがあったら、この論文の手法は参考になるかもしれません。

タイトルのUnsupervisedについて

この論文のタイトルに「Unsupervised」がついています。なぜ「Unsupervised」を強調したかったのかというと、フォント生成の先行研究では、部首の分解や、書き順など補助的なアノテーションを使っていたためです。そういった補助アノテーションを使わずにできるよ、ということを示すためにUnsupervisedでを強調しています。

Deformable Convolution

この論文では「Deformable Convolution」という別の論文の手法が重要になっています。この理解が重要なので、元論文からいくつか引用してきます。DGFontではDeformable Conv V2を使っていますが、V1とV2の論文両方から図を引用します。

通常の畳み込みとDeformable Convの違い

通常の畳み込みはカーネルの参照範囲が固定だが、Deformable Convはカーネルの参照範囲に学習可能なパラメーターを入れ、データに適応的な参照を可能にしています。図で見るのが一番早いです。

13_02.png

これはV1の論文からですが、左が通常の畳み込み、右がDeformable Convです。同じ3×3カーネルでも通常の畳み込みは、固定範囲を参照し、特徴量がピラミッド形式にマッピングされていきます。Deformable Convの場合は多種多彩な参照範囲になります。

多種多彩な範囲をどう実装しているのかは数式を見るとわかります。3×3のカーネルの場合、畳込みカーネルの対応する画素のインデックスを、

$$\mathcal{R}={(-1, -1), (-1, 0), \cdots, (0, 1), (1, 1)}$$

とします。通常の畳み込みでの位置$\mathbf{p}_0$における出力の特徴マップ$\mathbf{y}$は、

\mathbf{y}(\mathbf{p}_0)=\sum_{\mathbf{p}_n\in\mathcal{R}}\mathbf{w}(\mathbf{p}_n)\cdot\mathbf{x}(\mathbf{p}_0+\mathbf{p}_n)

これをDeformable Convでは、

\mathbf{y}(\mathbf{p}_0)=\sum_{\mathbf{p}_n\in\mathcal{R}}\mathbf{w}(\mathbf{p}_n)\cdot\mathbf{x}(\mathbf{p}_0+\mathbf{p}_n+\Delta\mathbf{p}_n)

という学習可能な${\Delta\mathbf{p}_n|n=1, \cdots, N}$を追加します。これによって適応的な参照が可能となっているわけです。

Deformable Convを入れたいわけ

幾何学的な変形を行いたいからです。通常の畳み込みだと固定範囲のみで、空間方向に動的な変換は難しいです。

13_03.png

このように適応的な参照範囲にして、幾何学的な(空間方向の)変形を実装しようという目論見です。チャンネル数が$2N$となっているのは、畳込みカーネルの値が$N$個、オフセットの値が$N$個あるからです。

DGFontのモデル全図

本論文(DGFont)のモデル全図は次の通りです。

13_04.png

やっていることはContent Imageと同じ文字を、Style Imageのフォントスタイルで生成するということです。ResNetベースのアーキテクチャーですが、随所にDeformable Convが使われています。詳細なアーキテクチャーは次のとおりです。

13_05.png

Style Encoderについて

Style Imageを128次元の特徴量$Z_s$にエンコードしています。これはよくある形なので問題ないでしょう。エンコードされた特徴量は、Decoder(Mixer)部分の随所にAdaINで組み込まれています。AdaINとは、線形変換のついたInstance Normalizationで、StyleGANやスタイル変換でよく使われるConditional Normalizationです。

Content Encoderについて

ConvがDeformable Convに置き換えられたという点を除けば、これもよくある形です。Content側のみでDeformable Convを導入しているのは、Content側の特徴量$Z_C$がスタイルに対して不変になるようにしたいからとのことです。

Feature Deformation Skip Connectionについて

理解が難しいのは本論文のコアな手法である「Feature Deformation Skip Connection(FDSC)」の実装です。公式実装のデコーダー部分を見てみましょう。

class Decoder(nn.Module):
    def __init__(self, nf_dec, sty_dim, n_downs, n_res, res_norm, dec_norm, act, pad, use_sn=False):
        super(Decoder, self).__init__()
        print("Init Decoder")

        nf = nf_dec
        self.model = nn.ModuleList()
        self.model.append(ResBlocks(n_res, nf, res_norm, act, pad, use_sn=use_sn))

        self.model.append(nn.Upsample(scale_factor=2))
        self.model.append(Conv2dBlock(nf, nf//2, 5, 1, 2, norm=dec_norm, act=act, pad_type=pad, use_sn=use_sn))
        nf //= 2

        self.model.append(nn.Upsample(scale_factor=2))
        self.model.append(Conv2dBlock(2*nf, nf//2, 5, 1, 2, norm=dec_norm, act=act, pad_type=pad, use_sn=use_sn))
        nf //= 2

        self.model.append(Conv2dBlock(2*nf, 3, 7, 1, 3, norm='none', act='tanh', pad_type=pad, use_sn=use_sn))
        self.model = nn.Sequential(*self.model)
        self.dcn = modulated_deform_conv.ModulatedDeformConvPack(64, 64, kernel_size=(3, 3), stride=1, padding=1, groups=1, deformable_groups=1, double=True).cuda()
        self.dcn_2 = modulated_deform_conv.ModulatedDeformConvPack(128, 128, kernel_size=(3, 3), stride=1, padding=1, groups=1, deformable_groups=1, double=True).cuda()

    def forward(self, x, skip1, skip2):
        output = x
        for i in range(len(self.model)):
            output = self.model[i](output)

            if i == 2: 
                deformable_concat = torch.cat((output,skip2), dim=1)
                concat_pre, offset2 = self.dcn_2(deformable_concat, skip2)
                output = torch.cat((concat_pre,output), dim=1)

            if i == 4:
                deformable_concat = torch.cat((output,skip1), dim=1)
                concat_pre, offset1 = self.dcn(deformable_concat, skip1)
                output = torch.cat((concat_pre,output), dim=1)
            
        offset_sum1 = torch.mean(torch.abs(offset1))
        offset_sum2 = torch.mean(torch.abs(offset2))
        offset_sum = (offset_sum1+offset_sum2)/2
        return output, offset_sum

このコードはMixerとFDSCがセットで実装されています。FDSCの部分はforwardのi==2i==4の部分です。U-NetのSkip Connectionに近い実装のように見えます。

Concatしたあとself.dcn1, self.dcn_2という2つのModulatedDeformConvPackレイヤーを通しています。このレイヤーがどんな入出力をしているのかが気になります。ModulatedDeformConvPackの実装を部分抜粋すると次のとおりです。

class ModulatedDeformConvPack(ModulatedDeformConv):

    def __init__(self, in_channels, out_channels,
                 kernel_size, stride, padding,
                 dilation=1, groups=1, deformable_groups=1, double=False, im2col_step=64, bias=True, lr_mult=0.1):
        super(ModulatedDeformConvPack, self).__init__(in_channels, out_channels,
                                  kernel_size, stride, padding, dilation, groups, deformable_groups, im2col_step, bias)

        out_channels = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
        if double == False:
            self.conv_offset_mask = nn.Conv2d(self.in_channels,
                                            out_channels,
                                            kernel_size=self.kernel_size,
                                            stride=self.stride,
                                            padding=self.padding,
                                            bias=True)
        else:
            # 省略
        self.conv_offset_mask.lr_mult = lr_mult
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def forward(self, input_offset, input_real):
        out = self.conv_offset_mask(input_offset)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return ModulatedDeformConvFunction.apply(input_real, offset, mask, 
                                                self.weight, 
                                                self.bias, 
                                                self.stride, 
                                                self.padding, 
                                                self.dilation, 
                                                self.groups,
                                                self.deformable_groups,
                                                self.im2col_step), offset

forwardは第1引数input_offsetにオフセット、第2引数input_realに特徴マップが与えられます。input_offsetself.conv_offset_mask層を通じて、カーネル、カーネルオフセット、マスクの3種の値にマッピングされます。

このときself.conv_offset_mask層の入力チャンネル、出力チャンネル数がいくつなのか理解するのがFDSCを理解する早道です。入力チャンネル数は特徴マップのチャンネル数(畳み込み層のチャンネル数)と一致します。出力チャンネル数は縦横のカーネル数の積×3。例えば3×3カーネルなら27になります。

直感的にはself.conv_offset_maskのマッピングは、Squeeze and Excitationに近いかもしれません。SE-Netの場合は出力チャンネルが1ですが、これをDeformable Convに合わせてもう少しチャンネル数を伴ったマッピングをしています。

Decoderの実装に戻ると、FDSCの部分では、

            if i == 4:
                deformable_concat = torch.cat((output,skip1), dim=1)
                concat_pre, offset1 = self.dcn(deformable_concat, skip1)
                output = torch.cat((concat_pre,output), dim=1)

と、skip1が2回出てきます。これはまさにSE-Netチックな実装で、論文の図に戻ると、

13_06.png

とまさに今コードで確認したことが描かれています。「U-NetのSkip Connectionなんだけれども、Deformable Convが入ってSE-Netチックなことをやっているんだな」と理解することで自分は腑に落ちました。論文を読んだだけでは、なかなか理解が追いつきませんでした。

FDSCでDeformable Convを入れる理由

単なるSkip Connectionではなく、Deformable ConvによるSkip Connectionを入れたい理由は、フォントのストロークに対して幾何学的な変形をしたいからです。この図を見れば一目瞭然でしょう。

13_07.png

フォント間の変換は、ストロークの幾何学的な変形でおおよそ示せます。幾何学的な変形にはDeformable Convが向いているというわけです。

Multi-task discriminator

DGFontはGANで訓練するため、Discriminatorがありますが、この出力は単なるReal/Fakeの分類ではなく、フォントの種類ごとのReal/Fakeを分類します。これは訓練データにいくつかのフォントが含まれているため、種類込みで二値分類をします。

DのアーキテクチャーにGlobal Average Poolingがなかったので、PatchGANをしているのかなと思いましたが、その記述もなかったので、単に画像全体の真偽を見ているのだと思われます。

損失関数

DGFontは4つの損失関数からなります。

\mathcal{L}=\mathcal{L}_{adv}+\lambda_{img}\mathcal{L}_{img}+\lambda_{cnt}\mathcal{L}_{cnt}+\lambda_{offset}\mathcal{L}_{offset}

それぞれ、

  • $\mathcal{L}_{adv}$がAdversarial loss
  • $\mathcal{L}_{cnt}$がContent consitent loss
  • $\mathcal{L}_{img}$がImage Reconstruction loss
  • $\mathcal{L}_{offset}$がDeformation offset normalization

を示します。順に見ていきましょう。

Adversarial loss

これは普通のGANの敵対的なロスです。公式実装ではHingeロスを使っていました。

def calc_adv_loss(logit, mode):
    assert mode in ['d_real', 'd_fake', 'g']
    if mode == 'd_real':
        loss = F.relu(1.0 - logit).mean()
    elif mode == 'd_fake':
        loss = F.relu(1.0 + logit).mean()
    else:
        loss = -logit.mean()

    return loss

Content consitent loss

Cycle Consistency Lossと似たような名前ですが、発想はそれに近いです。Style Imageを$I_s$, Content Imageを$I_c$としたときに、Gで合成された画像を$G(I_s, I_c)$とします。この画像をもう一度Content Encoder$f_c$に戻したときの値と、$Z_c$との差がContent consitent lossです。このロスは式のほうがわかりやすいです。

\mathcal{L}_{cnt}=\mathbb{E}_{I_s\in P_s, I_c\in P_c}\|Z_c-f_c(G(I_s, I_c))\|_1

出力画像をContent Encoderに戻して、$Z_c$とのL1ロスを取ればいいです。このロス項はContent Encoderがスタイルに依存しないことを促します

Image Reconstruction loss

これは単なるピクセルベースのL1ロスです。入力画像$I_c$に対してドメイン不変であることを保つためのものです。

\mathcal{L}_{img}=\mathbb{E}_{I_c\in P_c}\|I_c-G(I_s, I_c)\|_1

Deformation offset normalization

これはDeformable Conv特有の正則化です。文字生成の場合大半が白か黒かになってしまうため、オフセットのカーネル$\Delta P$が一意の解を獲得することが難しくなってしまいます。そこで、

$$\mathcal{L}_{offset}=\frac{1}{|\mathcal{R}|}|\Delta p|_1$$

ここで$|R|$は畳み込みカーネルの大きさを示します。やっていることは$\Delta p$のノルムを制約するような正則化をかけているので、より近傍に着目するようになります。フォントのストロークや太さをより効率的に学習することを目的としています。

ハイパラ設定

損失関数のハイパラは、$\lambda_{img}=0.1, \lambda_{cnt}=0.1, \lambda_{offset}=0.5$としています。Adversarial Lossが中心で、Deformation offset normalizationもある程度重要度が高いのがわかります。L1系は添えるだけという感じでしょうか。

この他に$\gamma=10$のR1 Regularizationを使用しています。これはGradient Penaltyの一種です。

データセットについて

410個のフォントを集めてきて、

  • 訓練データ:400個のフォントを使用し、各フォント800文字使う。
  • テストデータ1:400個中190個のフォントを使用する(既知のフォント
  • テストデータ2:訓練データに含まれない10個のフォントを使用する(未知のフォント

各文字は80×80のピクセルとします。

結果

13_08.png

既知のフォントでは、特にFIDで提案手法の有効性が目立ち、未知のフォントに対しては一貫して提案手法が良かったことがわかります。実際の生成結果を見ると、

13_09.png

上が簡単なケース、下が難しいケースです。定量評価ではGAN-imorphが部分的に良くても、実際の出力結果ではとぼやけてたり欠けてたりするため、定量評価はあくまで参考です。FUNITは「性」や「会」のような例で、一部の構造を失っています。FIDがかなり人間の直感に近い指標をしているように見えます。

提案された各手法の有効性

13_10.png

本論文に出てきた各手法を見ています。FIDに着目すると、(a)(b)がよく効いています。(a)はContent Encoderのレイヤーを通常の畳み込みからDeformable Convに変えたこと、(b)はFDSC-1モジュールを(Noramlizationなしで)入れたことです。(c)はNormalizationあり、(d)はFDSC-2モジュールを入れたことなので、提案された各手法が有効であることが示せています。

仮に「FDSCをU-Netのような単純なSkip Connection(SC)に変えるとどうなるのか?」という点ですが、全指標で悪化します。

13_11.png

Deformable Convを入れたSkip Connection(FDSC)はフォントの合成において有効であるということです。

13_12.png

FDSC-1のカーネルサイズを1×1にして$\Delta p$を可視化したものがこの図です。$\Delta p$は文字の周辺のみ有効になっており、背景では0になっています。これは「Deformation offset normalization」の効果によるものです。

まとめと感想

この論文では、フォントの合成というかなり需要のありそうなテーマに対し、実用に耐えうる程度のフォント合成を可能にしています。この論文の面白いところは2点あります。1つはDeformable Convを導入し、CNNが比較的不得意な幾何学変形を可能にしていること。空間方向に移動したいのなら現在ならTransformerを使うことが考えられますが、安易にそっち方面に逃げていないのが良いですね。もう1つは、未知のフォントにも対応できるように、Conten ImageとStyle Imageを直交化できていること。これを支えるロスがContent consitent lossやImage Reconstruction lossなのでしょう。

Deformable Convの論文が出たのは、V1が2017年、V2が2018年で、当時は物体検出での利用を想定していました。しかし、2021年の最先端の論文で温故知新的に引用されているのは、なかなか面白いものがあるなーと思いました。

ただ、Deformable Convを1から実装すると相当面倒くさそうですし、Deformation offset normalizationのようにカーネル値を正則化項として損失計算に入れようとすると、特にKerasで実装が大変になりそうな雰囲気があります。「公式コードでポン!」ならいいでしょうが、自分であれこれいじって1から実装しようとすると思わぬハマりどころがあるかもしれません。

告知

このアドベントカレンダーが本になりました!
https://koshian2.booth.pm/items/3595424
Amazonでも扱いあります詳しくは👉 https://shikoan.com

  1. CNN自体が形状ではなくテクスチャに注目する傾向にある、というのは以前からよく言われていました

21
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
13

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?