1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【pytorch-lightning入門】ゼロから作るMNIST及びCifar10のAutoencoder♬

Last updated at Posted at 2020-12-30

以前は、Kerasでやってみたことをやってみようということで、まずは定番Autoencoderに挑戦してみます。
pytorch-lightningの解説から入ることにします。
Step 1: Add these imports
Step 2: Define a LightningModule (nn.Module subclass)
Step 3: Train!
And without changing a single line of code, you could run on GPUs/TPUs
For advanced users, you can still own complex training loops
を最初に実行してみます。

【参考】
Step 1: Add these imports@PyTorchLightning/pytorch-lightning
つなげると以下のようなコードになり、実行できました。

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
# trainer = pl.Trainer()
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))    

# torchscript
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
print('training_finished')
>python simple_autoencoder.py
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
C:\Users\user\Anaconda3\lib\site-packages\pytorch_lightning\utilities\distributed.py:49: UserWarning: you passed in a val_dataloader but have no validation_step. Skipping validation loop
  warnings.warn(*args, **kwargs)
2020-12-29 22:52:28.353096: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2020-12-29 22:52:28.353187: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 100 K
1 | decoder | Sequential | 101 K
---------------------------------------
202 K     Trainable params
0         Non-trainable params
202 K     Total params
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████| 55000/55000 [02:57<00:00, 310.28it/s, loss=0.0406, v_num=98]
training_finished

validation_stepが無いとか、cudart dlerrorが出ていますが、コードに書いたことは実行出来て、epoch0が実行出来たことが分かります。そして、lossも減少していくのが見えると思います。
このコードは解説無しで通過し、ここを出発点とします。

やったこと

・MNIST_autoencoderを完成させる
・Cifar10に適用する

・MNIST_autoencoderを完成させる

・main関数から起動させる
・validation_step, test_stepをclass LitAutoEncoderに追記
・データ読込はclass LitAutoEncoderに入れる
(・データ分割をtrain, val, testに分割)
・初期画像確認
・結果表示をする
・学習時の表示制御にcallbackを使う
・networkをclassにする

・main関数から起動させる

これは、定番なので以下を導入する。

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    

・validation_step, test_stepをclass LitAutoEncoderに追記

追記すると、学習後epoch毎にvalidation_stepが自動実行してくれる。
そして、epoch終了後test_stepが自動実行される。
また、training_stepを途中でctrl-Cして中断しても、test_step以下実行が継続する。

class LitAutoEncoder(pl.LightningModule):
...
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss) #ここだけtestに変更
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

・データ読込はclass LitAutoEncoderに入れる

これは、初めてのLitと同じコードです。

class LitAutoEncoder(pl.LightningModule):
...
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        mnist_full =MNIST(self.data_dir, train=True, transform=self.transform)
        n_train = int(len(mnist_full)*0.8)
        n_val = len(mnist_full)-n_train
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(mnist_full, [n_train, n_val])
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.mnist_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        return self.trainloader
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, shuffle=False, batch_size=32, num_workers=0)
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.mnist_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader

・初期画像確認

以下のコードを追記

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.pause(1)
    plt.close()

def __init__(self, data_dir='./'):
def train_dataloader(self):
self.classes = を追記して、train_datalorderの所で出力imshow()を呼び出す。

class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
...
    def train_dataloader(self):
        self.trainloader = DataLoader(self.mnist_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        # get some random training images
        dataiter = iter(self.trainloader)
        images, labels = dataiter.next()
        # show images
        imshow(torchvision.utils.make_grid(images))
        # print labels
        print(' '.join('%5s' % self.classes[labels[j]] for j in range(4)))
        return self.trainloader

なお、imshow()の位置は、学習後だと以下で動きました。
このために、self.trainloaderとselfをつけています。
※結果表示の伏線です

trainer.fit(autoencoder)    

dataiter = iter(autoencoder.trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

・結果表示をする

まず、学習したモデルを参考②のとおり保存し、再度読み出します。
※保存しなくても続けられるけど、のちのち単独で予測したい
 torch.jit.saveも行うようになっているが、これは今回使っていない。何に使うかは不明?以下のとおりだそうです。
TORCH.JIT.SAVE
「Save an offline version of this module for use in a separate process. The saved module serializes all of the methods, submodules, parameters, and attributes of this module. It can be loaded into the C++ API using torch::jit::load(filename) or into the Python API with torch.jit.load.」

【参考】
Manual saving@SAVING AND LOADING WEIGHTS

trainer.save_checkpoint("example.ckpt")

PATH = 'example.ckpt'
pretrained_model = autoencoder.load_from_checkpoint(PATH)
pretrained_model.freeze()
pretrained_model.eval()

このpretrained_modelで、Autoencodeしてみます。
その前に、一度学習していない、テストデータのオリジナルを出力します。

latent_dim,ver = 3, 1
dataiter = iter(autoencoder.testloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images),'original_images_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

original_images_3_1.png

encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,28*28))
decode_img = pretrained_model.decoder(encode_img)
imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,1,28,28)), 'original_autoencode_preds_mnist_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

