1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RepVGGのre-parameterizationをPyTorchで計算して確認してみた

Last updated at Posted at 2022-10-09

はじめに

RepVGGで使用されているre-parameterizationを,PyTorchを使用して実際に計算して確認してみました
実装はImplementing RepVGG in PyTorchを参考にしました

RepVGG

RepVGGは,推論時はVGGのように3×3convとReLUのみのbodyを持ち,訓練時にはMulti-branchのトポロジーを持ちます
251b36cf-40be-4cf5-80f5-43948d2c8408.png
この推論時と訓練時のアーキテクチャの分離は,re-parameterizationテクニックを用いることで実現されています
15a60690-5ec3-44cc-b9ba-8802700a69a4.png
各re-parameterizationについて以下の順に見ていきます

  1. (3つの3x3 conv + 各結果の加算) → 3x3 convへのre-parameterization
  2. (3x3 conv + BN) → 3x3 convへのre-parameterization
  3. (1x1 conv + BN) → 3x3 convへのre-parameterization
  4. (identity + BN) → 3x3 convへのre-parameterization

1. (3つの3x3 conv + 各結果の加算) → 3x3 convへのre-parameterization

↓この部分です
d8caa3c5-90f9-45a7-a669-c2dedc41685f.png
2つも3つも一緒なので,(2つの3x3 conv + 各結果の加算) → 3x3 convへのre-parameterizationを確認します

全体のコード

import torch
from torch import nn, Tensor


with torch.no_grad():
    # 入力画像
    feature_map = torch.ones((1, 2, 7, 7)) # BCHW
    feature_map[:, 1, :, :] = 2
    print('feature_map')
    print(feature_map)

    conv1 = nn.Conv2d(2,2,kernel_size=3, padding=1, bias=True)
    conv1.weight.fill_(1)
    conv1.bias.fill_(10)
    print('conv1.weight')
    print(conv1.weight)
    print('conv1.bias')
    print(conv1.bias)

    conv2 = nn.Conv2d(2,2,kernel_size=3, padding=1, bias=True)
    conv2.weight.fill_(2)
    conv2.bias.fill_(20)
    print('conv2.weight')
    print(conv2.weight)
    print('conv2.bias')
    print(conv2.bias)

    conv_fused = nn.Conv2d(conv1.in_channels, conv1.out_channels, kernel_size=conv1.kernel_size, padding=1)
    conv_fused.weight = nn.Parameter(conv1.weight + conv2.weight)
    conv_fused.bias =  nn.Parameter(conv1.bias + conv2.bias)
    print('conv_fused.weight')
    print(conv_fused.weight)
    print('conv_fused.bias')
    print(conv_fused.bias)

    print('conv1(feature_map)+conv2(feature_map)')
    print(conv1(feature_map)+conv2(feature_map))
    print('conv_fused(feature_map)')
    print(conv_fused(feature_map))

入力特徴マップの生成

C=2とします

with torch.no_grad():
    # 入力画像
    feature_map = torch.ones((1, 2, 7, 7)) # BCHW
    feature_map[:, 1, :, :] = 2
    print('feature_map')
    print(feature_map)

feature_map
tensor([[[[1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.]],

         [[2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.]]]])

2つの3x3 convを定義

    conv1 = nn.Conv2d(2,2,kernel_size=3, padding=1, bias=True)
    conv1.weight.fill_(1)
    conv1.bias.fill_(10)
    print('conv1.weight')
    print(conv1.weight)
    print('conv1.bias')
    print(conv1.bias)

    conv2 = nn.Conv2d(2,2,kernel_size=3, padding=1, bias=True)
    conv2.weight.fill_(2)
    conv2.bias.fill_(20)
    print('conv2.weight')
    print(conv2.weight)
    print('conv2.bias')
    print(conv2.bias)

