3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【pytorch-lightning入門】自前datasetで~Denoising, Coloring, Normalization, そして拡大カラー画像生成で遊んでみた♬

Last updated at Posted at 2021-01-24

前回の続きとして、Cifar10だけど、自前datasetにしていろいろ処理を施しつつpytorch-lightningでDataloaderに流し込む手法で利用してみた。
ありそうでここまでも余り公開されているものは少ないと思うので、精度などまだまだ改善の余地があるがまとめておこうと思う。

現在の段階で以下のような結果となっている。これは、以前参考にあるように同じような手法でKerasで実施したものと同様な結果である。
コードの見やすさは、格段に分かり易くなったと思う。、
【参考】
【画像生成】AutoencoderでDenoising, Coloring, そして拡大カラー画像生成♬

1 2
オリジナル; out_data original_images0_cifar10_Gray2ClolarizationResize1000_100.png
インプット; ノイズ有gray; out_data1 original_images_cifar10_Gray2ClolarizationResize1000_100.png
生成画像; preds autoencode_preds_cifar10_Gray2ClolarizationResize1000_100.png
比較用正規化画像    (target画像); out_data2 normalized_images1_cifar10_Gray2ClolarizationResize1000_100.png

やったこと

・pytorch-lightningでdatasetを利用する(省略)
・pytorch-lightningの最終的なコード解説

・pytorch-lightningでdatasetを利用する(省略)

今回は、段階的開発を実施したので、開発上の各段階のプロセスを記述し、その都度のコードは提示しないこととする。
①datasetを利用して、Dataloaderはmain()の中で定義してCifar10画像を出力してみる
⇒ここは、前回の記事参照
②①を利用して、pytorch-lightningのデータ部分をmain()の中で定義してpytorch-lightningとしてautoencoderを動かす
⇒ここは、【pytorch-lightning入門】ゼロから作るMNIST及びCifar10のAutoencoder♬に①を導入したので参照のこと
③datasetを利用して、通常のDataloaderの定義にしたがって使えるようにする
⇒ここは、④の解説と合わせて以下のコード解説で解説する
④入力と比較用のdatasetとするためtransformsなどを微調整
ここでは、この④のプロセスで完成したdatasetの解説をしたいと思う。

④入力と比較用のdatasetとするためtransformsなどを微調整

最終的に以下のようなdatasetクラスを作成した。
大切なことをまとめると以下のようになります。
return out_data, out_data1, out_data2, out_labelとあるように返却値は、

  1. out_data;元のデータ、
  2. out_data1;transform1で処理したデータ、
  3. out_data2;transform2で処理したデータ
  4. 元データのlabel
    を返しています。
    つまり、ここでは3種類の画像データを生成していますが、個数や処理も自由に設定して、何個でも(複数ソースなど種類が異なっても)処理して返却(利用)できるということです。
    ・返却値は、全てself.ts2 = transform=transforms.ToTensor()で処理されたものです。
     そのために、読み出し当初にその処理を実施しています。
    ・self.dataは、通常と同じく、images, labelのデータを含んでいます。labelは前回同様に変換しただけとしていますが今回は利用していないので利用するときに変更が発生する可能性はあります。
    ・imagesのデータを最終的にどのような形式で返すべきかが大切で、以下のシークエンスで処理しています。
    ・まず、imageをToTensor()で読み込んで一度out_dataに格納し、それを処理してout_data, out_data1, out_data2などを出力することとしました。つまり、out_data等も最終的にはToTensor()形式で出力されています。これは、前回数値データでは、self.data = torch.from_numpy(np.array(x)).float()で出力したことに対応しています。
    ・downloadは、通常と同じようにtrain=TrueとFalseの二種類を格納しています。二本必要かどうかは未確認ですが、一本だけでも学習もテストも動くのでtrain=Falseの方は不要なようです
