事前学習済みのモデルを切り貼りすれば,少し手間を省いて新しいモデルを作ることができる.
ここではResNet50をバラバラにして,ResNet50ベースのABN (attention branch network)を作ってみる.
切って繋げて戻してみる
まずはResNetを切って,もう一度繋げて,元のモデルと同じかどうかを確認する.(念の為)
元のResNetの情報
まずはoriginalのResNet50.データセットは確認しやすい大きさのCIFAR100を利用.
どのモデルでもよいが,とりあえず
https://github.com/chenyaofo/pytorch-cifar-models
から,CIFAR100用のpretrainモデルを取得.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_org = torch.hub.load("chenyaofo/pytorch-cifar-models",
"cifar100_resnet56",
pretrained=True)
model_org = model_org.to(device)
torchinfo.summary(
model_org,
(64, 3, 32, 32),
depth=2,
col_names=["input_size",
"output_size",
"kernel_size"],
row_settings=("var_names",)
)
===================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Kernel Shape
===================================================================================================================
CifarResNet -- -- --
├─Conv2d (conv1) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]
├─BatchNorm2d (bn1) [64, 16, 32, 32] [64, 16, 32, 32] [16]
├─ReLU (relu) [64, 16, 32, 32] [64, 16, 32, 32] --
├─Sequential (layer1) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --
├─Sequential (layer2) [64, 16, 32, 32] [64, 32, 16, 16] --
│ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --
│ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --
├─Sequential (layer3) [64, 32, 16, 16] [64, 64, 8, 8] --
│ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --
│ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --
├─AdaptiveAvgPool2d (avgpool) [64, 64, 8, 8] [64, 64, 1, 1] --
├─Linear (fc) [64, 64] [64, 100] [64, 100]
===================================================================================================================
Total params: 861,620
Trainable params: 861,620
Non-trainable params: 0
Total mult-adds (G): 8.05
===================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 557.89
Params size (MB): 3.45
Estimated Total Size (MB): 562.13
===================================================================================================================
切ってつなげて元通りになるかかどうか確認
次は,事前学習済みResNetを,前半と後半に分けて,それらをつなげた再構成モデルを作り,originalと同じ出力が得られるかどうかを確認する.
class ReconstructResNet50(nn.Module):
def __init__(self):
super().__init__()
model = torch.hub.load("chenyaofo/pytorch-cifar-models",
"cifar100_resnet56",
pretrained=True)
self.resnet50_bottom_half = nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.layer1,
model.layer2,
model.layer3,
)
self.resnet50_top_half = nn.Sequential(
model.avgpool,
nn.Flatten(),
model.fc
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.resnet50_bottom_half(x)
x = self.resnet50_top_half(x)
return x
model_new = ReconstructResNet50()
model_new = model_new.to(device)
torchinfo.summary(
model_new,
(64, 3, 32, 32),
depth=3,
col_names=["input_size",
"output_size",
"kernel_size"],
row_settings=("var_names",)
)
再構成したモデルの中身.depthが一つ深くなっているが,パラメータ数その他は同一であることが分かる.
========================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Kernel Shape
========================================================================================================================
ReconstructResNet50 -- -- --
├─Sequential (resnet50_bottom_half) [64, 3, 32, 32] [64, 64, 8, 8] --
│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]
│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]
│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (1) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (2) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (3) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (4) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (5) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (6) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (7) [64, 16, 32, 32] [64, 16, 32, 32] --
│ │ └─BasicBlock (8) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --
│ │ └─BasicBlock (0) [64, 16, 32, 32] [64, 32, 16, 16] --
│ │ └─BasicBlock (1) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (2) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (3) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (4) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (5) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (6) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (7) [64, 32, 16, 16] [64, 32, 16, 16] --
│ │ └─BasicBlock (8) [64, 32, 16, 16] [64, 32, 16, 16] --
│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --
│ │ └─BasicBlock (0) [64, 32, 16, 16] [64, 64, 8, 8] --
│ │ └─BasicBlock (1) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (2) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (3) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (4) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (5) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (6) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (7) [64, 64, 8, 8] [64, 64, 8, 8] --
│ │ └─BasicBlock (8) [64, 64, 8, 8] [64, 64, 8, 8] --
├─Sequential (resnet50_top_half) [64, 64, 8, 8] [64, 100] --
│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --
│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --
│ └─Linear (2) [64, 64] [64, 100] [64, 100]
========================================================================================================================
Total params: 861,620
Trainable params: 861,620
Non-trainable params: 0
Total mult-adds (G): 8.05
========================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 557.89
Params size (MB): 3.45
Estimated Total Size (MB): 562.13
========================================================================================================================
データを流し込んで確認
ではデータを流し込んでみる.
data = torch.randn(64, 3, 32, 32).to(device)
出力が一致するか
model_org.eval()
print(model_org(data).max(axis=1))
print(model_org(data)[:10, :5])
torch.return_types.max(
values=tensor([ 6.6042, 9.0203, 7.3368, 7.8674, 6.7200, 7.2444, 10.0077, 10.5660,
10.3509, 9.1652, 8.1699, 8.6301, 6.8983, 7.2104, 7.0237, 9.4096,
7.9846, 8.4669, 8.8615, 7.1618, 6.6025, 7.1010, 7.3803, 9.5781,
8.5849, 9.0294, 6.8827, 10.2920, 8.5496, 6.8068, 7.8123, 7.7742,
9.2964, 8.1514, 6.8203, 7.3823, 10.8784, 7.1684, 7.2372, 8.7422,
9.0883, 8.0774, 9.1850, 7.1326, 6.9103, 8.9576, 9.5753, 8.6561,
8.2095, 8.5133, 7.8957, 7.3277, 7.9779, 7.0222, 6.9912, 9.9922,
6.4064, 7.1179, 6.9993, 8.4908, 6.4606, 10.3621, 9.0333, 6.0951],
device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([6, 2, 6, 2, 2, 6, 2, 6, 2, 2, 6, 2, 6, 6, 2, 2, 6, 2, 6, 2, 6, 6, 6, 6,
6, 2, 6, 2, 2, 2, 6, 2, 2, 6, 6, 6, 6, 2, 2, 6, 2, 2, 2, 2, 6, 2, 6, 6,
2, 2, 2, 2, 6, 6, 2, 6, 2, 2, 2, 2, 6, 6, 6, 6], device='cuda:0'))
tensor([[-2.8958, -0.3914, 5.9768, -0.3876, -0.5456],
[-2.7208, 0.2338, 9.0203, -0.2278, -0.7832],
[-3.1234, 0.2259, 6.8180, 0.0619, -1.2472],
[-3.2154, -0.2761, 7.8674, 0.3466, -1.0217],
[-2.6827, -0.5690, 6.7200, 0.1551, -0.4000],
[-3.3118, 0.1258, 6.2590, 0.0972, -1.1040],
[-2.4617, -0.3470, 10.0077, -0.2054, -0.0124],
[-3.2467, -0.6021, 4.2468, -1.4212, -1.6072],
[-2.3458, -0.4344, 10.3509, 0.7181, -0.1760],
[-2.3557, -0.4685, 9.1652, -0.0789, -0.4057]], device='cuda:0',
grad_fn=<SliceBackward>)
model_new.eval()
print(model_new(data).max(axis=1))
print(model_new(data)[:10, :5])
torch.return_types.max(
values=tensor([ 6.6042, 9.0203, 7.3368, 7.8674, 6.7200, 7.2444, 10.0077, 10.5660,
10.3509, 9.1652, 8.1699, 8.6301, 6.8983, 7.2104, 7.0237, 9.4096,
7.9846, 8.4669, 8.8615, 7.1618, 6.6025, 7.1010, 7.3803, 9.5781,
8.5849, 9.0294, 6.8827, 10.2920, 8.5496, 6.8068, 7.8123, 7.7742,
9.2964, 8.1514, 6.8203, 7.3823, 10.8784, 7.1684, 7.2372, 8.7422,
9.0883, 8.0774, 9.1850, 7.1326, 6.9103, 8.9576, 9.5753, 8.6561,
8.2095, 8.5133, 7.8957, 7.3277, 7.9779, 7.0222, 6.9912, 9.9922,
6.4064, 7.1179, 6.9993, 8.4908, 6.4606, 10.3621, 9.0333, 6.0951],
device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([6, 2, 6, 2, 2, 6, 2, 6, 2, 2, 6, 2, 6, 6, 2, 2, 6, 2, 6, 2, 6, 6, 6, 6,
6, 2, 6, 2, 2, 2, 6, 2, 2, 6, 6, 6, 6, 2, 2, 6, 2, 2, 2, 2, 6, 2, 6, 6,
2, 2, 2, 2, 6, 6, 2, 6, 2, 2, 2, 2, 6, 6, 6, 6], device='cuda:0'))
tensor([[-2.8958, -0.3914, 5.9768, -0.3876, -0.5456],
[-2.7208, 0.2338, 9.0203, -0.2278, -0.7832],
[-3.1234, 0.2259, 6.8180, 0.0619, -1.2472],
[-3.2154, -0.2761, 7.8674, 0.3466, -1.0217],
[-2.6827, -0.5690, 6.7200, 0.1551, -0.4000],
[-3.3118, 0.1258, 6.2590, 0.0972, -1.1040],
[-2.4617, -0.3470, 10.0077, -0.2054, -0.0124],
[-3.2467, -0.6021, 4.2468, -1.4212, -1.6072],
[-2.3458, -0.4344, 10.3509, 0.7181, -0.1760],
[-2.3557, -0.4685, 9.1652, -0.0789, -0.4057]], device='cuda:0',
grad_fn=<SliceBackward>)
どちらも同じ出力が得られたことが分かる.
lossは一致するか
ではoptimizerを設定して,lossも一致するかどうか確認する.
criterion = nn.CrossEntropyLoss()
optimizer_org = torch.optim.SGD(model_org.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
optimizer_new = torch.optim.SGD(model_new.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
target = torch.randint(num_classes, (64,)).to(device)
model_org.train()
optimizer_org.zero_grad()
output_org = model_org(data)
loss_org = criterion(output_org, target)
loss_org.backward()
loss_org
tensor(4.3274, device='cuda:0', grad_fn=<NllLossBackward>)
model_new.train()
optimizer_new.zero_grad()
output_new = model_new(data)
loss_new = criterion(output_new, target)
loss_new.backward()
loss_new
tensor(4.3274, device='cuda:0', grad_fn=<NllLossBackward>)
lossは一致した.
勾配も一致するか
では重みとその勾配は一致するかどうかを確認する.対象は最初のconvに限定.
model_new.resnet50_bottom_half[0].weight[0, 0, :, :]
tensor([[-0.0568, -0.0996, 0.0227],
[-0.1853, -0.1841, 0.1732],
[ 0.0093, -0.0274, 0.0995]], device='cuda:0', grad_fn=<SliceBackward>)
model_org.conv1.weight[0, 0, :, :]
tensor([[-0.0568, -0.0996, 0.0227],
[-0.1853, -0.1841, 0.1732],
[ 0.0093, -0.0274, 0.0995]], device='cuda:0', grad_fn=<SliceBackward>)
model_new.resnet50_bottom_half[0].weight.grad[0, 0, :, :]
tensor([[ 0.4496, -0.1967, -0.1387],
[ 0.4462, 0.0295, 0.6632],
[-0.4074, -0.0283, 0.2142]], device='cuda:0')
model_org.conv1.weight.grad[0, 0, :, :]
tensor([[ 0.4496, -0.1967, -0.1387],
[ 0.4462, 0.0295, 0.6632],
[-0.4074, -0.0283, 0.2142]], device='cuda:0')
これで一致したことが確認できた.
ResNetを切って貼って,ABNを作る
では事前学習済みResNetを色々と切り貼りして,ABNを作ってみる.
class ABNResNet50(nn.Module):
def __init__(self, num_classes=100, pretrained=True):
super().__init__()
model = torch.hub.load("chenyaofo/pytorch-cifar-models",
"cifar100_resnet56",
pretrained=pretrained)
self.resnet50_bottom = nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.layer1,
model.layer2,
model.layer3,
)
r50b_out_features = \
list(self.resnet50_bottom.modules())[-1].num_features
self.resnet50_top = nn.Sequential(
model.avgpool,
nn.Flatten(),
model.fc
)
self.attention_branch1 = nn.Sequential(
# ここは入出力サイズが同じlayer3[1:8]を再利用してしまおう
# deepcopyしないと,上で使ったものと重みが共有されてしまうので注意
copy.deepcopy(model.layer3[1:8]),
nn.BatchNorm2d(r50b_out_features),
nn.Conv2d(r50b_out_features, num_classes, kernel_size=1),
nn.ReLU(inplace=True)
)
self.attention_branch2 = nn.Sequential(
nn.Conv2d(num_classes, 1, kernel_size=1),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.attention_branch3 = nn.Sequential(
nn.Conv2d(num_classes, num_classes, kernel_size=1),
)
def get_attn(self):
return self.attn
def forward(self, x):
x = self.resnet50_bottom(x)
ax = self.attention_branch1(x)
attn = self.attention_branch2(ax)
x = x * attn
x = self.resnet50_top(x)
ax = self.attention_branch3(ax)
ax = F.avg_pool2d(ax, kernel_size=ax.shape[2]).squeeze() # GAP
self.attn = attn
return x, ax, attn
ABNのパラメータを確認.
torchinfo.summary(
model,
(64, 3, 32, 32),
depth=2, # 3にすると長いので,2にして省略
col_names=["input_size",
"output_size",
"kernel_size"],
row_settings=("var_names",)
)
========================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Kernel Shape
========================================================================================================================
ABNResNet50 -- -- --
├─Sequential (resnet50_bottom) [64, 3, 32, 32] [64, 64, 8, 8] --
│ └─Conv2d (0) [64, 3, 32, 32] [64, 16, 32, 32] [3, 16, 3, 3]
│ └─BatchNorm2d (1) [64, 16, 32, 32] [64, 16, 32, 32] [16]
│ └─ReLU (2) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─Sequential (3) [64, 16, 32, 32] [64, 16, 32, 32] --
│ └─Sequential (4) [64, 16, 32, 32] [64, 32, 16, 16] --
│ └─Sequential (5) [64, 32, 16, 16] [64, 64, 8, 8] --
├─Sequential (attention_branch1) [64, 64, 8, 8] [64, 100, 8, 8] --
│ └─Sequential (0) [64, 64, 8, 8] [64, 64, 8, 8] --
│ └─BatchNorm2d (1) [64, 64, 8, 8] [64, 64, 8, 8] [64]
│ └─Conv2d (2) [64, 64, 8, 8] [64, 100, 8, 8] [64, 100, 1, 1]
│ └─ReLU (3) [64, 100, 8, 8] [64, 100, 8, 8] --
├─Sequential (attention_branch2) [64, 100, 8, 8] [64, 1, 8, 8] --
│ └─Conv2d (0) [64, 100, 8, 8] [64, 1, 8, 8] [100, 1, 1, 1]
│ └─BatchNorm2d (1) [64, 1, 8, 8] [64, 1, 8, 8] [1]
│ └─Sigmoid (2) [64, 1, 8, 8] [64, 1, 8, 8] --
├─Sequential (resnet50_top) [64, 64, 8, 8] [64, 100] --
│ └─AdaptiveAvgPool2d (0) [64, 64, 8, 8] [64, 64, 1, 1] --
│ └─Flatten (1) [64, 64, 1, 1] [64, 64] --
│ └─Linear (2) [64, 64] [64, 100] [64, 100]
├─Sequential (attention_branch3) [64, 100, 8, 8] [64, 100, 8, 8] --
│ └─Conv2d (0) [64, 100, 8, 8] [64, 100, 8, 8] [100, 100, 1, 1]
========================================================================================================================
Total params: 1,396,339
Trainable params: 1,396,339
Non-trainable params: 0
Total mult-adds (G): 10.23
========================================================================================================================
Input size (MB): 0.79
Forward/backward pass size (MB): 625.33
Params size (MB): 5.59
Estimated Total Size (MB): 631.70
========================================================================================================================
便利関数を作っておく.
class AverageMeter(object):
"""
Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L363-L380
https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
if type(val) == torch.Tensor:
val = val.item()
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""
Computes the accuracy over the k top predictions for the specified values of k
https://github.com/pytorch/examples/blob/cedca7729fef11c91e28099a0e45d7e98d03b66d/imagenet/main.py#L411
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
ABNの学習
では学習.
model = ABNResNet50(pretrained=True)
model.to(device)
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
epoch_num = 10
def one_epoch_process(data_loader, is_train=True):
"""train or evaluation for one epoch
Args:
data_loader (DataLoader): data loader of train or val set
model (nn.Module): CNN model
is_train (bool, optional): flag of train or val. Defaults to True.
"""
with tqdm(enumerate(data_loader),
total=len(data_loader),
leave=True) as pbar_loss:
log_loss = AverageMeter()
log_top1 = AverageMeter()
log_top5 = AverageMeter()
correct = AverageMeter()
for batch_idx, (image, label) in pbar_loss:
pbar_loss.set_description("[{}]".format('train' if is_train else 'val'))
current_batch_size = image.size(0)
image, label = image.to(device), label.to(device)
y, ay, _ = model(image)
loss = criterion(y, label)
loss_ay = criterion(ay, label)
loss_all = (loss + loss_ay) / 2
log_loss.update(loss_all, current_batch_size)
if is_train:
optimizer.zero_grad()
loss_all.backward()
optimizer.step()
acc1, acc5 = accuracy(y, label, topk=(1, 5))
log_top1.update(acc1, current_batch_size)
log_top5.update(acc5, current_batch_size)
pbar_loss.set_postfix_str(
' | loss={:6.04f} acc top1={:6.04f} top5={:6.04f}'
' err top1={:6.04f} top5={:6.04f}'
''.format(
log_loss.avg,
log_top1.avg,
log_top5.avg,
100 - log_top1.avg,
100 - log_top5.avg,
))
def train():
model.train()
one_epoch_process(train_loader, is_train=True)
def validate():
model.eval()
with torch.no_grad():
one_epoch_process(val_loader, is_train=False)
with tqdm(range(epoch_num)) as pbar_epoch:
for epoch in pbar_epoch:
pbar_epoch.set_description("[Epoch %d]" % (epoch))
train()
validate()
結果.
[train]: 100%|██████████| 782/782 [00:59<00:00, 13.10it/s, | loss=0.7773 acc top1=78.2380 top5=97.6020 err top1=21.7620 top5=2.3980]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.31it/s, | loss=0.5266 acc top1=82.5600 top5=99.2100 err top1=17.4400 top5=0.7900]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.4180 acc top1=85.9420 top5=99.4820 err top1=14.0580 top5=0.5180]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.71it/s, | loss=0.4284 acc top1=85.2400 top5=99.4400 err top1=14.7600 top5=0.5600]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.82it/s, | loss=0.3465 acc top1=88.2740 top5=99.6420 err top1=11.7260 top5=0.3580]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.58it/s, | loss=0.4238 acc top1=85.8700 top5=99.4100 err top1=14.1300 top5=0.5900]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.98it/s, | loss=0.3014 acc top1=89.8680 top5=99.6780 err top1=10.1320 top5=0.3220]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.73it/s, | loss=0.3660 acc top1=87.6900 top5=99.6300 err top1=12.3100 top5=0.3700]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.99it/s, | loss=0.2664 acc top1=90.9200 top5=99.7340 err top1=9.0800 top5=0.2660]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.18it/s, | loss=0.3526 acc top1=88.3100 top5=99.6000 err top1=11.6900 top5=0.4000]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2407 acc top1=91.9240 top5=99.8040 err top1=8.0760 top5=0.1960]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.45it/s, | loss=0.3370 acc top1=89.0900 top5=99.6700 err top1=10.9100 top5=0.3300]
[train]: 100%|██████████| 782/782 [00:59<00:00, 13.20it/s, | loss=0.2209 acc top1=92.5220 top5=99.8440 err top1=7.4780 top5=0.1560]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.61it/s, | loss=0.2876 acc top1=90.2700 top5=99.7300 err top1=9.7300 top5=0.2700]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.88it/s, | loss=0.2077 acc top1=92.9760 top5=99.8880 err top1=7.0240 top5=0.1120]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.10it/s, | loss=0.3089 acc top1=90.0600 top5=99.7000 err top1=9.9400 top5=0.3000]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.97it/s, | loss=0.1881 acc top1=93.5180 top5=99.8740 err top1=6.4820 top5=0.1260]
[val]: 100%|██████████| 157/157 [00:05<00:00, 30.59it/s, | loss=0.3117 acc top1=90.1800 top5=99.7800 err top1=9.8200 top5=0.2200]
[train]: 100%|██████████| 782/782 [01:00<00:00, 12.86it/s, | loss=0.1755 acc top1=93.9980 top5=99.9020 err top1=6.0020 top5=0.0980]
[val]: 100%|██████████| 157/157 [00:05<00:00, 31.12it/s, | loss=0.3198 acc top1=89.7400 top5=99.7600 err top1=10.2600 top5=0.2400]
[Epoch 9]: 100%|██████████| 10/10 [10:57<00:00, 65.71s/it]