conv1.weight
Parameter containing:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], requires_grad=True)
conv1.bias
Parameter containing:
tensor([10., 10.], requires_grad=True)
conv2.weight
Parameter containing:
tensor([[[[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]],

         [[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]]],


        [[[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]],

         [[2., 2., 2.],
          [2., 2., 2.],
          [2., 2., 2.]]]], requires_grad=True)
conv2.bias
Parameter containing:
tensor([20., 20.], requires_grad=True)

1つの3✕3 convに変換

    conv_fused = nn.Conv2d(conv1.in_channels, conv1.out_channels, kernel_size=conv1.kernel_size, padding=1)
    conv_fused.weight = nn.Parameter(conv1.weight + conv2.weight)
    conv_fused.bias =  nn.Parameter(conv1.bias + conv2.bias)
    print('conv_fused.weight')
    print(conv_fused.weight)
    print('conv_fused.bias')
    print(conv_fused.bias)

conv_fused.weight
Parameter containing:
tensor([[[[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]],

         [[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]]],


        [[[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]],

         [[3., 3., 3.],
          [3., 3., 3.],
          [3., 3., 3.]]]], requires_grad=True)
conv_fused.bias
Parameter containing:
tensor([30., 30.], requires_grad=True)

計算結果の確認

    print('conv1(feature_map)+conv2(feature_map)')
    print(conv1(feature_map)+conv2(feature_map))
    print('conv_fused(feature_map)')
    print(conv_fused(feature_map))

conv1(feature_map)+conv2(feature_map)
tensor([[[[ 66.,  84.,  84.,  84.,  84.,  84.,  66.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 66.,  84.,  84.,  84.,  84.,  84.,  66.]],

         [[ 66.,  84.,  84.,  84.,  84.,  84.,  66.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 66.,  84.,  84.,  84.,  84.,  84.,  66.]]]])
conv_fused(feature_map)
tensor([[[[ 66.,  84.,  84.,  84.,  84.,  84.,  66.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 66.,  84.,  84.,  84.,  84.,  84.,  66.]],

         [[ 66.,  84.,  84.,  84.,  84.,  84.,  66.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 84., 111., 111., 111., 111., 111.,  84.],
          [ 66.,  84.,  84.,  84.,  84.,  84.,  66.]]]])

同じ計算結果になりました

2. (3x3 conv + BN) → 3x3 convへのre-parameterization

M: 特徴マップ
μ, σ, γ, β: BNのパラメータ
とすると
c0942dbc-35a1-4e1c-8b70-45c6fc23d750.png
のようになります
↓実装を見たほうが分かりやすいと思います

全体のコード

import torch
from torch import nn, Tensor


with torch.no_grad():
    feature_map = torch.empty((1, 2, 7, 7))
    feature_map[:, 0, :, :] = 50
    feature_map[:, 1, :, :] = torch.arange(50, 99).reshape(1, 7, 7)
    print('feature_map')
    print(feature_map)

    bn = nn.BatchNorm2d(2)
    bn.running_mean.fill_(5)
    bn.running_var.fill_(100)
    bn.weight.fill_(5)
    bn.bias.fill_(20)
    bn.eval()

    print('bn.running_mean(mean)')
    print(bn.running_mean)
    print('bn.running_var(std^2)')
    print(bn.running_var)
    print('bn.weight(gamma)')
    print(bn.weight)
    print('bn.bias(beta)')
    print(bn.bias)

    # (50-5)*(5/10) + 20 = 45 * (1/2) + 20 = 42.5
    print('bn(feature_map)')
    print(bn(feature_map))

    conv = nn.Conv2d(2, 2,kernel_size=3, padding=1, bias=True)
    print('bn(conv(feature_map))')
    print(bn(conv(feature_map)))
    
    bn_std = (bn.running_var + bn.eps).sqrt()
    conv_weight = nn.Parameter((bn.weight / bn_std).reshape(-1, 1, 1, 1) * conv.weight) # conv.weightは(C_out, C_in, H, W)
    conv_bias = nn.Parameter(bn.bias - bn.running_mean * bn.weight / bn_std)
    conv_fused = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, padding=1)
    conv_fused.load_state_dict({'weight': conv_weight, 'bias':conv_bias})

    print('conv_fused(feature_map)')
    print(conv_fused(feature_map))