Cifar10を処理後提供するためのDatasetのコード
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
                
        self.transform1 = transform1
        self.transform2 = transform2
        #self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transform=transforms.ToTensor()
        self.train = train_
        
        self.data_dir = './'
        self.data_num = data_num
        self.data = []
        self.label = []

        # download
        CIFAR10(self.data_dir, train=True, download=True)
        #CIFAR10(self.data_dir, train=False, download=True)
        self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx][0]
        out_label_ =  self.data[idx][1]
        out_label = torch.from_numpy(np.array(out_label_)).long()
        if self.transform1:
            out_data1 = self.transform1(out_data)
        if self.transform2:
            out_data2 = self.transform2(out_data)
        return out_data, out_data1, out_data2, out_label
### ・pytorch-lightningの最終的なコード解説 ここまで来ると、pytorch-lightningのコードの汎用性を享受できます。すなわち、特別なことをしなければ、以下のように前回のコードと酷似したコードで動かせました。 以下のとおり、利用するLibも一緒です。 dataset以外に画像出力とガウスノイズを重畳するためのMyAddGaussianNoise(object)を定義しています。
Cifar10を処理後提供するためのDatasetのコード
import os
import time
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision.datasets import CIFAR10 #MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchsummary import summary

from net_encoder_decoder1D2DResize import Encoder, Decoder

def imshow(img,file='', text_=''):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy() #img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.text(x = 3, y = 2, s = text_, c = "red")
    plt.pause(3)
    if file != '':
        plt.savefig(file+'.png')
    plt.close()

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')

class MyAddGaussianNoise(object):
    def __init__(self, mean=0., std=0.1):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)  
    

class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
                
        self.transform1 = transform1
        self.transform2 = transform2
        self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transform=transforms.ToTensor()
        self.train = train_
        
        self.data_dir = './'
        self.data_num = data_num
        self.data = []
        self.label = []

        # download
        CIFAR10(self.data_dir, train=True, download=True)
        #CIFAR10(self.data_dir, train=False, download=True)
        self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx][0]
        out_label =  np.array(self.data[idx][1])
        if self.transform1:
            out_data1 = self.transform1(out_data)
        if self.transform2:
            out_data2 = self.transform2(out_data)
        return out_data, out_data1, out_data2, out_label
ここで注意すべきなのは、`training_step(self, batch, batch_idx):`と`validation_step(self, batch, batch_idx):`で定義されている ` _, x, x_ , y = batch`がdatasetの出力と連携しているところです。その後、出力変数と該当するものの F.mse_loss()を計算しています。 self.trans1とself.trans2で、どういう画像処理をするかを決めています。 ※記述順は特に意味はありません
上記のdatasetを利用したautoencoderのコード
class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        self.data_num =50000 #50000
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        self.dims = (32*2, 32*2) 
        self.mean, self.std =[0.5,0.5,0.5], [0.25,0.25,0.25]
        self.trans2 = torchvision.transforms.Compose([
            torchvision.transforms.Normalize(self.mean, self.std),
            torchvision.transforms.Resize(self.dims)
        ])
        self.trans1 =  torchvision.transforms.Compose([
            torchvision.transforms.Normalize(self.mean, self.std),
            MyAddGaussianNoise(0., 0.5),
            torchvision.transforms.Grayscale()
        ])
        
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        _, x, x_ , y = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x_)
        self.log('train_loss', loss, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        _, x, x_, y = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x_)
        self.log('test_loss', loss, prog_bar = True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 
        return optimizer
    
    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        cifar10_full =ImageDataset(self.data_num, train=True, transform1=self.trans1, transform2=self.trans2)
        n_train = int(len(cifar10_full)*0.8)
        n_val = int(len(cifar10_full)*0.1)
        n_test = len(cifar10_full)-n_train -n_val
        
        self.cifar10_train, self.cifar10_val, self.cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
        
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        return self.trainloader
    
    def val_dataloader(self):
        self.valloader = DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
        return self.valloader
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader
main()の処理順は、 ・学習準備 ・学習 ・テスト ・学習済checkpointの保存 ・結果確認のための初期画像出力  (上記はvalloaderの結果を表示していますがtestloaderでも同様な結果が得られています) ・学習済モデル読込, freeze(), eval() ・最後に、生成画像などを出力しています
上記を動かすためのmain()コード
def main():    
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)

    # model
    autoencoder = LitAutoEncoder()
    autoencoder = autoencoder.to(device) #for gpu
    print(autoencoder)
    summary(autoencoder.encoder,(1,32,32))
    summary(autoencoder,(1,32,32))
    
    trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[MyPrintingCallback()]) ####epoch
    
    trainer.fit(autoencoder)    
    print('training_finished')
    
    results = trainer.test(autoencoder)
    print(results)

    dataiter = iter(autoencoder.valloader) #autoencoder.testloader
    _,images, _, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images.reshape(32,1,32,32)), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

    # torchscript
    #torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
    trainer.save_checkpoint("example_cifar10.ckpt")

    PATH = 'example_cifar10.ckpt'
    pretrained_model = autoencoder.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    pretrained_model.eval()

    latent_dim,ver = "Gray2ClolarizationResize1000", "10"  #####save condition
    dataiter = iter(autoencoder.valloader)  #autoencoder.testloader
    images0,images, images1, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images.reshape(32,1,32,32)),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # show images0
    imshow(torchvision.utils.make_grid(images0.reshape(32,3,32,32)),'original_images0_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # show images1
    imshow(torchvision.utils.make_grid(images1.reshape(32,3,32*2,32*2)),'normalized_images1_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))     

    encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,1,32,32))
    decode_img = pretrained_model.decoder(encode_img)
    imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32*2,32*2)), 'autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    
