0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

PyTorch における、Deformable Convolution の使用方法と性能

Posted at

概要

  • PyTorch の deform_conv2d (deform_conv2d ? Torchvision main documentation)に詳細な使用方法が書かれていなかったので、使用方法を確認する。

  • ついでに、通常の畳み込みと Deformable Convolution(変形可能畳み込み)について、CIFAR-10 データセットに対する性能を比較する。

実装

次のリンクに作成した実装をアップロード済みである。

使用方法はリンク先を参照。

GitHub - hrmc2022/Conv_vs_DeformConvGitHub - hrmc2022/Conv_vs_DeformConv

Deformable Convolutionについて

2017年に Deformable Convolution Networks (以降、DeformConv と呼ぶ)、2018年に Deformable Convolution Networks v2 (以降、DeformConv v2 と呼ぶ)が提案された。通常の畳み込みに対して、前者は、カーネルのオフセット(カーネルの変位、つまり、位置と距離)も学習対象に、後者は、入力画素ごとの重み(modulation)も学習対象としている。

詳細は、次を参照。

deform_conv2d の使い方

mask を入力しない場合は DeformConv が適用され、mask を入力する場合は Defrom Conv v2 が適用される。

  • DeformConv 適用の場合

    通常の畳み込みに加えて、オフセット用の畳み込みを行う。

    オフセットの出力チャンネル数は、2×カーネルサイズ×カーネルサイズ。

    import torch
    import torch.nn as nn
    from torchvision.ops import deform_conv2d
    
    in_ch = 3 # 入力チャンネル数
    out_ch = 64 # 出力チャンネル数
    
    regular_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
    offset_conv = nn.Conv2d(in_ch, 2 * 3 * 3, kernel_size=3, stride=1, padding=1, bias=False)
    input = torch.rand(4, 3, 10, 10) # [バッチサイズ. チャンネル数, 高さ, 幅]
    offset = offset_conv(input)
    output = deform_conv2d(input, offset, weight=regular_conv.weight, bias=regular_conv.bias, padding=1, stride=1)
    
    print(output.shape) # torch.Size([4, 64, 10, 10])
    
  • DeformConv v2 の場合

    DeformConv に対して、modulation も計算して入力する。

    modulation の出力チャンネル数は、カーネルサイズ×カーネルサイズ。

    import torch
    import torch.nn as nn
    from torchvision.ops import deform_conv2d
    
    in_ch = 3 # 入力チャンネル数
    out_ch = 64 # 出力チャンネル数
    
    regular_conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)
    offset_conv = nn.Conv2d(in_ch, 2 * 3 * 3, kernel_size=3, stride=1, padding=1, bias=False)
    mod_conv = nn.Conv2d(in_ch, 3 * 3, kernel_size=3, stride=1, padding=1, bias=False)
    input = torch.rand(4, 3, 10, 10) # [バッチサイズ. チャンネル数, 高さ, 幅]
    offset = offset_conv(input)
    modulation = torch.sigmoid(mod_conv(input))
    output = deform_conv2d(input, offset, weight=regular_conv.weight, mask=modulation, bias=regular_conv.bias, padding=1, stride=1)
    
    print(output.shape) # torch.Size([4, 64, 10, 10])
    

性能比較

CIFAR-10 の学習データセットに対して、同じ設定(チャンネル数、学習エポック数、等)で通常の畳み込みと DeformConv と DeformConv v2 でそれぞれ学習後、評価データセットに対して評価をした結果(正解率)は次の通り。

DeformConv v2 が最良となる。

Test Result
Conv2d DeformConv2d DeformConv2d v2
0.5538138747215271 0.6718250513076782 0.6987819671630859

おわりに

今回は、Deformable Convolution の使用方法を確認し、性能が良いことがわかった。発展手法に Self Attention と組み合わせた DAT[2201.00520] Vision Transformer with Deformable Attention 等も出てきているので、Deformable Convolution を積極的に使用していきたい。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?