20
17

More than 3 years have passed since last update.

【pytorch-lightning入門】初めてのLit♬

Last updated at Posted at 2020-12-28

丁度一年前にpytorchの記事を書いた。
割と簡単に動かせたので、今回も簡単だろうと高をくくっていたので、ちょっと慌てた。
導入のページが、ちょっとな気がする。
でも、わかってしまうとむしろ参考①のアニメーションが秀逸なことに気が付いた。
ということで、これの解説しますかね。
※結局、やったことしか書けていない

【参考】
PyTorchLightning/pytorch-lightning
STEP-BY-STEP WALK-THROUGH
今回の環境は以下のとおりです。

>python -m pip install pytorch-lightning
...
Successfully installed fsspec-0.8.5 pytorch-lightning-1.1.2 tensorboard-2.4.0 tensorboard-plugin-wit-1.7.0 tqdm-4.55.0

>python -m pip install tensorboard
Requirement already satisfied:...
>>> import tensorboard
>>> tensorboard.__version__
'2.4.0'

>python
Python 3.7.5 (default, Oct 31 2019, 15:18:51) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import pytorch_lightning
>>> pytorch_lightning.__version__
'1.1.2'
>>> import torch
>>> torch.__version__
'1.7.1'

やったこと

・pytorch-lightningの肝
・Pytorch振り返り
・pytorch-lightning

・pytorch-lightningの肝

Lightning Philosophy
Lightning structures your deep learning code in 4 parts:
・Research code
・Engineering code
・Non-essential code
・Data code
これらをpytorchのコードから、再配置してClassに集約したんですね。
それが、上のアニメーションです。
※この動画貼れそうなので、解説する必要ないような気がしてきました

ということで、Pytorchで書かれたコードを上のような分析をして、以下のフレームに載せるだけで動きます。

pytorch-lightningテンプレート
※実際に動くMNISTのカテゴライズのコードにしました

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
import pytorch_lightning as pl

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

from torch.optim import Adam

class LitMNIST(LightningModule): #LightningModuleの継承が必須

    def __init__(self, data_dir='./'): #初期化が必要なものはここに記載
        super().__init__()
        self.data_dir=data_dir
        self.transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

        # mnist images are (1, 28, 28) (channels, width, height)
        #ネットワーク定義するための関数;forwardで使う
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x): #推論をする
        batch_size, channels, width, height = x.size()
        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx): #trainingする
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx): #汎化性能確認のため実施
        x, t = batch
        y = self(x)
        loss = F.nll_loss(y, t)
        preds = torch.argmax(y, dim=1)

        # Calling self.log will surface up scalars for you in TensorBoard
        #log記録;上で定義したself.val_accなどを記録
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx): #testの実行;最終的なlossとaccの確認
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self): #optimizerの定義
        return Adam(self.parameters(), lr=1e-3)    

    def prepare_data(self): #データダウンロード
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        mnist_full =MNIST(self.data_dir, train=True, transform=self.transform)
        n_train = int(len(mnist_full)*0.8)
        n_val = len(mnist_full)-n_train
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(mnist_full, [n_train, n_val])
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self): #trainデータ作成
        # prepare transforms standard to MNIST
        mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
        return DataLoader(mnist_train, batch_size=64)

    def val_dataloader(self): #valデータ作成
        mnist_val = MNIST(self.data_dir, train=False, transform=self.transform)
        return DataLoader(mnist_val, batch_size=64)

    def test_dataloader(self): #testデータ作成
        mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
        return DataLoader(mnist_test, batch_size=64)