今回は、pytorch-lightningの動きを確認する目的で実施しているので、基本的に比較的軽いnetwork(前回のcifar10のautoencoderとほぼ同様)としています。 前回の参考のautoencoderと異なるのは、gray画像を入力とするため、入力部分を以下としています。 `nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, padding = 1),` また、decoder()も、拡大画像を出力するために、一段増やしています。 【参考】 ・[【pytorch-lightning入門】ゼロから作るMNIST及びCifar10のAutoencoder♬](https://qiita.com/MuAuan/items/a062d0c245c8f4836399)
上記のGray⇒Resizeで利用するencoder-decoderのコード
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 64,
                               kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256)
        )

    def forward(self, x):
        x = self.encoder(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 256, out_channels = 64,
                                          kernel_size = 2, stride = 2, padding = 0),
            nn.ConvTranspose2d(in_channels = 64, out_channels = 16,
                                          kernel_size = 2, stride = 2),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 3,
                                          kernel_size = 2, stride = 2)
        )

    def forward(self, x):
        x = self.decoder(x)
        return x
実行結果は、学習データを減らしてテストデータを増やして対応するかを実験した例です。通常のように学習データを増やして、...など試してみましたが、全体にこのままでは前回結果や上記結果を大きく改善はできないようです。
実行結果の例(学習データを減らしています)
>python autoencoder_colorizationResize_dataset.py
cuda:0
LitAutoEncoder(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU()
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (decoder): Decoder(
    (decoder): Sequential(
      (0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))
      (1): ConvTranspose2d(64, 16, kernel_size=(2, 2), stride=(2, 2))
      (2): ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2))
    )
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]             640
              ReLU-2           [-1, 64, 32, 32]               0
         MaxPool2d-3           [-1, 64, 16, 16]               0
       BatchNorm2d-4           [-1, 64, 16, 16]             128
            Conv2d-5          [-1, 256, 16, 16]         147,712
              ReLU-6          [-1, 256, 16, 16]               0
         MaxPool2d-7            [-1, 256, 8, 8]               0
       BatchNorm2d-8            [-1, 256, 8, 8]             512
================================================================
Total params: 148,992
Trainable params: 148,992
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 2.50
Params size (MB): 0.57
Estimated Total Size (MB): 3.07
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 32, 32]             640
              ReLU-2           [-1, 64, 32, 32]               0
         MaxPool2d-3           [-1, 64, 16, 16]               0
       BatchNorm2d-4           [-1, 64, 16, 16]             128
            Conv2d-5          [-1, 256, 16, 16]         147,712
              ReLU-6          [-1, 256, 16, 16]               0
         MaxPool2d-7            [-1, 256, 8, 8]               0
       BatchNorm2d-8            [-1, 256, 8, 8]             512
           Encoder-9            [-1, 256, 8, 8]               0
  ConvTranspose2d-10           [-1, 64, 16, 16]          65,600
  ConvTranspose2d-11           [-1, 16, 32, 32]           4,112
  ConvTranspose2d-12            [-1, 3, 64, 64]             195
          Decoder-13            [-1, 3, 64, 64]               0
