6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ONNX.jsでConvTranspose2dを使って画像生成させる

Posted at

はじめに

ONNX.jsを使うとpytorchなどで学習したモデルをJavascriptで動かすことが出来ますが,ここに書いてあるとおり現時点で対応していない処理が結構あります。GANなどで画像生成する際によく使われるConvTranspose2dも対応していません。
しかし通常のConv2Dを使えば等価な処理を(すべてではないですが)行うことができます。
Pytorchでstride 1 or 2に対応したカスタムレイヤーを作りました。

ConvTranspose2d

詳細はA guide to convolution arithmetic for deep learningを見ればよいかと思いますが,Kernel: 4 Stride: 2 Padding: 1のConvTranspose2dは以下の図のようになります。

image.png

青マスが入力になっており,Strideのために空白をいれてPaddingを施したあと,2次元畳み込みを行います。よって空白を入れる処理とPaddingをしてしまえばあとはnn.Conv2Dで計算出来ます。
どうやって空白を入れるかですが,大きい空配列を用意し,ステップ2のSlice[::2]で指定して入力を入れてやる,みたいなことは出来ませんでした。ONNX.jsがsliceに対応していません。なのでConcatを使います。

image.png

図の通り空配列とtorch.catをうまいこと使って調整したあと,stride=1,padding=0のnn.Conv2Dで計算します。
nn.Conv2dを継承して作ったカスタムレイヤーのコードは以下のようになります。

ConvTranspose2D_ONNXJS.py
class ConvTranspose2D_ONNXJS(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, height, width, batch_size):
        super(ConvTranspose2D_ONNXJS, self).__init__(in_channels, out_channels, kernel_size, 1, 0)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride_t = stride
        if not (self.stride_t==1 or self.stride_t==2):
            raise Exception("available stride is only 1 or 2.")

        self.padding_t = padding
        self.h = height
        self.w = width
        self.b = batch_size
        self.outpad = self.kernel_size - 1 - self.padding_t
        if self.outpad < (self.stride_t-1):
            raise Exception("outside padding is too little.")
        

    def forward(self, x):
        if self.stride_t == 2:
            #2 stride by concat
            tmp  = x.view(-1,1)
            tmp  = torch.cat([tmp, torch.zeros((self.b * self.in_channels * self.h * self.w, 1))],1)
            tmp = tmp.view(-1, self.w * 2)
            x = torch.cat([tmp,torch.zeros((self.b * self.in_channels * self.h, self.w * 2))], 1).view(-1, self.in_channels, self.h * 2, self.w * 2)
            # outside padding by concat
            hs, ws, he, we = self.outpad, self.outpad, self.outpad, self.outpad
            x = torch.cat([torch.zeros((self.b, self.in_channels, hs, self.w * 2)), x, torch.zeros((self.b, self.in_channels, he-1, self.w * 2))], 2)
            x = torch.cat([torch.zeros((self.b, self.in_channels, self.h * 2+hs+he-1, ws)), x, torch.zeros((self.b, self.in_channels, self.h*2+hs+he-1, we-1))], 3)
        else:
            # outside padding by concat
            hs, ws, he, we = self.outpad, self.outpad, self.outpad, self.outpad
            x = torch.cat([torch.zeros((self.b, self.in_channels, hs, self.w)), x, torch.zeros((self.b, self.in_channels, he, self.w))], 2)
            x = torch.cat([torch.zeros((self.b, self.in_channels, self.h+hs+he, ws)), x, torch.zeros((self.b, self.in_channels, self.h+hs+he, we))], 3)

        x = super().forward(x)
        return x

ONNX.jsでshapeメソッドが使えないため,あらかじめバッチサイズ,高さ,幅を指定する必要があります。基本的に学習時は普通のnn.ConvTranspose2dを使って行い,レイヤーを差し替えたネットワークを作ったあと重みをコピーしてからonnxファイルにエクスポートするのがいいかと思います。
nn.ConvTranspose2dのweightとnn.Conv2dのweightは配置がひっくり返ってます。そこで以下のコードでnn.Conv2dのweightを変換してからコピーしてください。

convert_weight.py
def convert_weight(nn_weight):
    w = nn_weight[:]
    I, O, H, W = w.shape
    w = w.permute(1,0,2,3)
    w = w[:, :, torch.arange(H-1,-1,-1), :]
    w = w[:, :, :, torch.arange(W-1,-1,-1)]
    return w

GithubにシンプルなDCGANを使った例を上げました。

a2kiti/ConvTranspose2d_for_ONNX

変換の仕方はjupyter notebookを見てください。
作成したonnxファイルを使ったjavascriptでの画像生成のデモも作りました。
リポジトリをcloneしたあとターミナルから
python -m http.server
を打ち込んで
http://0.0.0.0:8000/
にアクセスするとブラウザ上での動作を確認できます。

image.png

おわり

オート般若心経
これを使うと上記のサイトのように画像生成処理をブラウザ上で実行することが出来ます。
とはいえこんなめんどくさいことをしなくても,そのうちONNX.jsが対応してくれることでしょう。

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?