#Gpu利用のためにdevice定義
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#乱数初期値固定
pl.seed_everything(0)
#netを定義
net = LitMNIST()
#GPUに載せる
model = net.to(device)
#Trainnigの回数などを定義
trainer = Trainer(gpus=1, max_epochs=10)
#fitting(学習)する
trainer.fit(model)
#testデータで検証する
results = trainer.test()
#検証結果出力
print(results)
>python Lit_MNIST.py
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2020-12-28 19:01:18.414240: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2020-12-28 19:01:18.414340: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name      | Type     | Params
---------------------------------------
0 | layer_1   | Linear   | 100 K
1 | layer_2   | Linear   | 33.0 K
2 | layer_3   | Linear   | 2.6 K
3 | train_acc | Accuracy | 0
4 | val_acc   | Accuracy | 0
5 | test_acc  | Accuracy | 0
---------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
Epoch 9: 100%|████| 1095/1095 [00:12<00:00, 87.94it/s, loss=0.02, v_num=62, val_loss=0.105, val_acc=0.977, my_loss_step=0.000118, my_loss_epoch=0.0252]
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [00:01<00:00, 107.82it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_acc': tensor(0.9758, device='cuda:0'),
 'val_loss': tensor(0.0901, device='cuda:0')}
--------------------------------------------------------------------------------
[{'val_loss': 0.09011910110712051, 'val_acc': 0.9757999777793884}]

tensorboardでlogを確認するには、以下の通りのコマンドで
ブラウザで
http://localhost:6006/
を見れば見える。

>tensorboard --logdir ./lightning_logs
...
TensorBoard 2.4.0 at http://localhost:6006/ (Press CTRL+C to quit)

tensorboard.png

・Pytorch振り返り

ここで終わると、なんだかいい感じですが、やはりpytorchでやれていたことが上では、出来ていません。
・上のコードだとnetworkの定義が煩雑です
・学習をコントロールしたい。まあ学習率は変更したい
・GANだと、複数のoptimizerに対応したい
ということで、一つずつ解決しました。

・上のコードだとnetworkの定義が煩雑です

networkを別classで与えてみます。
⇒出来ました
思っていた以上にすっきりして、大きなnetworkもまんま使えて便利です。
⇒もともとのpytorch-lightningの狙いであるModelに集中(分離して集中)して考えることが容易です

    def __init__(self, data_dir='./'):
        super().__init__()
        ...
        self.net = VGG16()
        ...

    def forward(self, x):
        return self.net(x)

・学習をコントロールしたい。まあ学習率は変更したい

これは三つ目と結局同じ部分ですが、optimizerの定義に関係しました。
参考は、以下の参考③のとおりです。
参考④は直接関係ありませんが、pytorchでの学習率の与え方のまとめです。
【参考】
OPTIMIZATION
PyTorchのSchedulerまとめ

ということで、以下のように書き換えました。

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(),lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        scheduler = {'scheduler': optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)}
        print("CF;Ir = ", optimizer.param_groups[0]['lr'])
        return [optimizer], [scheduler]

ほんとに変わっているか確認のために、以下で出力してみました。
ただし、このコードは1ステップ毎に実行されるので、一度確認してコメントアウトしています。

    def training_step(self, batch, batch_idx):
        x, t =batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        print('Ir =',self.optimizers().param_groups[0]['lr'])
        return loss

また、
class MyPrintingCallback(Callback):
で出力しようと試みましたが出来ませんでした。

・GANだと、複数のoptimizerに対応したい

これは、実際にはまだやれていませんので、次回の宿題にします。
しかし、上記のoptimizerの所で解説されていて、以下のようにやれば出来そうです。

【参考】
Use multiple optimizers (like GANs)
そして、複雑に設定したいときは結局以下の戦略を取るんですね。
Manual optimization

まとめ

・pytorch-lightningで遊んでみた
・簡単なものは容易に見えるが、奥は深いと感じた
・思うようにコードを操るのは時間がかかりそう

・次回は複数optimizerを駆使してGANを学習したいと思う
ちなみに、おまけにVGG16でのコードと実行結果を添付します。
本当は、このコード読まないとだけど、今回は、optimizerの所だけみました。
PyTorchLightning/pytorch-lightning

おまけ