================================================================
Total params: 218,899
Trainable params: 218,899
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 3.06
Params size (MB): 0.84
Estimated Total Size (MB): 3.90
----------------------------------------------------------------
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Files already downloaded and verified
2021-01-23 19:00:38.267639: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2021-01-23 19:00:38.267755: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name      | Type     | Params
---------------------------------------
0 | encoder   | Encoder  | 148 K
1 | decoder   | Decoder  | 69.9 K
---------------------------------------
218 K     Trainable params
0         Non-trainable params
218 K     Total params
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 46.22it/s, loss=0.301, v_num=62, test_loss=0.229, train_loss=0.284]
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.88it/s, loss=0.213, v_num=62, test_loss=0.171, train_loss=0.201]
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 49.16it/s, loss=0.176, v_num=62, test_loss=0.153, train_loss=0.212]
Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 47.85it/s, loss=0.164, v_num=62, test_loss=0.155, train_loss=0.139]
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.13it/s, loss=0.166, v_num=62, test_loss=0.142, train_loss=0.15]
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.43it/s, loss=0.164, v_num=62, test_loss=0.15, train_loss=0.211]
Epoch 6: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.61it/s, loss=0.163, v_num=62, test_loss=0.142, train_loss=0.135]
Epoch 7: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 47.80it/s, loss=0.149, v_num=62, test_loss=0.138, train_loss=0.141]
Epoch 8: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.01it/s, loss=0.152, v_num=62, test_loss=0.152, train_loss=0.132]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.83it/s, loss=0.156, v_num=62, test_loss=0.134, train_loss=0.182]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.77it/s, loss=0.156, v_num=62, test_loss=0.134, train_loss=0.182]
training_finished
Files already downloaded and verified
Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1407/1407 [00:26<00:00, 53.37it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.1358, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_loss': 0.13576959073543549}]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
truck  ship   cat truck
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
elapsed time: 81.340 [sec]
### まとめ ・自前datasetを利用して、Denoising, Cloloring, Normalization, 拡大カラー画像生成で遊んでみた ・pytorch-lightningでも実行できるところまで到達し、前回のKerasでの結果と同様な結果を得た ・今回は、拡大カラー画像生成について主に記述しましたが、おまけに掲載したように、DenoisingやNormalization及び拡大カラー画像(Grayからではなくカラー画像入力)は精度の良い結果を得ているので試してください

・拡大カラー画像生成をさらに工夫したいと思う

おまけ

今回は、Networkは貧弱ですが、また事前にNormalizationしてしまっていますが、Denoisingと拡大画像生成には成功していることが分かります。
※なお、NormalizationはDLでやるより、今回のようにtransformsの関数でやるのが筋です

1 2
オリジナル; out_data original_images0_cifar10_ClolarizationResize1000_100.png
インプット; ノイズ有; out_data1 original_images_cifar10_ClolarizationResize1000_100.png
生成画像; preds autoencode_preds_cifar10_ClolarizationResize1000_100.png
比較用正規化画像    (target画像); out_data2 normalized_images1_cifar10_ClolarizationResize1000_100.png
Denoising, Colaring, Normaliztionのコード全体
import os
import time
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

import torchvision
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchsummary import summary

from net_encoder_decoder1D2DResize import Encoder, Decoder

def imshow(img,file='', text_=''):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy() #img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.text(x = 3, y = 2, s = text_, c = "red")
    plt.pause(3)
    if file != '':
        plt.savefig(file+'.png')
    plt.close()

from pytorch_lightning.callbacks import Callback    
class MyPrintingCallback(Callback):
    def on_epoch_end(self, trainer, pl_module):
        print('')

class MyAddGaussianNoise(object):
    def __init__(self, mean=0., std=0.1):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)  
    