入力特徴マップの生成

値は適当です

import torch
from torch import nn, Tensor


with torch.no_grad():
    feature_map = torch.empty((1, 2, 7, 7))
    feature_map[:, 0, :, :] = 50
    feature_map[:, 1, :, :] = torch.arange(50, 99).reshape(1, 7, 7)
    print('feature_map')
    print(feature_map)

feature_map
tensor([[[[50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.],
          [50., 50., 50., 50., 50., 50., 50.]],

         [[50., 51., 52., 53., 54., 55., 56.],
          [57., 58., 59., 60., 61., 62., 63.],
          [64., 65., 66., 67., 68., 69., 70.],
          [71., 72., 73., 74., 75., 76., 77.],
          [78., 79., 80., 81., 82., 83., 84.],
          [85., 86., 87., 88., 89., 90., 91.],
          [92., 93., 94., 95., 96., 97., 98.]]]])

BNの定義

BNのパラメータも適当です

    bn = nn.BatchNorm2d(2)
    bn.running_mean.fill_(5)
    bn.running_var.fill_(100)
    bn.weight.fill_(5)
    bn.bias.fill_(20)
    bn.eval()

    print('bn.running_mean(mean)')
    print(bn.running_mean)
    print('bn.running_var(std^2)')
    print(bn.running_var)
    print('bn.weight(gamma)')
    print(bn.weight)
    print('bn.bias(beta)')
    print(bn.bias)

bn.running_mean(mean)
tensor([5., 5.])
bn.running_var(std^2)
tensor([100., 100.])
bn.weight(gamma)
Parameter containing:
tensor([5., 5.], requires_grad=True)
bn.bias(beta)
Parameter containing:
tensor([20., 20.], requires_grad=True)

3×3 convの定義と,(3x3 conv + BN) → 3x3 convへのre-parameterization

01d042ef-b348-4c0e-8dc4-3f77f63d8c22.png

    conv = nn.Conv2d(2, 2,kernel_size=3, padding=1, bias=True)
    bn_std = (bn.running_var + bn.eps).sqrt()
    conv_weight = nn.Parameter((bn.weight / bn_std).reshape(-1, 1, 1, 1) * conv.weight) # conv.weightは(C_out, C_in, H, W)
    conv_bias = nn.Parameter(bn.bias - bn.running_mean * bn.weight / bn_std)
    conv_fused = nn.Conv2d(conv.in_channels, conv.out_channels, kernel_size=conv.kernel_size, padding=1)
    conv_fused.load_state_dict({'weight': conv_weight, 'bias':conv_bias})

計算結果の確認

    print('bn(conv(feature_map))')
    print(bn(conv(feature_map)))
    print('conv_fused(feature_map)')
    print(conv_fused(feature_map))

bn(conv(feature_map))
tensor([[[[16.5714, 11.5451, 11.5955, 11.6459, 11.6962, 11.7466,  0.4305],
          [22.4148, 25.0118, 25.1873, 25.3628, 25.5383, 25.7137,  7.3858],
          [23.3898, 26.2402, 26.4156, 26.5911, 26.7666, 26.9421,  7.1664],
          [24.3649, 27.4685, 27.6440, 27.8194, 27.9949, 28.1704,  6.9471],
          [25.3400, 28.6968, 28.8723, 29.0478, 29.2233, 29.3987,  6.7278],
          [26.3150, 29.9252, 30.1006, 30.2761, 30.4516, 30.6271,  6.5084],
          [37.5233, 48.5330, 48.7859, 49.0389, 49.2918, 49.5448, 30.5362]],

         [[13.3822,  3.0390,  2.9962,  2.9534,  2.9106,  2.8678,  7.3584],
          [17.7679,  6.5199,  6.5464,  6.5729,  6.5994,  6.6260,  4.8348],
          [18.6928,  6.7056,  6.7321,  6.7586,  6.7851,  6.8117,  4.6277],
          [19.6176,  6.8913,  6.9178,  6.9443,  6.9709,  6.9974,  4.4205],
          [20.5425,  7.0770,  7.1035,  7.1300,  7.1566,  7.1831,  4.2134],
          [21.4673,  7.2627,  7.2892,  7.3157,  7.3423,  7.3688,  4.0063],
          [35.3246, 30.1079, 30.2486, 30.3892, 30.5299, 30.6705, 15.5543]]]])
