5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【pytorch入門】torchvision.modelsを使ってfine-tuningで遊んでみた♬

Last updated at Posted at 2021-02-14

今回は、pytorchで事前学習されたモデルを利用してClassificationしてみます。
簡単な話ですが、ちゃんと出来るのかというレベルを目指すと奥は深いと思います。
モデルは以下の参考で公開されています。
使い方のtutorialが見つかりませんが、だいたい参考に情報は掲載されています。
ということで、今回は一番簡単な使い方、Classificationに焦点をあてて、これらの種々のモデルの性能比較をやろうと思います。
【参考】
TORCHVISION.MODELS
###やったこと
・一番簡単な転移学習とFine-Tuningをやってみる
・いろいろなモデルを使ってみる
・入力サイズ依存を測定
###・一番簡単な転移学習とFine-Tuningをやってみる
基本のコードは、以前まとめた以下の参考のおまけに掲載したコードのモデルを変更して、利用します。このnetwork関数を転移学習に利用する学習済のnetwork関数とします。
【参考】
【pytorch-lightning入門】初めてのLit♬
####転移学習のコード
以下のコードは、元々のImagenetで学習したモデルで1000カテゴリまで分類しているので、それをCifar10向けにさらに10カテゴリに分類するために全結合層を新たに追加しています。最初から元のモデルの最終層を10カテゴリ出力に変更するでも同様な効果です。また、この全結合層をもう少しリッチにするなどの工夫もできますが、今回は一番単純な方式としています。

class customize_model(nn.Module):
  def __init__(self):
    super(customize_model, self).__init__()
    
    self.model_ = models.resnet18(pretrained=True)
    for param in self.model_.parameters():
        param.requires_grad = True #False
    self.f_resnet = nn.Sequential(
        nn.Linear(in_features=np.int(1000), out_features=10, bias=True)
    )    
    
  def forward(self, input):

    midlevel_features = self.model_(input)
    output = self.f_resnet(midlevel_features)
    return output

上記では、全てのWeightsやBiasなどを学習するか・しないかのどちらかしか設定できません。
実際には、自分の趣味(学習の状況)に合わせて、層を選んで学習したい。
pytorchでも当然出来るはずで、以下の参考②にありました。
【参考】
【PyTorch】torchvisionから学習済みモデルを使用する際のTips
ということで、いろいろ試した結果以下のようにできました。

    for i, param in enumerate(self.model_.parameters()):
        param.requires_grad = True
        if i >= 20:
            param.requires_grad = False #True #False
        print(i, param.requires_grad)
    
    j =0
    for name,  param in self.model_.named_parameters():
        param.requires_grad = True
        if j >= 10:
            param.requires_grad = False
        print(j, name, param.requires_grad)
        j +=1

###・いろいろなモデルを使ってみる
上記の単独のモデルを拡張して、以下のようなnetwork関数を定義しました。
以下のclassを読みだすことにより、modelをそれぞれ生成している。

class customize_model(nn.Module):
  def __init__(self, input_size=128, sel_model='resnet18'):
    super(customize_model, self).__init__()
    
    ## Select model
    if sel_model == 'resnext50':
        self.model_ = models.resnext50_32x4d(pretrained=True)
    elif sel_model == 'vgg16':
        self.model_ = models.vgg16_bn(pretrained=True) #, num_classes=10) #365) 
    elif sel_model == 'wide50':
        self.model_ = models.wide_resnet50_2(pretrained=True) #, num_classes=10) #365) 
    elif sel_model == 'mobilev2':
        self.model_ = models.mobilenet_v2(pretrained=True)
    elif sel_model == 'densenet121':
        self.model_ = models.densenet121(pretrained=True)
    else:
        self.model_ = models.resnet18(pretrained=True)
    
    j =0
    for name,  param in self.model_.named_parameters():
        param.requires_grad = True
        if j >= 10:
            param.requires_grad = False
        print(j, name, param.requires_grad)
        j +=1
    
    self.f_resnet = nn.Sequential(
        nn.Linear(in_features=np.int(1000), out_features=10, bias=True)
    )
    
  def forward(self, input):

    midlevel_features = self.model_(input)
    output = self.f_resnet(midlevel_features)
    return output