以下はepoch=1回で学習した結果です。いい加減に間違っています。
original_autoencode_preds_mnist_3_1.png

・学習時の表示制御にcallbackを使う

epochを増やすと2回目以降学習時の表示が上書きされ、気に入らないので、callbackで行替えするようにしました。

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')
...
def main():    
    autoencoder = LitAutoEncoder()

    #trainer = pl.Trainer()
    trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[MyPrintingCallback()])
    trainer.fit(autoencoder)

・networkをclassにする

ここは、前回と同じ。
networkがencoder()とdecorder()と二つあるところが違うけど、つまり必要なnetworkはいくつでも使えるということですね。

from net_encoder_decoder_mnist import Encoder, Decoder

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, data_dir='./'):
...
        self.encoder = Encoder() #nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 32))
        self.decoder = Decoder() #nn.Sequential(nn.Linear(32, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

networkは別途以下で定義している。

net_encoder_decoder_mnist.py
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), 
            nn.ReLU(), 
            nn.Linear(128, 32)
        )
    def forward(self, x):
        x = self.encoder(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(), 
            nn.Linear(128, 28 * 28)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

こうして、Autoencoderも最初の一歩が踏み出せました。
epoch=10の結果は以下の通りです。
中間層の潜在空間のチャネル=3のとき
original_autoencode_preds_mnist_3_10.png
潜在空間のチャネル=32のとき
original_autoencode_preds_mnist_32_10.png
そして、正解(input)は以下の通りです。
original_images_3_10.png
コード全体は、おまけに掲載します。

・Cifar10に適用する

・データをcifar10に変更する
・Cifar10(入力3ch)に対応させる
・networkをcnnに拡張する

・データをcifar10に変更する

これは、おまけ掲載のAutoencoder for cifar10 in pytorch-lightning; flatten-modelのとおりです。
データや入力の次元を対応していけば、出来ると思います。

epoch=1 潜在空間次元32の結果は、以下の通り
オリジナル
original_images_cifar10_32_1.png
エンコードイメージ
original_autoencode_preds_cifar10_32_1.png

>python simple_autoencoder3.py
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Files already downloaded and verified
Files already downloaded and verified
2020-12-30 16:44:59.037624: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2020-12-30 16:44:59.037743: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 397 K
1 | decoder | Decoder | 400 K
------------------------------------
797 K     Trainable params
0         Non-trainable params
797 K     Total params
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:12<00:00, 126.63it/s, loss=0.0542, v_num=145]
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:12<00:00, 125.70it/s, loss=0.0542, v_num=145]
training_finished
 bird  ship   car  frog
Files already downloaded and verified
Files already downloaded and verified
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:01<00:00, 185.81it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0541, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_loss': 0.054066576063632965}]
  cat  ship  ship plane
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
elapsed time: 34.649 [sec]

networkは以下の通り

net_encoder_decoder_1D_cifar10.py
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(3*32 * 32, 128), 
            nn.ReLU(), 
            nn.Linear(128, 32)
        )
    def forward(self, x):
        x = self.encoder(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(), 
            nn.Linear(128, 3*32 * 32)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

・Cifar10(入力3ch)に対応させる・networkをcnnに拡張する

まず、networkの変更は、自由だけど、以下が一つの解です。

import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 64,
                               kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 256, out_channels = 16,
                                          kernel_size = 2, stride = 2, padding = 0),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 3,
                                          kernel_size = 2, stride = 2)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

これの以下でモデルの構成を出力してみます。