>python torch_lightning_cifar10.py
cuda:0
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
        MaxPool2d-14            [-1, 128, 8, 8]               0
           Conv2d-15            [-1, 256, 8, 8]         295,168
      BatchNorm2d-16            [-1, 256, 8, 8]             512
             ReLU-17            [-1, 256, 8, 8]               0
           Conv2d-18            [-1, 256, 8, 8]         590,080
      BatchNorm2d-19            [-1, 256, 8, 8]             512
             ReLU-20            [-1, 256, 8, 8]               0
           Conv2d-21            [-1, 256, 8, 8]         590,080
      BatchNorm2d-22            [-1, 256, 8, 8]             512
             ReLU-23            [-1, 256, 8, 8]               0
        MaxPool2d-24            [-1, 256, 4, 4]               0
           Conv2d-25            [-1, 512, 4, 4]       1,180,160
      BatchNorm2d-26            [-1, 512, 4, 4]           1,024
             ReLU-27            [-1, 512, 4, 4]               0
           Conv2d-28            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-29            [-1, 512, 4, 4]           1,024
             ReLU-30            [-1, 512, 4, 4]               0
           Conv2d-31            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-32            [-1, 512, 4, 4]           1,024
             ReLU-33            [-1, 512, 4, 4]               0
        MaxPool2d-34            [-1, 512, 2, 2]               0
           Conv2d-35            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-36            [-1, 512, 2, 2]           1,024
             ReLU-37            [-1, 512, 2, 2]               0
           Conv2d-38            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-39            [-1, 512, 2, 2]           1,024
             ReLU-40            [-1, 512, 2, 2]               0
           Conv2d-41            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-42            [-1, 512, 2, 2]           1,024
             ReLU-43            [-1, 512, 2, 2]               0
        MaxPool2d-44            [-1, 512, 1, 1]               0
           Linear-45                  [-1, 512]         262,656
             ReLU-46                  [-1, 512]               0
           Linear-47                   [-1, 32]          16,416
             ReLU-48                   [-1, 32]               0
           Linear-49                   [-1, 10]             330
            VGG16-50                   [-1, 10]               0
================================================================
Total params: 15,002,538
Trainable params: 15,002,538
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 6.57
Params size (MB): 57.23
Estimated Total Size (MB): 63.82
----------------------------------------------------------------
Starting to init trainer!
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Trainer is init now
Files already downloaded and verified
Files already downloaded and verified
CF;Ir =  0.001
2020-12-28 20:36:51.786876: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2020-12-28 20:36:51.787050: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name      | Type     | Params
---------------------------------------
0 | net       | VGG16    | 15.0 M
1 | train_acc | Accuracy | 0
2 | val_acc   | Accuracy | 0
3 | test_acc  | Accuracy | 0
---------------------------------------
15.0 M    Trainable params
0         Non-trainable params
15.0 M    Total params
                                             dog   cat  deer plane


Epoch 0: 100%|██████████████████████████████████████████████████| 1563/1563 [01:08<00:00, 22.89it/s, loss=1.63, v_num=63, val_loss=1.58, val_acc=0.344]
Epoch 1: 100%|██████████████████████████████████████████████████| 1563/1563 [01:08<00:00, 22.72it/s, loss=1.41, v_num=63, val_loss=1.33, val_acc=0.481]
Epoch 2: 100%|███████████████████████████████████████████████████| 1563/1563 [01:08<00:00, 22.66it/s, loss=1.17, v_num=63, val_loss=1.07, val_acc=0.62]
Epoch 3: 100%|████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.61it/s, loss=0.994, v_num=63, val_loss=0.922, val_acc=0.684]
Epoch 4: 100%|████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.62it/s, loss=0.833, v_num=63, val_loss=0.839, val_acc=0.713]
Epoch 5: 100%|████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.40it/s, loss=0.601, v_num=63, val_loss=0.649, val_acc=0.775]
Epoch 6: 100%|█████████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.14it/s, loss=0.569, v_num=63, val_loss=0.62, val_acc=0.788]
Epoch 7: 100%|██████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.58it/s, loss=0.519, v_num=63, val_loss=0.595, val_acc=0.8]
Epoch 8: 100%|████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.57it/s, loss=0.435, v_num=63, val_loss=0.599, val_acc=0.803]
Epoch 9: 100%|█████████████████████████████████████████████████| 1563/1563 [01:09<00:00, 22.57it/s, loss=0.43, v_num=63, val_loss=0.579, val_acc=0.809]
Epoch 10: 100%|████████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.29it/s, loss=0.289, v_num=63, val_loss=0.58, val_acc=0.814]
Epoch 11: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.08it/s, loss=0.344, v_num=63, val_loss=0.595, val_acc=0.817]
Epoch 12: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.09it/s, loss=0.337, v_num=63, val_loss=0.601, val_acc=0.816]
Epoch 13: 100%|██████████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.10it/s, loss=0.28, v_num=63, val_loss=0.6, val_acc=0.817]
Epoch 14: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.10it/s, loss=0.288, v_num=63, val_loss=0.603, val_acc=0.819]
Epoch 15: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.10it/s, loss=0.296, v_num=63, val_loss=0.605, val_acc=0.819]
Epoch 16: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.09it/s, loss=0.252, v_num=63, val_loss=0.604, val_acc=0.819]
do something when training ends███████████████████████████████| 1563/1563 [01:10<00:00, 22.10it/s, loss=0.253, v_num=63, val_loss=0.616, val_acc=0.817]
Epoch 18: 100%|███████████████████████████████████████████████| 1563/1563 [01:10<00:00, 22.13it/s, loss=0.221, v_num=63, val_loss=0.611, val_acc=0.819]
Files already downloaded and verified
Files already downloaded and verified
Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:05<00:00, 54.47it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS███████████████████████████████████████████████████████████████████████████████████████████ | 310/313 [00:05<00:00, 55.76it/s]
{'val_acc': tensor(0.8030, device='cuda:0'),
 'val_loss': tensor(0.6054, device='cuda:0')}