これらの性能は上記のリンクからImagenetに対して以下の通りだそうです。
今回は、これにCifar10の測定結果を追記しています。
測定条件は、128□に拡大し、規格化実施したが、それ以外は読込通りで何ら手を加えていないが、ばらつきがあるもののtop-1 errorが10%以下で比較的よい。
これは、imagenetでの学習パラメータを初期値としてFine-Tuningを実施しているためかもしれないが、いわゆる学習を10カテゴリへのfc層に限った場合は、右の数字になっているので、Fine-Tuning(再学習)が成功していると言える。

モデル名 top-1 error top-10 error training_time/epoch;Fine-Tuning Cifar10 top-1 error/10epoch training_time/epoch;transfer Cifar10 top-1 error weights size
VGG-16 with batch normalization 26.63 8.50 179.1a/124.1b 9.37a/9.57b 185.5 18.4 35.6MBa(138MB)/21.0MBb
ResNet-18 30.24 10.92 120.0 9.54 61.5 20.3 11.7MB
Densenet-121 25.35 7.83 253.1 6.35 112.8 17.8 8.0MB
MobileNet V2 28.12 9.71 124.7 8.04 65.3 21.1 3.5MB
ResNeXt-50-32x4d 22.38 6.30 425.2 7.57 149.5 19.23 25.0MB
Wide ResNet-50-2 21.49 5.91 810.1(batch8)/258.5e/419f 30.9(batch8)/12.4e/11.1f 173.1c/124.0d 21.55/17.48d 68.9MB/29.3MBd/67.3MBe
a;AdaptiveAvgPool2d-45のパラメータ固定
b;AdaptivAvgPool2d-45を削除してFlatten()に置き換え
c;計算中停止時間を引く1931sec-200sec/10epoch
d;j <= 140 False (64,64)
e; j <= 44: False (1.6MB) (64,64) batch32
f; j <= 44: False (1.6MB) (128,128) batch32
####測定条件
Cifar10/128x128画像をbatch=32読み込み時に以下の変換を実施しているが、それ以外のAugumentationは実施していない。
self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((self.width, self.height)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))
        ])

###VGG16モデルのAdaptiveAvgPool2d-45のパラメータ固定と削除
上記の計算において、VGG16_bnでは、全結合層への入り口でAdaptiveAvgPool2dを利用して、出力を51277=4096統一しているようだ。
この方法は、他のモデルでも利用されているが、AdaptiveAvgPool2d(output_size=(1, 1)) と出力サイズが(1,1)となっており、Weightsもそれなりのサイズに収まっている。
ところが、VGG16では、AdaptiveAvgPool2d(output_size=(7, 7))とサイズが大きく、特にWeightsのサイズはほぼ102MBと大きいので目立っている。
固定で利用する場合は、どうにか学習できるが、学習しようとするとメモリー不足で計算できない。
そこで、今回は上の計算で利用した、固定と削除の方法を記載しておく。
####固定について
VGG16_bnでは、以下のように52番目のweightsをFalseにすれば固定される。

    j =0
    for name,  param in self.model_.named_parameters():
        param.requires_grad = True
        if j == 52:
            param.requires_grad = False
        print(j, name, param.requires_grad)
        j +=1

一方、削除は以下のようにモデルの定義で以下のように削除する。

    elif sel_model == 'vgg16':
        model_0 = models.vgg16_bn(pretrained=True) 
        self.model_ = nn.Sequential(*list(model_0.children())[0])  

これは、model0.children()を以下のように出力すると分かるが、AdaptiveAvgPool2d(output_size=(7, 7))は、以下のように別のレイヤーとして定義されており、最初の(features): Sequential()の次の(avgpool):となっている。
そこで、上では*list(model_0.children())[0]と最初の塊を取得している。