summary(autoencoder.encoder,(3,32,32))
summary(autoencoder.decoder,(256,8,8))
summary(autoencoder,(3,32,32))
print(autoencoder)

やってみると、以下の結果を得ます。
※print(autoencoder)が情報が割と多いのが分かります。

>python simple_autoencoder4.py
cuda:0
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
              ReLU-2           [-1, 64, 32, 32]               0
         MaxPool2d-3           [-1, 64, 16, 16]               0
       BatchNorm2d-4           [-1, 64, 16, 16]             128
            Conv2d-5          [-1, 256, 16, 16]         147,712
              ReLU-6          [-1, 256, 16, 16]               0
         MaxPool2d-7            [-1, 256, 8, 8]               0
       BatchNorm2d-8            [-1, 256, 8, 8]             512
================================================================
Total params: 150,144
Trainable params: 150,144
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 2.50
Params size (MB): 0.57
Estimated Total Size (MB): 3.08
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ConvTranspose2d-1           [-1, 16, 16, 16]          16,400
   ConvTranspose2d-2            [-1, 3, 32, 32]             195
================================================================
Total params: 16,595
Trainable params: 16,595
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.06
Forward/backward pass size (MB): 0.05
Params size (MB): 0.06
Estimated Total Size (MB): 0.18
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]           1,792
              ReLU-2           [-1, 64, 32, 32]               0
         MaxPool2d-3           [-1, 64, 16, 16]               0
       BatchNorm2d-4           [-1, 64, 16, 16]             128
            Conv2d-5          [-1, 256, 16, 16]         147,712
              ReLU-6          [-1, 256, 16, 16]               0
         MaxPool2d-7            [-1, 256, 8, 8]               0
       BatchNorm2d-8            [-1, 256, 8, 8]             512
           Encoder-9            [-1, 256, 8, 8]               0
  ConvTranspose2d-10           [-1, 16, 16, 16]          16,400
  ConvTranspose2d-11            [-1, 3, 32, 32]             195
          Decoder-12            [-1, 3, 32, 32]               0
================================================================
Total params: 166,739
Trainable params: 166,739
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 2.70
Params size (MB): 0.64
Estimated Total Size (MB): 3.35
----------------------------------------------------------------
LitAutoEncoder(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU()
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (decoder): Decoder(
    (decoder): Sequential(
      (0): ConvTranspose2d(256, 16, kernel_size=(2, 2), stride=(2, 2))
      (1): ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2))
    )
  )
)

そして、この場合のAutoencoderの結果は以下のとおりです。
original_images_cifar10_16384_1.png
original_autoencode_preds_cifar10_16384_1.png
1回のepochで比較的綺麗な画像が得られています。
そして、10epochでは以下の通りになりました。
もはや、再現度はかなり良くなっています。
original_autoencode_preds_cifar10_16384_10.png
ちなみに、学習から出力までの全体時間は3分ちょっとで終了しています。

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 150 K
1 | decoder | Decoder | 16.6 K
------------------------------------
166 K     Trainable params
0         Non-trainable params
166 K     Total params
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.14it/s, loss=0.0133, v_num=150]
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.88it/s, loss=0.0102, v_num=150]
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 95.17it/s, loss=0.00861, v_num=150]
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 95.22it/s, loss=0.0105, v_num=150]
Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 95.17it/s, loss=0.00785, v_num=150]
Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.58it/s, loss=0.00782, v_num=150]
Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.79it/s, loss=0.0073, v_num=150]
Epoch 7: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.88it/s, loss=0.00723, v_num=150]
Epoch 8: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.70it/s, loss=0.00721, v_num=150]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.77it/s, loss=0.00689, v_num=150]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████| 1563/1563 [00:16<00:00, 94.25it/s, loss=0.00689, v_num=150]
training_finished
 bird   dog  frog horse
Files already downloaded and verified
Files already downloaded and verified
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:01<00:00, 174.33it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0043, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_loss': 0.004303526598960161}]
  cat  ship  ship plane
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
elapsed time: 189.072 [sec]

実は、以下のコードを見るとdef forward(self, x):異なるのが分かります。
ネットワークのsummaryを出力するために変更しました。
なお、これは前のままでも学習するし、今まで見てきたようにpredsも計算できます。

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        x = self.decoder(embedding)
        return x

