0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

事前学習済みResNetを切り貼りしてABNを作ってみる

Last updated at Posted at 2021-08-31

事前学習済みのモデルを切り貼りすれば,少し手間を省いて新しいモデルを作ることができる.
ここではResNet50をバラバラにして,ResNet50ベースのABN (attention branch network)を作ってみる.

gistはこちら

切って繋げて戻してみる

まずはResNetを切って,もう一度繋げて,元のモデルと同じかどうかを確認する.(念の為)

元のResNetの情報

まずはoriginalのResNet50.データセットは確認しやすい大きさのCIFAR100を利用.
どのモデルでもよいが,とりあえず
https://github.com/chenyaofo/pytorch-cifar-models
から,CIFAR100用のpretrainモデルを取得.

事前学習済みモデルの取得
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_org = torch.hub.load("chenyaofo/pytorch-cifar-models", 
                           "cifar100_resnet56", 
                           pretrained=True)
model_org = model_org.to(device)
サマリを確認
torchinfo.summary(
    model_org,
    (64, 3, 32, 32),
    depth=2,
    col_names=["input_size",
               "output_size",
               "kernel_size"],
    row_settings=("var_names",)
    )
元のモデルのサマリ
===================================================================================================================
Layer (type (var_name))                  Input Shape               Output Shape              Kernel Shape
===================================================================================================================
CifarResNet                              --                        --                        --
├─Conv2d (conv1)                         [64, 3, 32, 32]           [64, 16, 32, 32]          [3, 16, 3, 3]
├─BatchNorm2d (bn1)                      [64, 16, 32, 32]          [64, 16, 32, 32]          [16]
├─ReLU (relu)                            [64, 16, 32, 32]          [64, 16, 32, 32]          --
├─Sequential (layer1)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (0)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (1)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (2)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (3)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (4)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (5)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (6)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (7)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─BasicBlock (8)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
├─Sequential (layer2)                    [64, 16, 32, 32]          [64, 32, 16, 16]          --
│    └─BasicBlock (0)                    [64, 16, 32, 32]          [64, 32, 16, 16]          --
│    └─BasicBlock (1)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (2)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (3)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (4)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (5)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (6)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (7)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─BasicBlock (8)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
├─Sequential (layer3)                    [64, 32, 16, 16]          [64, 64, 8, 8]            --
│    └─BasicBlock (0)                    [64, 32, 16, 16]          [64, 64, 8, 8]            --
│    └─BasicBlock (1)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (2)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (3)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (4)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (5)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (6)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (7)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BasicBlock (8)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
├─AdaptiveAvgPool2d (avgpool)            [64, 64, 8, 8]            [64, 64, 1, 1]            --
├─Linear (fc)                            [64, 64]                  [64, 100]                 [64, 100]
===================================================================================================================
Total params: 861,620
Trainable params: 861,620
Non-trainable params: 0
Total mult-adds (G): 8.05
===================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 557.89
Params size (MB): 3.45
Estimated Total Size (MB): 562.13
===================================================================================================================

切ってつなげて元通りになるかかどうか確認

次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する.