vgg16 Net(
  (net): customize_model(
    (model_): VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
...
        (42): ReLU(inplace=True)
        (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    (f_resnet): Sequential(
      (0): Linear(in_features=1000, out_features=10, bias=True)
    )
  )

そして、上記の(classifier)の代わりに以下を追加した。

    if sel_model =='vgg16':
        self.f_resnet = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=np.int(512*2*2), out_features=2048, bias=True), 
            nn.ReLU(),
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(in_features=np.int(512*2*2), out_features=1000, bias=True), 
            nn.Linear(in_features=np.int(1000), out_features=10, bias=True)
            )

###入力サイズ依存性
####resnet18、densenet121とmobilev2の入力画像サイズ依存性
測定条件は、上記と同一で、上の表でWeightsサイズが小さく比較的扱いやすい三つのモデルを選んでサイズ依存性をやってみました。
Cifar10の入力サイズを32,64, 96, 128, 192そして224に変更して10epoch Fine-Tuningした結果は以下の表の通りになりました。

resnet18 val_acc Train_time model_size memory_size
32 0.8242 633 11.7MB 137MB
64 0.8665 732
96 0.887 927
128 0.9077 1186
192 0.9182 1957
224 0.9136 2697
densenet121 val_acc Train_time model_size memory_size
32 0.7705 1249 8.0MB 94MB
64 0.8990 1404
96 0.9171 1726
128 0.9362 2522
192 - -
mobilev2 val_acc Train_time model_size memory_size
32 0.5443 543 3.5MB 41MB
64 0.8546 682
96 0.9024 839
128 0.9255 1252
192 0.9222 2454

以下の依存性があるようだ。
accvssize.png
###まとめ
・pytorch掲載のmodelsで遊んでみた
・Imagenetで学習済パラメータの転移学習をcifar10に適用したが、val_accはあまり向上しない
・Fine-tuningではtop-1 errorで10%以下の値を得た
・top-1 errorは、入力画像サイズ依存性があり、ほぼ単調に減少する傾向が見られた
・層毎の利用が出来るようになった
・mobile v2とdensenet121辺りの性能及び学習時間が良好であった

・マシンのメモリー不足で大きなレンジで測定できなかった

###おまけ
一番短いVGG16です。