あと、潜在空間は1次元にすることもできますが、どういう場合にする必要があるのかを考えるとAutoencodrだと、このままでもいいのかなと思っています。
※利用シーンで変換できる

まとめ

・Cifr10のAutoenderをLitで作成してみた
・Kerasの場合と同じようによい精度が得れてた
・基本のコードは、前回のカテゴライズのコードと酷似しており、再利用性が高いと感じる

・これをベースにdenoising, coloring, cGANなど、いろいろ試してみようと思う

おまけ

Autoencoder for mnist in pytorch-lightning

import os
import time
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from net_encoder_decoder_mnist import Encoder, Decoder

# functions to show an image
def imshow(img,file='', text_=''):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy() #img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.text(x = 3, y = 2, s = text_, c = "red")
    #plt.imshow(npimg)
    plt.pause(3)
    if file != '':
        plt.savefig(file+'.png')
    plt.close()

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')

    
class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])
        
        self.encoder = Encoder() #nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 32))
        self.decoder = Decoder() #nn.Sequential(nn.Linear(32, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
 
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        mnist_full =MNIST(self.data_dir, train=True, transform=self.transform)
        n_train = int(len(mnist_full)*0.8)
        n_val = len(mnist_full)-n_train
        self.mnist_train, self.mnist_val = torch.utils.data.random_split(mnist_full, [n_train, n_val])
        self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.mnist_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        # get some random training images
        return self.trainloader
    
    def val_dataloader(self):
        return DataLoader(self.mnist_val, shuffle=False, batch_size=32, num_workers=0)
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.mnist_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader

    
def main():    
    autoencoder = LitAutoEncoder()

    #trainer = pl.Trainer()
    trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[MyPrintingCallback()])
    trainer.fit(autoencoder) #, DataLoader(train), DataLoader(val))    
    print('training_finished')

    dataiter = iter(autoencoder.trainloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'mnist_initial',text_='original')
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    results = trainer.test(autoencoder)
    print(results)

    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images), 'mnist_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

    # torchscript
    torch.jit.save(autoencoder.to_torchscript(), "model.pt")
    trainer.save_checkpoint("example.ckpt")

    PATH = 'example.ckpt'
    pretrained_model = autoencoder.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    pretrained_model.eval()

    latent_dim,ver = 32, 10
    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'original_images_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

    encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,28*28))
    decode_img = pretrained_model.decoder(encode_img)
    imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,1,28,28)), 'original_autoencode_preds_mnist_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    

実行結果

>python simple_autoencoder2.py
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2020-12-30 15:54:08.278635: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2020-12-30 15:54:08.278724: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 104 K
1 | decoder | Decoder | 105 K
------------------------------------
209 K     Trainable params
0         Non-trainable params
209 K     Total params
Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 142.55it/s, loss=0.185, v_num=142]
Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.58it/s, loss=0.178, v_num=142]
Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.63it/s, loss=0.161, v_num=142]
Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.62it/s, loss=0.156, v_num=142]
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 142.94it/s, loss=0.151, v_num=142]
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.26it/s, loss=0.142, v_num=142]
Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.26it/s, loss=0.136, v_num=142]
Epoch 7: 100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.29it/s, loss=0.14, v_num=142]
Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.57it/s, loss=0.138, v_num=142]
Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 143.63it/s, loss=0.132, v_num=142]
Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:13<00:00, 142.42it/s, loss=0.132, v_num=142]
training_finished
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
    6     1     1     5
Testing: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:01<00:00, 208.94it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.1300, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_loss': 0.13000936806201935}]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
    7     2     1     0
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
elapsed time: 149.029 [sec]

Autoencoder for cifar10 in pytorch-lightning; flatten-model

simple_autoencoder3.py
import os
import time
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision.datasets import CIFAR10 #MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from net_encoder_decoder_1D_cifar10 import Encoder, Decoder
# from net_encoder_decoder_mnist import Encoder, Decoder