--------------------------------------------------------------------------------
[{'val_loss': 0.6053957343101501, 'val_acc': 0.8029999732971191}]
elapsed time: 1357.867 [sec]
import argparse
import time
import numpy as np
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
import torch.optim as optim
from torchsummary import summary
import matplotlib.pyplot as plt
import tensorboard
#from net_cifar10 import Net_cifar10
from net_vgg16 import VGG16
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.pause(1)
    plt.close()
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import EarlyStopping, ProgressBar,LearningRateMonitor

class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print('Starting to init trainer!')
    def on_init_end(self, trainer):
        print('Trainer is init now')
    def on_epoch_end(self, trainer, pl_module):
        #print("pl_module = ",pl_module)
        #print("trainer = ",trainer)
        print('')
    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')    
class LitProgressBar(ProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        bar.set_description('running validation ...')
        return bar
class Net(pl.LightningModule):
    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.dims = (3, 32, 32)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        #self.net = Net_cifar10()
        self.net = VGG16()
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()
    def forward(self, x):
        return self.net(x)
    def training_step(self, batch, batch_idx):
        x, t =batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        #print('Ir =',self.optimizers().param_groups[0]['lr'])
        self.log('test_loss', loss, on_step = False, on_epoch = True)
        self.log('test_acc', self.val_acc(y, t), on_step = False, on_epoch = True)
        return loss
    def validation_step(self, batch, batch_idx):
        x, t = batch
        y = self(x)
        loss = F.cross_entropy(y, t)
        preds = torch.argmax(y, dim=1)
        # Calling self.log will surface up scalars for you in TensorBoard
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', self.val_acc(y,t), prog_bar=True)
        return loss
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    def configure_optimizers(self):
        #optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(self.parameters(),lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
        #scheduler = {'scheduler': optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: 0.95 ** epoch)}
        scheduler = {'scheduler': optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)}
        print("CF;Ir = ", optimizer.param_groups[0]['lr'])
        return [optimizer], [scheduler]
    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)
    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full =CIFAR10(self.data_dir, train=True, transform=self.transform)
            n_train = int(len(cifar_full)*0.8)
            n_val = len(cifar_full)-n_train
            self.cifar_train, self.cifar_val = torch.utils.data.random_split(cifar_full, [n_train, n_val])
        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    def train_dataloader(self):
        #classes = tuple(np.linspace(0, 9, 10, dtype=np.uint8))
        classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        trainloader = DataLoader(self.cifar_train, batch_size=32)
        # get some random training images
        dataiter = iter(trainloader)
        images, labels = dataiter.next()
        # show images
        imshow(torchvision.utils.make_grid(images))
        # print labels
        print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
        return trainloader
    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=32)
    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=32)