conv_fused(feature_map)
tensor([[[[16.6741, 11.6478, 11.6982, 11.7485, 11.7989, 11.8492,  0.5331],
          [22.5174, 25.1145, 25.2900, 25.4654, 25.6409, 25.8164,  7.4884],
          [23.4925, 26.3428, 26.5183, 26.6938, 26.8692, 27.0447,  7.2691],
          [24.4675, 27.5711, 27.7466, 27.9221, 28.0976, 28.2730,  7.0497],
          [25.4426, 28.7995, 28.9749, 29.1504, 29.3259, 29.5014,  6.8304],
          [26.4177, 30.0278, 30.2033, 30.3788, 30.5542, 30.7297,  6.6111],
          [37.6260, 48.6356, 48.8886, 49.1415, 49.3945, 49.6474, 30.6388]],

         [[13.4420,  3.0987,  3.0559,  3.0131,  2.9703,  2.9275,  7.4181],
          [17.8277,  6.5796,  6.6061,  6.6326,  6.6592,  6.6857,  4.8945],
          [18.7525,  6.7653,  6.7918,  6.8184,  6.8449,  6.8714,  4.6874],
          [19.6774,  6.9510,  6.9775,  7.0041,  7.0306,  7.0571,  4.4803],
          [20.6022,  7.1367,  7.1632,  7.1898,  7.2163,  7.2428,  4.2732],
          [21.5270,  7.3224,  7.3489,  7.3755,  7.4020,  7.4285,  4.0661],
          [35.3844, 30.1676, 30.3083, 30.4490, 30.5896, 30.7303, 15.6140]]]])

誤差はありますがほぼ同じになっています

3. (1x1 conv + BN) → 3x3 convへのre-parameterization

1×1 conv → 3×3 convに変換すれば後は2と同じです
下の画像のオレンジの部分です
3058d94e-16c8-46c2-be65-e894f96a7cc1.png
3x3 convの中心に1x1 convのフィルタの値をコピーして,その他の8箇所の重みを0にすれば変換できます

4. (identity + BN) → 3x3 convへのre-parameterization

これもidentity → 3×3のidentity convに変換すれば後は2と同じです
さっきの画像の黄色の部分です
3×3のidentity convについて確認します
畳み込みしても特徴マップの値が変わらない畳み込みです
identity convの実装については
FrancescoSaverioZuppichini/repvgg-8.py
を参考にさせていただきました

全体のコード

import torch
from torch import nn, Tensor


with torch.no_grad():
    feature_map = torch.ones((1, 2, 7, 7)) # BCHW
    feature_map[:, 1, :, :] = 2
    print('feature_map')
    print(feature_map)

    identity_conv = nn.Conv2d(2,2,kernel_size=3, padding=1, bias=False)
    identity_conv.weight.zero_()

    in_channels = identity_conv.in_channels
    for i in range(in_channels):
        identity_conv.weight[i, i % in_channels, 1, 1] = 1

    print('identity_conv.weight')
    print(identity_conv.weight)

    out = identity_conv(feature_map)
    print('identity_conv(feature_map)')
    print(out)

feature_map
tensor([[[[1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.]],

         [[2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.]]]])
identity_conv.weight
Parameter containing:
tensor([[[[0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]]]], requires_grad=True)
identity_conv(feature_map)
tensor([[[[1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1.]],

         [[2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.],
          [2., 2., 2., 2., 2., 2., 2.]]]])

参考

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?