概要
-
PyTorch の deform_conv2d (deform_conv2d ? Torchvision main documentation)に詳細な使用方法が書かれていなかったので、使用方法を確認する。
-
ついでに、通常の畳み込みと Deformable Convolution(変形可能畳み込み)について、CIFAR-10 データセットに対する性能を比較する。
実装
次のリンクに作成した実装をアップロード済みである。
使用方法はリンク先を参照。
GitHub - hrmc2022/Conv_vs_DeformConvGitHub - hrmc2022/Conv_vs_DeformConv
Deformable Convolutionについて
2017年に Deformable Convolution Networks (以降、DeformConv と呼ぶ)、2018年に Deformable Convolution Networks v2 (以降、DeformConv v2 と呼ぶ)が提案された。通常の畳み込みに対して、前者は、カーネルのオフセット(カーネルの変位、つまり、位置と距離)も学習対象に、後者は、入力画素ごとの重み(modulation)も学習対象としている。
詳細は、次を参照。
-
スケールと形状を学習可能なConvolution: Modulated Deformable Convolution (Deformable ConvNets v2)を実装 - Qiita
-
[1811.11168] Deformable ConvNets v2: More Deformable, Better Results
deform_conv2d の使い方
mask を入力しない場合は DeformConv が適用され、mask を入力する場合は Defrom Conv v2 が適用される。
-
DeformConv 適用の場合
通常の畳み込みに加えて、オフセット用の畳み込みを行う。
オフセットの出力チャンネル数は、2×カーネルサイズ×カーネルサイズ。
import torch import torch.nn as nn from torchvision.ops import deform_conv2d in_ch = 3 # 入力チャンネル数 out_ch = 64 # 出力チャンネル数 regular_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False) offset_conv = nn.Conv2d(in_ch, 2 * 3 * 3, kernel_size=3, stride=1, padding=1, bias=False) input = torch.rand(4, 3, 10, 10) # [バッチサイズ. チャンネル数, 高さ, 幅] offset = offset_conv(input) output = deform_conv2d(input, offset, weight=regular_conv.weight, bias=regular_conv.bias, padding=1, stride=1) print(output.shape) # torch.Size([4, 64, 10, 10])
-
DeformConv v2 の場合
DeformConv に対して、modulation も計算して入力する。
modulation の出力チャンネル数は、カーネルサイズ×カーネルサイズ。
import torch import torch.nn as nn from torchvision.ops import deform_conv2d in_ch = 3 # 入力チャンネル数 out_ch = 64 # 出力チャンネル数 regular_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False) offset_conv = nn.Conv2d(in_ch, 2 * 3 * 3, kernel_size=3, stride=1, padding=1, bias=False) mod_conv = nn.Conv2d(in_ch, 3 * 3, kernel_size=3, stride=1, padding=1, bias=False) input = torch.rand(4, 3, 10, 10) # [バッチサイズ. チャンネル数, 高さ, 幅] offset = offset_conv(input) modulation = torch.sigmoid(mod_conv(input)) output = deform_conv2d(input, offset, weight=regular_conv.weight, mask=modulation, bias=regular_conv.bias, padding=1, stride=1) print(output.shape) # torch.Size([4, 64, 10, 10])
性能比較
CIFAR-10 の学習データセットに対して、同じ設定(チャンネル数、学習エポック数、等)で通常の畳み込みと DeformConv と DeformConv v2 でそれぞれ学習後、評価データセットに対して評価をした結果(正解率)は次の通り。
DeformConv v2 が最良となる。
Conv2d | DeformConv2d | DeformConv2d v2 |
---|---|---|
0.5538138747215271 | 0.6718250513076782 | 0.6987819671630859 |
おわりに
今回は、Deformable Convolution の使用方法を確認し、性能が良いことがわかった。発展手法に Self Attention と組み合わせた DAT[2201.00520] Vision Transformer with Deformable Attention 等も出てきているので、Deformable Convolution を積極的に使用していきたい。