# functions to show an image
def imshow(img,file='', text_=''):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy() #img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.text(x = 3, y = 2, s = text_, c = "red")
    #plt.imshow(npimg)
    plt.pause(3)
    if file != '':
        plt.savefig(file+'.png')
    plt.close()

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')

    
class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        #self.classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
        self.dims = (3, 32, 32) #(1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        #self.transform=transforms.Compose([transforms.ToTensor(),
        #                      transforms.Normalize((0.1307,), (0.3081,))])
        
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
 
    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        cifar10_full =CIFAR10(self.data_dir, train=True, transform=self.transform)
        n_train = int(len(cifar10_full)*0.8)
        n_val = len(cifar10_full)-n_train
        self.cifar10_train, self.cifar10_val = torch.utils.data.random_split(cifar10_full, [n_train, n_val])
        self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        # get some random training images
        return self.trainloader
    
    def val_dataloader(self):
        return DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader

    
def main():    
    autoencoder = LitAutoEncoder()

    #trainer = pl.Trainer()
    trainer = pl.Trainer(max_epochs=1, gpus=1, callbacks=[MyPrintingCallback()])
    trainer.fit(autoencoder) #, DataLoader(train), DataLoader(val))    
    print('training_finished')

    dataiter = iter(autoencoder.trainloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'cifar10_initial',text_='original')
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    results = trainer.test(autoencoder)
    print(results)

    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

    # torchscript
    torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
    trainer.save_checkpoint("example_cifar10.ckpt")

    PATH = 'example_cifar10.ckpt'
    pretrained_model = autoencoder.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    pretrained_model.eval()

    latent_dim,ver = 32, 1
    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

    encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,3*32*32))
    decode_img = pretrained_model.decoder(encode_img)
    imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32,32)), 'original_autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    

上記コードは以下のモデルに対応しています。

net_encoder_decoder_1D_cifar10.py
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(3*32 * 32, 128), 
            nn.ReLU(), 
            nn.Linear(128, 32)
        )
    def forward(self, x):
        x = self.encoder(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(32, 128),
            nn.ReLU(), 
            nn.Linear(128, 3*32 * 32)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x

Autoencoder for cifar10 in pytorch-lightning; 2D-model

import os
import time
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision.datasets import CIFAR10 #MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchsummary import summary

from net_encoder_decoder2D import Encoder, Decoder
# from net_encoder_decoder_1D_cifar10 import Encoder, Decoder
# from net_encoder_decoder_mnist import Encoder, Decoder

# functions to show an image
def imshow(img,file='', text_=''):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy() #img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.text(x = 3, y = 2, s = text_, c = "red")
    #plt.imshow(npimg)
    plt.pause(3)
    if file != '':
        plt.savefig(file+'.png')
    plt.close()

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')

    
class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        #self.classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')
        self.dims = (3, 32, 32) #(1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        #self.transform=transforms.Compose([transforms.ToTensor(),
        #                      transforms.Normalize((0.1307,), (0.3081,))])
        
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        x = self.decoder(embedding)
        return x

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        x, y = batch
        #x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        #x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
 
    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        cifar10_full =CIFAR10(self.data_dir, train=True, transform=self.transform)
        n_train = int(len(cifar10_full)*0.8)
        n_val = len(cifar10_full)-n_train
        self.cifar10_train, self.cifar10_val = torch.utils.data.random_split(cifar10_full, [n_train, n_val])
        self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        # get some random training images
        return self.trainloader
    
    def val_dataloader(self):
        return DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader

    
def main():    
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)
    # model
    autoencoder = LitAutoEncoder()
    autoencoder = autoencoder.to(device)  #for gpu
    summary(autoencoder.encoder,(3,32,32))
    summary(autoencoder.decoder,(256,8,8))
    summary(autoencoder,(3,32,32))
    print(autoencoder)

    #trainer = pl.Trainer()
    trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[MyPrintingCallback()])
    
    
    trainer.fit(autoencoder) #, DataLoader(train), DataLoader(val))    
    print('training_finished')

    dataiter = iter(autoencoder.trainloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'cifar10_initial',text_='original')
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    results = trainer.test(autoencoder)
    print(results)

    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

    # torchscript
    # torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
    trainer.save_checkpoint("example_cifar10.ckpt")

    PATH = 'example_cifar10.ckpt'
    pretrained_model = autoencoder.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    pretrained_model.eval()

    latent_dim,ver = 16384, 10
    dataiter = iter(autoencoder.testloader)
    images, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

    encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,3,32,32))
    decode_img = pretrained_model.decoder(encode_img)
    imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32,32)), 'original_autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(8)))

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    
1
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?