VGG16_bn.
0 True
1 True
2 True
3 True
4 True
5 True
6 True
7 True
8 True
9 True
10 True
11 True
12 True
13 True
14 True
15 True
16 True
17 True
18 True
19 True
20 False
21 False
22 False
23 False
24 False
25 False
26 False
27 False
28 False
29 False
30 False
31 False
32 False
33 False
34 False
35 False
36 False
37 False
38 False
39 False
40 False
41 False
42 False
43 False
44 False
45 False
46 False
47 False
48 False
49 False
50 False
51 False
52 False
53 False
54 False
55 False
56 False
57 False
0 features.0.weight True
1 features.0.bias True
2 features.1.weight True
3 features.1.bias True
4 features.3.weight True
5 features.3.bias True
6 features.4.weight True
7 features.4.bias True
8 features.7.weight True
9 features.7.bias True
10 features.8.weight False
11 features.8.bias False
12 features.10.weight False
13 features.10.bias False
14 features.11.weight False
15 features.11.bias False
16 features.14.weight False
17 features.14.bias False
18 features.15.weight False
19 features.15.bias False
20 features.17.weight False
21 features.17.bias False
22 features.18.weight False
23 features.18.bias False
24 features.20.weight False
25 features.20.bias False
26 features.21.weight False
27 features.21.bias False
28 features.24.weight False
29 features.24.bias False
30 features.25.weight False
31 features.25.bias False
32 features.27.weight False
33 features.27.bias False
34 features.28.weight False
35 features.28.bias False
36 features.30.weight False
37 features.30.bias False
38 features.31.weight False
39 features.31.bias False
40 features.34.weight False
41 features.34.bias False
42 features.35.weight False
43 features.35.bias False
44 features.37.weight False
45 features.37.bias False
46 features.38.weight False
47 features.38.bias False
48 features.40.weight False
49 features.40.bias False
50 features.41.weight False
51 features.41.bias False
52 classifier.0.weight False
53 classifier.0.bias False
54 classifier.3.weight False
55 classifier.3.bias False
56 classifier.6.weight False
57 classifier.6.bias False
model_customize vgg16=  customize_model(
  (model_): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace=True)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace=True)
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (16): ReLU(inplace=True)
      (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (19): ReLU(inplace=True)
      (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (22): ReLU(inplace=True)
      (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (26): ReLU(inplace=True)
      (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (29): ReLU(inplace=True)
      (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (32): ReLU(inplace=True)
      (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (36): ReLU(inplace=True)
      (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (39): ReLU(inplace=True)
      (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (42): ReLU(inplace=True)
      (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
    (classifier): Sequential(
      (0): Linear(in_features=25088, out_features=4096, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=4096, out_features=4096, bias=True)
      (4): ReLU(inplace=True)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=4096, out_features=1000, bias=True)
    )
  )
  (f_resnet): Sequential(
    (0): Linear(in_features=1000, out_features=10, bias=True)
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 128, 128]           1,792
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4         [-1, 64, 128, 128]          36,928
       BatchNorm2d-5         [-1, 64, 128, 128]             128
              ReLU-6         [-1, 64, 128, 128]               0
         MaxPool2d-7           [-1, 64, 64, 64]               0
            Conv2d-8          [-1, 128, 64, 64]          73,856
       BatchNorm2d-9          [-1, 128, 64, 64]             256
             ReLU-10          [-1, 128, 64, 64]               0
           Conv2d-11          [-1, 128, 64, 64]         147,584
      BatchNorm2d-12          [-1, 128, 64, 64]             256
             ReLU-13          [-1, 128, 64, 64]               0
        MaxPool2d-14          [-1, 128, 32, 32]               0
           Conv2d-15          [-1, 256, 32, 32]         295,168
      BatchNorm2d-16          [-1, 256, 32, 32]             512
             ReLU-17          [-1, 256, 32, 32]               0
           Conv2d-18          [-1, 256, 32, 32]         590,080
      BatchNorm2d-19          [-1, 256, 32, 32]             512
             ReLU-20          [-1, 256, 32, 32]               0
           Conv2d-21          [-1, 256, 32, 32]         590,080
      BatchNorm2d-22          [-1, 256, 32, 32]             512
             ReLU-23          [-1, 256, 32, 32]               0
        MaxPool2d-24          [-1, 256, 16, 16]               0
           Conv2d-25          [-1, 512, 16, 16]       1,180,160
      BatchNorm2d-26          [-1, 512, 16, 16]           1,024
             ReLU-27          [-1, 512, 16, 16]               0
           Conv2d-28          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-29          [-1, 512, 16, 16]           1,024
             ReLU-30          [-1, 512, 16, 16]               0
           Conv2d-31          [-1, 512, 16, 16]       2,359,808
      BatchNorm2d-32          [-1, 512, 16, 16]           1,024
             ReLU-33          [-1, 512, 16, 16]               0
        MaxPool2d-34            [-1, 512, 8, 8]               0
           Conv2d-35            [-1, 512, 8, 8]       2,359,808
      BatchNorm2d-36            [-1, 512, 8, 8]           1,024
             ReLU-37            [-1, 512, 8, 8]               0
           Conv2d-38            [-1, 512, 8, 8]       2,359,808
      BatchNorm2d-39            [-1, 512, 8, 8]           1,024
             ReLU-40            [-1, 512, 8, 8]               0
           Conv2d-41            [-1, 512, 8, 8]       2,359,808
      BatchNorm2d-42            [-1, 512, 8, 8]           1,024
             ReLU-43            [-1, 512, 8, 8]               0
        MaxPool2d-44            [-1, 512, 4, 4]               0
AdaptiveAvgPool2d-45            [-1, 512, 7, 7]               0
           Linear-46                 [-1, 4096]     102,764,544
             ReLU-47                 [-1, 4096]               0
          Dropout-48                 [-1, 4096]               0
           Linear-49                 [-1, 4096]      16,781,312
             ReLU-50                 [-1, 4096]               0
          Dropout-51                 [-1, 4096]               0
           Linear-52                 [-1, 1000]       4,097,000
              VGG-53                 [-1, 1000]               0
           Linear-54                   [-1, 10]          10,010
================================================================
Total params: 138,376,002
Trainable params: 122,842
Non-trainable params: 138,253,160
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 105.46
Params size (MB): 527.86
Estimated Total Size (MB): 633.51
----------------------------------------------------------------

以下が二番目に短いResnet18です。

Resnet18.
0 True
1 True
2 True
3 True
4 True
5 True
6 True
7 True
8 True
9 True
10 True
11 True
12 True
13 True
14 True
15 True
16 True
17 True
18 True
19 True
20 False
21 False
22 False
23 False
24 False
25 False
26 False
27 False
28 False
29 False
30 False
31 False
32 False
33 False
34 False
35 False
36 False
37 False
38 False
39 False
40 False
41 False
42 False
43 False
44 False
45 False
46 False
47 False
48 False
49 False
50 False
51 False
52 False
53 False
54 False
55 False
56 False
57 False
58 False
59 False
60 False
61 False
0 conv1.weight True
1 bn1.weight True
2 bn1.bias True
3 layer1.0.conv1.weight True
4 layer1.0.bn1.weight True
5 layer1.0.bn1.bias True
6 layer1.0.conv2.weight True
7 layer1.0.bn2.weight True
8 layer1.0.bn2.bias True
9 layer1.1.conv1.weight True
10 layer1.1.bn1.weight False
11 layer1.1.bn1.bias False
12 layer1.1.conv2.weight False
13 layer1.1.bn2.weight False
14 layer1.1.bn2.bias False
15 layer2.0.conv1.weight False
16 layer2.0.bn1.weight False
17 layer2.0.bn1.bias False
18 layer2.0.conv2.weight False
19 layer2.0.bn2.weight False
20 layer2.0.bn2.bias False
21 layer2.0.downsample.0.weight False
22 layer2.0.downsample.1.weight False
23 layer2.0.downsample.1.bias False
24 layer2.1.conv1.weight False
25 layer2.1.bn1.weight False
26 layer2.1.bn1.bias False
27 layer2.1.conv2.weight False
28 layer2.1.bn2.weight False
29 layer2.1.bn2.bias False
30 layer3.0.conv1.weight False
31 layer3.0.bn1.weight False
32 layer3.0.bn1.bias False
33 layer3.0.conv2.weight False
34 layer3.0.bn2.weight False
35 layer3.0.bn2.bias False
36 layer3.0.downsample.0.weight False
37 layer3.0.downsample.1.weight False
38 layer3.0.downsample.1.bias False
39 layer3.1.conv1.weight False
40 layer3.1.bn1.weight False
41 layer3.1.bn1.bias False
42 layer3.1.conv2.weight False
43 layer3.1.bn2.weight False
44 layer3.1.bn2.bias False
45 layer4.0.conv1.weight False
46 layer4.0.bn1.weight False
47 layer4.0.bn1.bias False
48 layer4.0.conv2.weight False
49 layer4.0.bn2.weight False
50 layer4.0.bn2.bias False
51 layer4.0.downsample.0.weight False
52 layer4.0.downsample.1.weight False
53 layer4.0.downsample.1.bias False
54 layer4.1.conv1.weight False
55 layer4.1.bn1.weight False
56 layer4.1.bn1.bias False
57 layer4.1.conv2.weight False
58 layer4.1.bn2.weight False
59 layer4.1.bn2.bias False
60 fc.weight False
61 fc.bias False
model_customize resnet18=  customize_model(
  (model_): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer2): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer3): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (layer4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=512, out_features=1000, bias=True)
  )
  (f_resnet): Sequential(
    (0): Linear(in_features=1000, out_features=10, bias=True)
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
       BasicBlock-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64, 32, 32]               0
           Conv2d-15           [-1, 64, 32, 32]          36,864
      BatchNorm2d-16           [-1, 64, 32, 32]             128
             ReLU-17           [-1, 64, 32, 32]               0
       BasicBlock-18           [-1, 64, 32, 32]               0
           Conv2d-19          [-1, 128, 16, 16]          73,728
      BatchNorm2d-20          [-1, 128, 16, 16]             256
             ReLU-21          [-1, 128, 16, 16]               0
           Conv2d-22          [-1, 128, 16, 16]         147,456
      BatchNorm2d-23          [-1, 128, 16, 16]             256
           Conv2d-24          [-1, 128, 16, 16]           8,192
      BatchNorm2d-25          [-1, 128, 16, 16]             256
             ReLU-26          [-1, 128, 16, 16]               0
       BasicBlock-27          [-1, 128, 16, 16]               0
           Conv2d-28          [-1, 128, 16, 16]         147,456
      BatchNorm2d-29          [-1, 128, 16, 16]             256
             ReLU-30          [-1, 128, 16, 16]               0
           Conv2d-31          [-1, 128, 16, 16]         147,456
      BatchNorm2d-32          [-1, 128, 16, 16]             256
             ReLU-33          [-1, 128, 16, 16]               0
       BasicBlock-34          [-1, 128, 16, 16]               0
           Conv2d-35            [-1, 256, 8, 8]         294,912
      BatchNorm2d-36            [-1, 256, 8, 8]             512
             ReLU-37            [-1, 256, 8, 8]               0
           Conv2d-38            [-1, 256, 8, 8]         589,824
      BatchNorm2d-39            [-1, 256, 8, 8]             512
           Conv2d-40            [-1, 256, 8, 8]          32,768
      BatchNorm2d-41            [-1, 256, 8, 8]             512
             ReLU-42            [-1, 256, 8, 8]               0
       BasicBlock-43            [-1, 256, 8, 8]               0
           Conv2d-44            [-1, 256, 8, 8]         589,824
      BatchNorm2d-45            [-1, 256, 8, 8]             512
             ReLU-46            [-1, 256, 8, 8]               0
           Conv2d-47            [-1, 256, 8, 8]         589,824
      BatchNorm2d-48            [-1, 256, 8, 8]             512
             ReLU-49            [-1, 256, 8, 8]               0
       BasicBlock-50            [-1, 256, 8, 8]               0
           Conv2d-51            [-1, 512, 4, 4]       1,179,648
      BatchNorm2d-52            [-1, 512, 4, 4]           1,024
             ReLU-53            [-1, 512, 4, 4]               0
           Conv2d-54            [-1, 512, 4, 4]       2,359,296
      BatchNorm2d-55            [-1, 512, 4, 4]           1,024
           Conv2d-56            [-1, 512, 4, 4]         131,072
      BatchNorm2d-57            [-1, 512, 4, 4]           1,024
             ReLU-58            [-1, 512, 4, 4]               0
       BasicBlock-59            [-1, 512, 4, 4]               0
           Conv2d-60            [-1, 512, 4, 4]       2,359,296
      BatchNorm2d-61            [-1, 512, 4, 4]           1,024
             ReLU-62            [-1, 512, 4, 4]               0
           Conv2d-63            [-1, 512, 4, 4]       2,359,296
      BatchNorm2d-64            [-1, 512, 4, 4]           1,024
             ReLU-65            [-1, 512, 4, 4]               0
       BasicBlock-66            [-1, 512, 4, 4]               0
AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
           Linear-68                 [-1, 1000]         513,000
           ResNet-69                 [-1, 1000]               0
           Linear-70                   [-1, 10]          10,010
================================================================
Total params: 11,699,522
Trainable params: 130,394
Non-trainable params: 11,569,128
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 20.52
Params size (MB): 44.63
Estimated Total Size (MB): 65.34
----------------------------------------------------------------
5
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?