class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
                
        self.transform1 = transform1
        self.transform2 = transform2
        self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transform=transforms.ToTensor()
        self.train = train_
        
        self.data_dir = './'
        self.data_num = data_num
        self.data = []
        self.label = []

        # download
        CIFAR10(self.data_dir, train=True, download=True)
        #CIFAR10(self.data_dir, train=False, download=True)
        self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx][0]
        out_label_ =  self.data[idx][1]
        out_label = torch.from_numpy(np.array(out_label_)).long()
        
        if self.transform1:
            out_data1 = self.transform1(out_data)
        if self.transform2:
            out_data2 = self.transform2(out_data)
        return out_data, out_data1, out_data2, out_label
    
class LitAutoEncoder(pl.LightningModule):

    def __init__(self, data_dir='./'):
        super().__init__()
        self.data_dir = data_dir
        self.data_num =50000 #50000
        # Hardcode some dataset specific attributes
        self.num_classes = 10
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        self.dims = (32*2, 32*2) 
        self.mean, self.std =[0.5,0.5,0.5], [0.25,0.25,0.25]
        self.trans2 = torchvision.transforms.Compose([
            torchvision.transforms.Normalize(self.mean, self.std),
            torchvision.transforms.Resize(self.dims)
        ])
        self.trans1 =  torchvision.transforms.Compose([
            torchvision.transforms.Normalize(self.mean, self.std),
            MyAddGaussianNoise(0., 0.5),
            #torchvision.transforms.Grayscale()
        ])
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
        self.train_acc = pl.metrics.Accuracy()
        self.val_acc = pl.metrics.Accuracy()
        self.test_acc = pl.metrics.Accuracy()

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        _,x,x_ , y = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x_)
        self.log('train_loss', loss, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        _,x, x_, y = batch
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x_)
        self.log('test_loss', loss, prog_bar = True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 
        return optimizer
    
    def setup(self, stage=None): #train, val, testデータ分割
        # Assign train/val datasets for use in dataloaders
        cifar10_full =ImageDataset(self.data_num, train=True, transform1=self.trans1, transform2=self.trans2)
        n_train = int(len(cifar10_full)*0.02)
        n_val = int(len(cifar10_full)*0.08)
        n_test = len(cifar10_full)-n_train -n_val
        
        self.cifar10_train, self.cifar10_val, self.cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
        
    
    def train_dataloader(self):
        self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
        return self.trainloader
    
    def val_dataloader(self):
        self.valloader = DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
        return self.valloader
    
    def test_dataloader(self):
        self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
        return self.testloader
    
    
def main():    
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)

    # model
    autoencoder = LitAutoEncoder()
    autoencoder = autoencoder.to(device) #for gpu
    print(autoencoder)
    summary(autoencoder.encoder,(3,32,32))
    summary(autoencoder,(3,32,32))
    
    trainer = pl.Trainer(max_epochs=100, gpus=1, callbacks=[MyPrintingCallback()]) ####epoch
    
    trainer.fit(autoencoder)    
    print('training_finished')
    
    results = trainer.test(autoencoder)
    print(results)

    dataiter = iter(autoencoder.testloader)
    _,images, _, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images.reshape(32,3,32,32)), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # print labels
    print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

    # torchscript
    #torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
    trainer.save_checkpoint("example_cifar10.ckpt")

    PATH = 'example_cifar10.ckpt'
    pretrained_model = autoencoder.load_from_checkpoint(PATH)
    pretrained_model.freeze()
    pretrained_model.eval()

    latent_dim,ver = "ClolarizationResize1000", "100"  #####save condition
    dataiter = iter(autoencoder.testloader)
    images0,images, images1, labels = dataiter.next()
    # show images
    imshow(torchvision.utils.make_grid(images.reshape(32,3,32,32)),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # show images0
    imshow(torchvision.utils.make_grid(images0.reshape(32,3,32,32)),'original_images0_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
    # show images1
    imshow(torchvision.utils.make_grid(images1.reshape(32,3,32*2,32*2)),'normalized_images1_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))     

    encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,3,32,32))
    decode_img = pretrained_model.decoder(encode_img)
    imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32*2,32*2)), 'autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    
3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?