はじめに
ONNX.jsを使うとpytorchなどで学習したモデルをJavascriptで動かすことが出来ますが,ここに書いてあるとおり現時点で対応していない処理が結構あります。GANなどで画像生成する際によく使われるConvTranspose2dも対応していません。
しかし通常のConv2Dを使えば等価な処理を(すべてではないですが)行うことができます。
Pytorchでstride 1 or 2に対応したカスタムレイヤーを作りました。
ConvTranspose2d
詳細はA guide to convolution arithmetic for deep learningを見ればよいかと思いますが,Kernel: 4 Stride: 2 Padding: 1のConvTranspose2dは以下の図のようになります。
青マスが入力になっており,Strideのために空白をいれてPaddingを施したあと,2次元畳み込みを行います。よって空白を入れる処理とPaddingをしてしまえばあとはnn.Conv2Dで計算出来ます。
どうやって空白を入れるかですが,大きい空配列を用意し,ステップ2のSlice[::2]で指定して入力を入れてやる,みたいなことは出来ませんでした。ONNX.jsがsliceに対応していません。なのでConcatを使います。
図の通り空配列とtorch.catをうまいこと使って調整したあと,stride=1,padding=0のnn.Conv2Dで計算します。
nn.Conv2dを継承して作ったカスタムレイヤーのコードは以下のようになります。
class ConvTranspose2D_ONNXJS(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, height, width, batch_size):
super(ConvTranspose2D_ONNXJS, self).__init__(in_channels, out_channels, kernel_size, 1, 0)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride_t = stride
if not (self.stride_t==1 or self.stride_t==2):
raise Exception("available stride is only 1 or 2.")
self.padding_t = padding
self.h = height
self.w = width
self.b = batch_size
self.outpad = self.kernel_size - 1 - self.padding_t
if self.outpad < (self.stride_t-1):
raise Exception("outside padding is too little.")
def forward(self, x):
if self.stride_t == 2:
#2 stride by concat
tmp = x.view(-1,1)
tmp = torch.cat([tmp, torch.zeros((self.b * self.in_channels * self.h * self.w, 1))],1)
tmp = tmp.view(-1, self.w * 2)
x = torch.cat([tmp,torch.zeros((self.b * self.in_channels * self.h, self.w * 2))], 1).view(-1, self.in_channels, self.h * 2, self.w * 2)
# outside padding by concat
hs, ws, he, we = self.outpad, self.outpad, self.outpad, self.outpad
x = torch.cat([torch.zeros((self.b, self.in_channels, hs, self.w * 2)), x, torch.zeros((self.b, self.in_channels, he-1, self.w * 2))], 2)
x = torch.cat([torch.zeros((self.b, self.in_channels, self.h * 2+hs+he-1, ws)), x, torch.zeros((self.b, self.in_channels, self.h*2+hs+he-1, we-1))], 3)
else:
# outside padding by concat
hs, ws, he, we = self.outpad, self.outpad, self.outpad, self.outpad
x = torch.cat([torch.zeros((self.b, self.in_channels, hs, self.w)), x, torch.zeros((self.b, self.in_channels, he, self.w))], 2)
x = torch.cat([torch.zeros((self.b, self.in_channels, self.h+hs+he, ws)), x, torch.zeros((self.b, self.in_channels, self.h+hs+he, we))], 3)
x = super().forward(x)
return x
ONNX.jsでshapeメソッドが使えないため,あらかじめバッチサイズ,高さ,幅を指定する必要があります。基本的に学習時は普通のnn.ConvTranspose2dを使って行い,レイヤーを差し替えたネットワークを作ったあと重みをコピーしてからonnxファイルにエクスポートするのがいいかと思います。
nn.ConvTranspose2dのweightとnn.Conv2dのweightは配置がひっくり返ってます。そこで以下のコードでnn.Conv2dのweightを変換してからコピーしてください。
def convert_weight(nn_weight):
w = nn_weight[:]
I, O, H, W = w.shape
w = w.permute(1,0,2,3)
w = w[:, :, torch.arange(H-1,-1,-1), :]
w = w[:, :, :, torch.arange(W-1,-1,-1)]
return w
GithubにシンプルなDCGANを使った例を上げました。
a2kiti/ConvTranspose2d_for_ONNX
変換の仕方はjupyter notebookを見てください。
作成したonnxファイルを使ったjavascriptでの画像生成のデモも作りました。
リポジトリをcloneしたあとターミナルから
python -m http.server
を打ち込んで
http://0.0.0.0:8000/
にアクセスするとブラウザ上での動作を確認できます。
おわり
オート般若心経
これを使うと上記のサイトのように画像生成処理をブラウザ上で実行することが出来ます。
とはいえこんなめんどくさいことをしなくても,そのうちONNX.jsが対応してくれることでしょう。