切ってつなげるだけ.ここでは最後のプーリング前で切る
class ReconstructResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", 
                               "cifar100_resnet56", 
                               pretrained=True)

        self.resnet50_bottom_half = nn.Sequential(
            model.conv1,
            model.bn1,
            model.relu,
            model.layer1,
            model.layer2,
            model.layer3,
        )

        self.resnet50_top_half = nn.Sequential(
            model.avgpool,
            nn.Flatten(),
            model.fc
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.resnet50_bottom_half(x)
        x = self.resnet50_top_half(x)
        return x
繋げてサマリを確認
model_new = ReconstructResNet50()
model_new = model_new.to(device)

torchinfo.summary(
    model_new,
    (64, 3, 32, 32),
    depth=3,
    col_names=["input_size",
               "output_size",
               "kernel_size"],
    row_settings=("var_names",)
    )

再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる.

========================================================================================================================
Layer (type (var_name))                       Input Shape               Output Shape              Kernel Shape
========================================================================================================================
ReconstructResNet50                           --                        --                        --
├─Sequential (resnet50_bottom_half)           [64, 3, 32, 32]           [64, 64, 8, 8]            --
│    └─Conv2d (0)                             [64, 3, 32, 32]           [64, 16, 32, 32]          [3, 16, 3, 3]
│    └─BatchNorm2d (1)                        [64, 16, 32, 32]          [64, 16, 32, 32]          [16]
│    └─ReLU (2)                               [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─Sequential (3)                         [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (0)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (1)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (2)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (3)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (4)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (5)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (6)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (7)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    │    └─BasicBlock (8)                    [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─Sequential (4)                         [64, 16, 32, 32]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (0)                    [64, 16, 32, 32]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (1)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (2)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (3)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (4)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (5)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (6)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (7)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    │    └─BasicBlock (8)                    [64, 32, 16, 16]          [64, 32, 16, 16]          --
│    └─Sequential (5)                         [64, 32, 16, 16]          [64, 64, 8, 8]            --
│    │    └─BasicBlock (0)                    [64, 32, 16, 16]          [64, 64, 8, 8]            --
│    │    └─BasicBlock (1)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (2)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (3)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (4)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (5)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (6)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (7)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    │    └─BasicBlock (8)                    [64, 64, 8, 8]            [64, 64, 8, 8]            --
├─Sequential (resnet50_top_half)              [64, 64, 8, 8]            [64, 100]                 --
│    └─AdaptiveAvgPool2d (0)                  [64, 64, 8, 8]            [64, 64, 1, 1]            --
│    └─Flatten (1)                            [64, 64, 1, 1]            [64, 64]                  --
│    └─Linear (2)                             [64, 64]                  [64, 100]                 [64, 100]
========================================================================================================================
Total params: 861,620
Trainable params: 861,620
Non-trainable params: 0
Total mult-adds (G): 8.05
========================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 557.89
Params size (MB): 3.45
Estimated Total Size (MB): 562.13
========================================================================================================================

データを流し込んで確認

ではデータを流し込んでみる.

流し込むデータ
data = torch.randn(64, 3, 32, 32).to(device)

出力が一致するか

元のモデルのサマリと出力
model_org.eval()
print(model_org(data).max(axis=1))
print(model_org(data)[:10, :5])

torch.return_types.max(
values=tensor([ 6.6042,  9.0203,  7.3368,  7.8674,  6.7200,  7.2444, 10.0077, 10.5660,
        10.3509,  9.1652,  8.1699,  8.6301,  6.8983,  7.2104,  7.0237,  9.4096,
         7.9846,  8.4669,  8.8615,  7.1618,  6.6025,  7.1010,  7.3803,  9.5781,
         8.5849,  9.0294,  6.8827, 10.2920,  8.5496,  6.8068,  7.8123,  7.7742,
         9.2964,  8.1514,  6.8203,  7.3823, 10.8784,  7.1684,  7.2372,  8.7422,
         9.0883,  8.0774,  9.1850,  7.1326,  6.9103,  8.9576,  9.5753,  8.6561,
         8.2095,  8.5133,  7.8957,  7.3277,  7.9779,  7.0222,  6.9912,  9.9922,
         6.4064,  7.1179,  6.9993,  8.4908,  6.4606, 10.3621,  9.0333,  6.0951],
       device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([6, 2, 6, 2, 2, 6, 2, 6, 2, 2, 6, 2, 6, 6, 2, 2, 6, 2, 6, 2, 6, 6, 6, 6,
        6, 2, 6, 2, 2, 2, 6, 2, 2, 6, 6, 6, 6, 2, 2, 6, 2, 2, 2, 2, 6, 2, 6, 6,
        2, 2, 2, 2, 6, 6, 2, 6, 2, 2, 2, 2, 6, 6, 6, 6], device='cuda:0'))
tensor([[-2.8958, -0.3914,  5.9768, -0.3876, -0.5456],
        [-2.7208,  0.2338,  9.0203, -0.2278, -0.7832],
        [-3.1234,  0.2259,  6.8180,  0.0619, -1.2472],
        [-3.2154, -0.2761,  7.8674,  0.3466, -1.0217],
        [-2.6827, -0.5690,  6.7200,  0.1551, -0.4000],
        [-3.3118,  0.1258,  6.2590,  0.0972, -1.1040],
        [-2.4617, -0.3470, 10.0077, -0.2054, -0.0124],
        [-3.2467, -0.6021,  4.2468, -1.4212, -1.6072],
        [-2.3458, -0.4344, 10.3509,  0.7181, -0.1760],
        [-2.3557, -0.4685,  9.1652, -0.0789, -0.4057]], device='cuda:0',
       grad_fn=<SliceBackward>)
再構成したモデルの出力
model_new.eval()
print(model_new(data).max(axis=1))
print(model_new(data)[:10, :5])


torch.return_types.max(
values=tensor([ 6.6042,  9.0203,  7.3368,  7.8674,  6.7200,  7.2444, 10.0077, 10.5660,
        10.3509,  9.1652,  8.1699,  8.6301,  6.8983,  7.2104,  7.0237,  9.4096,
         7.9846,  8.4669,  8.8615,  7.1618,  6.6025,  7.1010,  7.3803,  9.5781,
         8.5849,  9.0294,  6.8827, 10.2920,  8.5496,  6.8068,  7.8123,  7.7742,
         9.2964,  8.1514,  6.8203,  7.3823, 10.8784,  7.1684,  7.2372,  8.7422,
         9.0883,  8.0774,  9.1850,  7.1326,  6.9103,  8.9576,  9.5753,  8.6561,
         8.2095,  8.5133,  7.8957,  7.3277,  7.9779,  7.0222,  6.9912,  9.9922,
         6.4064,  7.1179,  6.9993,  8.4908,  6.4606, 10.3621,  9.0333,  6.0951],
       device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([6, 2, 6, 2, 2, 6, 2, 6, 2, 2, 6, 2, 6, 6, 2, 2, 6, 2, 6, 2, 6, 6, 6, 6,
        6, 2, 6, 2, 2, 2, 6, 2, 2, 6, 6, 6, 6, 2, 2, 6, 2, 2, 2, 2, 6, 2, 6, 6,
        2, 2, 2, 2, 6, 6, 2, 6, 2, 2, 2, 2, 6, 6, 6, 6], device='cuda:0'))
tensor([[-2.8958, -0.3914,  5.9768, -0.3876, -0.5456],
        [-2.7208,  0.2338,  9.0203, -0.2278, -0.7832],
        [-3.1234,  0.2259,  6.8180,  0.0619, -1.2472],
        [-3.2154, -0.2761,  7.8674,  0.3466, -1.0217],
        [-2.6827, -0.5690,  6.7200,  0.1551, -0.4000],
        [-3.3118,  0.1258,  6.2590,  0.0972, -1.1040],
        [-2.4617, -0.3470, 10.0077, -0.2054, -0.0124],
        [-3.2467, -0.6021,  4.2468, -1.4212, -1.6072],
        [-2.3458, -0.4344, 10.3509,  0.7181, -0.1760],
        [-2.3557, -0.4685,  9.1652, -0.0789, -0.4057]], device='cuda:0',
       grad_fn=<SliceBackward>)

どちらも同じ出力が得られたことが分かる.

lossは一致するか

ではoptimizerを設定して,lossも一致するかどうか確認する.

optimizerとtargetラベルの準備
criterion = nn.CrossEntropyLoss()
optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

target = torch.randint(num_classes, (64,)).to(device)
元のモデル
model_org.train()
optimizer_org.zero_grad()
output_org = model_org(data)
loss_org = criterion(output_org, target)
loss_org.backward()
loss_org

tensor(4.3274, device='cuda:0', grad_fn=<NllLossBackward>)
再構成したモデル
model_new.train()
optimizer_new.zero_grad()
output_new = model_new(data)
loss_new = criterion(output_new, target)
loss_new.backward()
loss_new

tensor(4.3274, device='cuda:0', grad_fn=<NllLossBackward>)

lossは一致した.

勾配も一致するか

では重みとその勾配は一致するかどうかを確認する.対象は最初のconvに限定.

ノートブックの入出力
model_new.resnet50_bottom_half[0].weight[0, 0, :, :]

tensor([[-0.0568, -0.0996,  0.0227],
        [-0.1853, -0.1841,  0.1732],
        [ 0.0093, -0.0274,  0.0995]], device='cuda:0', grad_fn=<SliceBackward>)

model_org.conv1.weight[0, 0, :, :]

tensor([[-0.0568, -0.0996,  0.0227],
        [-0.1853, -0.1841,  0.1732],
        [ 0.0093, -0.0274,  0.0995]], device='cuda:0', grad_fn=<SliceBackward>)

model_new.resnet50_bottom_half[0].weight.grad[0, 0, :, :]

tensor([[ 0.4496, -0.1967, -0.1387],
        [ 0.4462,  0.0295,  0.6632],
        [-0.4074, -0.0283,  0.2142]], device='cuda:0')

model_org.conv1.weight.grad[0, 0, :, :]

tensor([[ 0.4496, -0.1967, -0.1387],
        [ 0.4462,  0.0295,  0.6632],
        [-0.4074, -0.0283,  0.2142]], device='cuda:0')

これで一致したことが確認できた.

ResNetを切って貼って,ABNを作る

では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる.

ResNet50ベースのABN

class ABNResNet50(nn.Module):
    def __init__(self, num_classes=100, pretrained=True):
        super().__init__()
        model = torch.hub.load("chenyaofo/pytorch-cifar-models", 
                               "cifar100_resnet56", 
                               pretrained=pretrained)

        self.resnet50_bottom = nn.Sequential(
            model.conv1,
            model.bn1,
            model.relu,
            model.layer1,
            model.layer2,
            model.layer3,
        )
        r50b_out_features = \
            list(self.resnet50_bottom.modules())[-1].num_features

        self.resnet50_top = nn.Sequential(
            model.avgpool,
            nn.Flatten(),
            model.fc
        )

        self.attention_branch1 = nn.Sequential(
            # ここは入出力サイズが同じlayer3[1:8]を再利用してしまおう
            # deepcopyしないと,上で使ったものと重みが共有されてしまうので注意
            copy.deepcopy(model.layer3[1:8]),

            nn.BatchNorm2d(r50b_out_features),
            nn.Conv2d(r50b_out_features, num_classes, kernel_size=1),
            nn.ReLU(inplace=True)
        )
        self.attention_branch2 = nn.Sequential(
            nn.Conv2d(num_classes, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.attention_branch3 = nn.Sequential(
            nn.Conv2d(num_classes, num_classes, kernel_size=1),
        )

    def get_attn(self):
        return self.attn

    def forward(self, x):
        x = self.resnet50_bottom(x)

        ax = self.attention_branch1(x)
        attn = self.attention_branch2(ax)

        x = x * attn
        x = self.resnet50_top(x)

        ax = self.attention_branch3(ax)
        ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze()  # GAP

        self.attn = attn
        return x, ax, attn

ABNのパラメータを確認.

ABNのサマリ
torchinfo.summary(
    model,
    (64, 3, 32, 32),
    depth=2,  # 3にすると長いので,2にして省略
    col_names=["input_size",
               "output_size",
               "kernel_size"],
    row_settings=("var_names",)
    )
========================================================================================================================
Layer (type (var_name))                       Input Shape               Output Shape              Kernel Shape
========================================================================================================================
ABNResNet50                                   --                        --                        --
├─Sequential (resnet50_bottom)                [64, 3, 32, 32]           [64, 64, 8, 8]            --
│    └─Conv2d (0)                             [64, 3, 32, 32]           [64, 16, 32, 32]          [3, 16, 3, 3]
│    └─BatchNorm2d (1)                        [64, 16, 32, 32]          [64, 16, 32, 32]          [16]
│    └─ReLU (2)                               [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─Sequential (3)                         [64, 16, 32, 32]          [64, 16, 32, 32]          --
│    └─Sequential (4)                         [64, 16, 32, 32]          [64, 32, 16, 16]          --
│    └─Sequential (5)                         [64, 32, 16, 16]          [64, 64, 8, 8]            --
├─Sequential (attention_branch1)              [64, 64, 8, 8]            [64, 100, 8, 8]           --
│    └─Sequential (0)                         [64, 64, 8, 8]            [64, 64, 8, 8]            --
│    └─BatchNorm2d (1)                        [64, 64, 8, 8]            [64, 64, 8, 8]            [64]
│    └─Conv2d (2)                             [64, 64, 8, 8]            [64, 100, 8, 8]           [64, 100, 1, 1]
│    └─ReLU (3)                               [64, 100, 8, 8]           [64, 100, 8, 8]           --
├─Sequential (attention_branch2)              [64, 100, 8, 8]           [64, 1, 8, 8]             --
│    └─Conv2d (0)                             [64, 100, 8, 8]           [64, 1, 8, 8]             [100, 1, 1, 1]
│    └─BatchNorm2d (1)                        [64, 1, 8, 8]             [64, 1, 8, 8]             [1]
│    └─Sigmoid (2)                            [64, 1, 8, 8]             [64, 1, 8, 8]             --
├─Sequential (resnet50_top)                   [64, 64, 8, 8]            [64, 100]                 --
│    └─AdaptiveAvgPool2d (0)                  [64, 64, 8, 8]            [64, 64, 1, 1]            --
│    └─Flatten (1)                            [64, 64, 1, 1]            [64, 64]                  --
│    └─Linear (2)                             [64, 64]                  [64, 100]                 [64, 100]
├─Sequential (attention_branch3)              [64, 100, 8, 8]           [64, 100, 8, 8]           --
│    └─Conv2d (0)                             [64, 100, 8, 8]           [64, 100, 8, 8]           [100, 100, 1, 1]
========================================================================================================================
Total params: 1,396,339
Trainable params: 1,396,339
Non-trainable params: 0
Total mult-adds (G): 10.23
========================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 625.33
Params size (MB): 5.59
Estimated Total Size (MB): 631.70
========================================================================================================================

便利関数を作っておく.

utilities
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if type(val) == torch.Tensor:
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """
    Computes the accuracy over the k top predictions for the specified values of k
    https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
    """
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

ABNの学習

では学習.

準備
model = ABNResNet50(pretrained=True)
model.to(device)
model.train()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

criterion = nn.CrossEntropyLoss()

epoch_num = 10
学習と評価のループ
def one_epoch_process(data_loader, is_train=True):
    """train or evaluation for one epoch

    Args:
        data_loader (DataLoader): data loader of train or val set
        model (nn.Module): CNN model
        is_train (bool, optional): flag of train or val. Defaults to True.
    """
    with tqdm(enumerate(data_loader),
                total=len(data_loader),
                leave=True) as pbar_loss:

        log_loss = AverageMeter()
        log_top1 = AverageMeter()
        log_top5 = AverageMeter()
        correct = AverageMeter()

        
        for batch_idx, (image, label) in pbar_loss:
            pbar_loss.set_description("[{}]".format('train' if is_train else 'val'))

            current_batch_size = image.size(0)

            image, label = image.to(device), label.to(device)

            y, ay, _ = model(image)

            loss = criterion(y, label)
            loss_ay = criterion(ay, label)
            loss_all = (loss + loss_ay) / 2
            log_loss.update(loss_all, current_batch_size)

            if is_train:
                optimizer.zero_grad()
                loss_all.backward()
                optimizer.step()
            
            acc1, acc5 = accuracy(y, label, topk=(1, 5))
            log_top1.update(acc1, current_batch_size)
            log_top5.update(acc5, current_batch_size)

            pbar_loss.set_postfix_str(
                ' | loss={:6.04f} acc top1={:6.04f} top5={:6.04f}'
                ' err top1={:6.04f} top5={:6.04f}'
                ''.format(
                log_loss.avg, 
                log_top1.avg,
                log_top5.avg,
                100 - log_top1.avg,
                100 - log_top5.avg,
            ))


def train():
    model.train()
    one_epoch_process(train_loader, is_train=True)

def validate():
    model.eval()
    with torch.no_grad():
        one_epoch_process(val_loader, is_train=False)


with tqdm(range(epoch_num)) as pbar_epoch:
    for epoch in pbar_epoch:
        pbar_epoch.set_description("[Epoch %d]" % (epoch))

        train()

        validate()

結果.

[train]: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s,  | loss=0.7773 acc top1=78.2380 top5=97.6020 err top1=21.7620 top5=2.3980]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.31it/s,  | loss=0.5266 acc top1=82.5600 top5=99.2100 err top1=17.4400 top5=0.7900]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s,  | loss=0.4180 acc top1=85.9420 top5=99.4820 err top1=14.0580 top5=0.5180]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.71it/s,  | loss=0.4284 acc top1=85.2400 top5=99.4400 err top1=14.7600 top5=0.5600]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.82it/s,  | loss=0.3465 acc top1=88.2740 top5=99.6420 err top1=11.7260 top5=0.3580]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.58it/s,  | loss=0.4238 acc top1=85.8700 top5=99.4100 err top1=14.1300 top5=0.5900]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.98it/s,  | loss=0.3014 acc top1=89.8680 top5=99.6780 err top1=10.1320 top5=0.3220]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.73it/s,  | loss=0.3660 acc top1=87.6900 top5=99.6300 err top1=12.3100 top5=0.3700]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s,  | loss=0.2664 acc top1=90.9200 top5=99.7340 err top1=9.0800 top5=0.2660]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.18it/s,  | loss=0.3526 acc top1=88.3100 top5=99.6000 err top1=11.6900 top5=0.4000]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s,  | loss=0.2407 acc top1=91.9240 top5=99.8040 err top1=8.0760 top5=0.1960]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.45it/s,  | loss=0.3370 acc top1=89.0900 top5=99.6700 err top1=10.9100 top5=0.3300]
[train]: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s,  | loss=0.2209 acc top1=92.5220 top5=99.8440 err top1=7.4780 top5=0.1560]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.61it/s,  | loss=0.2876 acc top1=90.2700 top5=99.7300 err top1=9.7300 top5=0.2700]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s,  | loss=0.2077 acc top1=92.9760 top5=99.8880 err top1=7.0240 top5=0.1120]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.10it/s,  | loss=0.3089 acc top1=90.0600 top5=99.7000 err top1=9.9400 top5=0.3000]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.97it/s,  | loss=0.1881 acc top1=93.5180 top5=99.8740 err top1=6.4820 top5=0.1260]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.59it/s,  | loss=0.3117 acc top1=90.1800 top5=99.7800 err top1=9.8200 top5=0.2200]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.86it/s,  | loss=0.1755 acc top1=93.9980 top5=99.9020 err top1=6.0020 top5=0.0980]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.12it/s,  | loss=0.3198 acc top1=89.7400 top5=99.7600 err top1=10.2600 top5=0.2400]
[Epoch 9]: 100%|██████████| 10/10 [10:57<00:00, 65.71s/it]
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?