def main():
    '''
    main
    '''
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)
    # model
    net = Net()
    model = net.to(device)  #for gpu
    summary(model,(3,32,32))
    early_stopping = EarlyStopping('val_acc')  #('val_loss'
    bar = LitProgressBar(refresh_rate = 10, process_position = 1)
    #lr_monitor = LearningRateMonitor(logging_interval='step')

    trainer = pl.Trainer(gpus=1, max_epochs=20,callbacks=[MyPrintingCallback(),early_stopping, bar]) # progress_bar_refresh_rate=10
    trainer.fit(model)
    results = trainer.test()
    print(results)

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 16, 16]               0
        MaxPool2d-14            [-1, 128, 8, 8]               0
           Conv2d-15            [-1, 256, 8, 8]         295,168
      BatchNorm2d-16            [-1, 256, 8, 8]             512
             ReLU-17            [-1, 256, 8, 8]               0
           Conv2d-18            [-1, 256, 8, 8]         590,080
      BatchNorm2d-19            [-1, 256, 8, 8]             512
             ReLU-20            [-1, 256, 8, 8]               0
           Conv2d-21            [-1, 256, 8, 8]         590,080
      BatchNorm2d-22            [-1, 256, 8, 8]             512
             ReLU-23            [-1, 256, 8, 8]               0
        MaxPool2d-24            [-1, 256, 4, 4]               0
           Conv2d-25            [-1, 512, 4, 4]       1,180,160
      BatchNorm2d-26            [-1, 512, 4, 4]           1,024
             ReLU-27            [-1, 512, 4, 4]               0
           Conv2d-28            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-29            [-1, 512, 4, 4]           1,024
             ReLU-30            [-1, 512, 4, 4]               0
           Conv2d-31            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-32            [-1, 512, 4, 4]           1,024
             ReLU-33            [-1, 512, 4, 4]               0
        MaxPool2d-34            [-1, 512, 2, 2]               0
           Conv2d-35            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-36            [-1, 512, 2, 2]           1,024
             ReLU-37            [-1, 512, 2, 2]               0
           Conv2d-38            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-39            [-1, 512, 2, 2]           1,024
             ReLU-40            [-1, 512, 2, 2]               0
           Conv2d-41            [-1, 512, 2, 2]       2,359,808
      BatchNorm2d-42            [-1, 512, 2, 2]           1,024
             ReLU-43            [-1, 512, 2, 2]               0
        MaxPool2d-44            [-1, 512, 1, 1]               0
           Linear-45                  [-1, 512]         262,656
             ReLU-46                  [-1, 512]               0
           Linear-47                   [-1, 32]          16,416
             ReLU-48                   [-1, 32]               0
           Linear-49                   [-1, 10]             330
================================================================
Total params: 15,002,538
Trainable params: 15,002,538
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 6.57
Params size (MB): 57.23
Estimated Total Size (MB): 63.82
----------------------------------------------------------------
Ir =  0.00025
[1,   200] loss: 1.946  train_acc: 31.54 % val_acc: 31.29 %
[1,   400] loss: 1.691  train_acc: 36.00 % val_acc: 35.82 %
[1,   600] loss: 1.570  train_acc: 36.58 % val_acc: 36.43 %
...
[20,   800] loss: 0.017  train_acc: 99.66 % val_acc: 83.27 %
[20,  1000] loss: 0.016  train_acc: 99.61 % val_acc: 83.11 %
[20,  1200] loss: 0.014  train_acc: 99.61 % val_acc: 83.11 %
Finished Training
Accuracy: 83.27 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship  ship plane
Accuracy of plane : 84 %
Accuracy of   car : 91 %
Accuracy of  bird : 75 %
Accuracy of   cat : 66 %
Accuracy of  deer : 81 %
Accuracy of   dog : 78 %
Accuracy of  frog : 85 %
Accuracy of horse : 85 %
Accuracy of  ship : 94 %
Accuracy of truck : 92 %
elapsed time: 4545.419 [sec]
        MaxPool2d-44            [-1, 512, 1, 1]               0
           Linear-45                  [-1, 512]         262,656
             ReLU-46                  [-1, 512]               0
          Dropout-47                  [-1, 512]               0
           Linear-48                   [-1, 32]          16,416
             ReLU-49                   [-1, 32]               0
          Dropout-50                   [-1, 32]               0
           Linear-51                   [-1, 10]             330
================================================================
Total params: 15,002,538
Trainable params: 15,002,538
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 6.58
Params size (MB): 57.23
Estimated Total Size (MB): 63.82
----------------------------------------------------------------
Ir =  0.001
Ir =  0.00025
[1,   200] loss: 2.173  Accuracy: 16.72 %
[1,   400] loss: 2.036  Accuracy: 19.80 %
[1,   600] loss: 1.962  Accuracy: 22.53 %
[1,   800] loss: 1.892  Accuracy: 26.34 %
...
[20,   800] loss: 0.418  Accuracy: 75.94 %
[20,  1000] loss: 0.433  Accuracy: 75.89 %
[20,  1200] loss: 0.429  Accuracy: 75.61 %
Finished Training
Accuracy: 76.10 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship  ship  ship
Accuracy of plane : 76 %
Accuracy of   car : 90 %
Accuracy of  bird : 61 %
Accuracy of   cat : 46 %
Accuracy of  deer : 69 %
Accuracy of   dog : 62 %
Accuracy of  frog : 85 %
Accuracy of horse : 82 %
Accuracy of  ship : 91 %
Accuracy of truck : 88 %
elapsed time: 1903.856 [sec]
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 6, 28, 28]             456
         MaxPool2d-2            [-1, 6, 14, 14]               0
            Conv2d-3           [-1, 16, 10, 10]           2,416
         MaxPool2d-4             [-1, 16, 5, 5]               0
            Linear-5                  [-1, 120]          48,120
            Linear-6                   [-1, 84]          10,164
            Linear-7                   [-1, 10]             850
================================================================
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.31
----------------------------------------------------------------
Ir =  0.00025
[1,   200] loss: 2.126  train_acc: 29.71 % val_acc: 29.66 %
[1,   400] loss: 1.889  train_acc: 33.12 % val_acc: 33.43 %
[1,   600] loss: 1.801  train_acc: 35.55 % val_acc: 35.70 %
...
[20,   800] loss: 1.156  train_acc: 59.33 % val_acc: 55.87 %
[20,  1000] loss: 1.133  train_acc: 59.28 % val_acc: 55.71 %
[20,  1200] loss: 1.167  train_acc: 59.23 % val_acc: 55.81 %
Finished Training
Accuracy: 56.42 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat   car   car  ship
Accuracy of plane : 65 %
Accuracy of   car : 74 %
Accuracy of  bird : 40 %
Accuracy of   cat : 36 %
Accuracy of  deer : 36 %
Accuracy of   dog : 45 %
Accuracy of  frog : 69 %
Accuracy of horse : 63 %
Accuracy of  ship : 73 %
Accuracy of truck : 57 %
elapsed time: 1277.975 [sec]
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
            Conv2d-2           [-1, 64, 32, 32]          36,928
         MaxPool2d-3           [-1, 64, 16, 16]               0
            Conv2d-4          [-1, 128, 16, 16]          73,856
            Conv2d-5          [-1, 128, 16, 16]         147,584
         MaxPool2d-6            [-1, 128, 8, 8]               0
            Conv2d-7            [-1, 256, 8, 8]         295,168
            Conv2d-8            [-1, 256, 8, 8]         590,080
            Conv2d-9            [-1, 256, 8, 8]         590,080
           Conv2d-10            [-1, 256, 8, 8]         590,080
        MaxPool2d-11            [-1, 256, 4, 4]               0
           Linear-12                 [-1, 1024]       4,195,328
           Linear-13                 [-1, 1024]       1,049,600
           Linear-14                   [-1, 10]          10,250
================================================================
Total params: 7,580,746
Trainable params: 7,580,746
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 2.23
Params size (MB): 28.92
Estimated Total Size (MB): 31.16
----------------------------------------------------------------
Ir =  0.00025
Accuracy: 79.23 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship  ship plane
Accuracy of plane : 78 %
Accuracy of   car : 90 %
Accuracy of  bird : 72 %
Accuracy of   cat : 64 %
Accuracy of  deer : 72 %
Accuracy of   dog : 69 %
Accuracy of  frog : 82 %
Accuracy of horse : 83 %
Accuracy of  ship : 88 %
Accuracy of truck : 93 %
elapsed time: 1068.527 [sec]
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 256, 28, 28]          19,456
         MaxPool2d-2          [-1, 256, 14, 14]               0
       BatchNorm2d-3          [-1, 256, 14, 14]             512
            Conv2d-4          [-1, 512, 10, 10]       3,277,312
         MaxPool2d-5            [-1, 512, 5, 5]               0
       BatchNorm2d-6            [-1, 512, 5, 5]           1,024
            Conv2d-7           [-1, 1924, 4, 4]       3,942,276
         MaxPool2d-8           [-1, 1924, 2, 2]               0
       BatchNorm2d-9           [-1, 1924, 2, 2]           3,848
           Linear-10                  [-1, 160]       1,231,520
           Linear-11                   [-1, 10]           1,610
================================================================
Total params: 8,477,558
Trainable params: 8,477,558
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 3.24
Params size (MB): 32.34
Estimated Total Size (MB): 35.59
----------------------------------------------------------------
Ir =  0.00025
[1,   200] loss: 1.651  train_acc: 49.06 % val_acc: 47.93 %
[1,   400] loss: 1.375  train_acc: 58.22 % val_acc: 55.22 %
[1,   600] loss: 1.222  train_acc: 62.61 % val_acc: 59.38 %
...
[20,   800] loss: 0.000  train_acc: 100.00 % val_acc: 79.74 %
[20,  1000] loss: 0.000  train_acc: 100.00 % val_acc: 79.77 %
[20,  1200] loss: 0.000  train_acc: 100.00 % val_acc: 79.79 %
Finished Training
Accuracy: 80.05 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship  ship plane
Accuracy of plane : 85 %
Accuracy of   car : 93 %
Accuracy of  bird : 73 %
Accuracy of   cat : 62 %
Accuracy of  deer : 74 %
Accuracy of   dog : 68 %
Accuracy of  frog : 88 %
Accuracy of horse : 86 %
Accuracy of  ship : 89 %
Accuracy of truck : 88 %
elapsed time: 3917.718 [sec]

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1          [-1, 256, 28, 28]          19,456
         MaxPool2d-2          [-1, 256, 14, 14]               0
       BatchNorm2d-3          [-1, 256, 14, 14]             512
            Conv2d-4          [-1, 512, 10, 10]       3,277,312
         MaxPool2d-5            [-1, 512, 5, 5]               0
       BatchNorm2d-6            [-1, 512, 5, 5]           1,024
            Linear-7                  [-1, 160]       2,048,160
            Linear-8                   [-1, 10]           1,610
================================================================
Total params: 5,348,074
Trainable params: 5,348,074
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 2.88
Params size (MB): 20.40
Estimated Total Size (MB): 23.30
----------------------------------------------------------------
Ir =  0.00025
[1,   200] loss: 1.654  train_acc: 51.47 % val_acc: 49.53 %
[1,   400] loss: 1.380  train_acc: 59.47 % val_acc: 55.60 %
[1,   600] loss: 1.223  train_acc: 63.02 % val_acc: 58.57 %
...
[20,   800] loss: 0.001  train_acc: 100.00 % val_acc: 77.33 %
[20,  1000] loss: 0.001  train_acc: 100.00 % val_acc: 77.37 %
[20,  1200] loss: 0.001  train_acc: 100.00 % val_acc: 77.42 %
Finished Training
Accuracy: 77.37 %
GroundTruth:    cat  ship  ship plane
Predicted:    cat  ship  ship plane
Accuracy of plane : 84 %
Accuracy of   car : 89 %
Accuracy of  bird : 65 %
Accuracy of   cat : 51 %
Accuracy of  deer : 71 %
Accuracy of   dog : 65 %
Accuracy of  frog : 85 %
Accuracy of horse : 83 %
Accuracy of  ship : 88 %
Accuracy of truck : 90 %
elapsed time: 3419.362 [sec]
20
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
17