前回の続きとして、Cifar10だけど、自前datasetにしていろいろ処理を施しつつpytorch-lightningでDataloaderに流し込む手法で利用してみた。
ありそうでここまでも余り公開されているものは少ないと思うので、精度などまだまだ改善の余地があるがまとめておこうと思う。
現在の段階で以下のような結果となっている。これは、以前参考にあるように同じような手法でKerasで実施したものと同様な結果である。
コードの見やすさは、格段に分かり易くなったと思う。、
【参考】
・【画像生成】AutoencoderでDenoising, Coloring, そして拡大カラー画像生成♬
1 | 2 |
---|---|
オリジナル; out_data | ![]() |
インプット; ノイズ有gray; out_data1 | ![]() |
生成画像; preds | ![]() |
比較用正規化画像 (target画像); out_data2 | ![]() |
やったこと
・pytorch-lightningでdatasetを利用する(省略)
・pytorch-lightningの最終的なコード解説
・pytorch-lightningでdatasetを利用する(省略)
今回は、段階的開発を実施したので、開発上の各段階のプロセスを記述し、その都度のコードは提示しないこととする。
①datasetを利用して、Dataloaderはmain()の中で定義してCifar10画像を出力してみる
⇒ここは、前回の記事参照
②①を利用して、pytorch-lightningのデータ部分をmain()の中で定義してpytorch-lightningとしてautoencoderを動かす
⇒ここは、【pytorch-lightning入門】ゼロから作るMNIST及びCifar10のAutoencoder♬に①を導入したので参照のこと
③datasetを利用して、通常のDataloaderの定義にしたがって使えるようにする
⇒ここは、④の解説と合わせて以下のコード解説で解説する
④入力と比較用のdatasetとするためtransformsなどを微調整
ここでは、この④のプロセスで完成したdatasetの解説をしたいと思う。
④入力と比較用のdatasetとするためtransformsなどを微調整
最終的に以下のようなdatasetクラスを作成した。
大切なことをまとめると以下のようになります。
・return out_data, out_data1, out_data2, out_label
とあるように返却値は、
- out_data;元のデータ、
- out_data1;transform1で処理したデータ、
- out_data2;transform2で処理したデータ
- 元データのlabel
を返しています。
つまり、ここでは3種類の画像データを生成していますが、個数や処理も自由に設定して、何個でも(複数ソースなど種類が異なっても)処理して返却(利用)できるということです。
・返却値は、全てself.ts2 = transform=transforms.ToTensor()
で処理されたものです。
そのために、読み出し当初にその処理を実施しています。
・self.dataは、通常と同じく、images, labelのデータを含んでいます。labelは前回同様に変換しただけとしていますが今回は利用していないので利用するときに変更が発生する可能性はあります。
・imagesのデータを最終的にどのような形式で返すべきかが大切で、以下のシークエンスで処理しています。
・まず、imageをToTensor()で読み込んで一度out_dataに格納し、それを処理してout_data, out_data1, out_data2などを出力することとしました。つまり、out_data等も最終的にはToTensor()形式で出力されています。これは、前回数値データでは、self.data = torch.from_numpy(np.array(x)).float()
で出力したことに対応しています。
・downloadは、通常と同じようにtrain=TrueとFalseの二種類を格納しています。二本必要かどうかは未確認ですが、一本だけでも学習もテストも動くのでtrain=Falseの方は不要なようです
Cifar10を処理後提供するためのDatasetのコード
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
self.transform1 = transform1
self.transform2 = transform2
#self.ts = torchvision.transforms.ToPILImage()
self.ts2 = transform=transforms.ToTensor()
self.train = train_
self.data_dir = './'
self.data_num = data_num
self.data = []
self.label = []
# download
CIFAR10(self.data_dir, train=True, download=True)
#CIFAR10(self.data_dir, train=False, download=True)
self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx][0]
out_label_ = self.data[idx][1]
out_label = torch.from_numpy(np.array(out_label_)).long()
if self.transform1:
out_data1 = self.transform1(out_data)
if self.transform2:
out_data2 = self.transform2(out_data)
return out_data, out_data1, out_data2, out_label
Cifar10を処理後提供するためのDatasetのコード
import os
import time
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import CIFAR10 #MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchsummary import summary
from net_encoder_decoder1D2DResize import Encoder, Decoder
def imshow(img,file='', text_=''):
img = img / 2 + 0.5 # unnormalize
npimg = img.detach().numpy() #img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.text(x = 3, y = 2, s = text_, c = "red")
plt.pause(3)
if file != '':
plt.savefig(file+'.png')
plt.close()
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_epoch_end(self, trainer, pl_module):
print('')
class MyAddGaussianNoise(object):
def __init__(self, mean=0., std=0.1):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
self.transform1 = transform1
self.transform2 = transform2
self.ts = torchvision.transforms.ToPILImage()
self.ts2 = transform=transforms.ToTensor()
self.train = train_
self.data_dir = './'
self.data_num = data_num
self.data = []
self.label = []
# download
CIFAR10(self.data_dir, train=True, download=True)
#CIFAR10(self.data_dir, train=False, download=True)
self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx][0]
out_label = np.array(self.data[idx][1])
if self.transform1:
out_data1 = self.transform1(out_data)
if self.transform2:
out_data2 = self.transform2(out_data)
return out_data, out_data1, out_data2, out_label
上記のdatasetを利用したautoencoderのコード
class LitAutoEncoder(pl.LightningModule):
def __init__(self, data_dir='./'):
super().__init__()
self.data_dir = data_dir
self.data_num =50000 #50000
# Hardcode some dataset specific attributes
self.num_classes = 10
self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
self.dims = (32*2, 32*2)
self.mean, self.std =[0.5,0.5,0.5], [0.25,0.25,0.25]
self.trans2 = torchvision.transforms.Compose([
torchvision.transforms.Normalize(self.mean, self.std),
torchvision.transforms.Resize(self.dims)
])
self.trans1 = torchvision.transforms.Compose([
torchvision.transforms.Normalize(self.mean, self.std),
MyAddGaussianNoise(0., 0.5),
torchvision.transforms.Grayscale()
])
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
x = self.encoder(x)
x = self.decoder(x)
return x
def training_step(self, batch, batch_idx):
# training_step defined the train loop. It is independent of forward
_, x, x_ , y = batch
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x_)
self.log('train_loss', loss, prog_bar = True)
return loss
def validation_step(self, batch, batch_idx):
_, x, x_, y = batch
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x_)
self.log('test_loss', loss, prog_bar = True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def setup(self, stage=None): #train, val, testデータ分割
# Assign train/val datasets for use in dataloaders
cifar10_full =ImageDataset(self.data_num, train=True, transform1=self.trans1, transform2=self.trans2)
n_train = int(len(cifar10_full)*0.8)
n_val = int(len(cifar10_full)*0.1)
n_test = len(cifar10_full)-n_train -n_val
self.cifar10_train, self.cifar10_val, self.cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
def train_dataloader(self):
self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
return self.trainloader
def val_dataloader(self):
self.valloader = DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
return self.valloader
def test_dataloader(self):
self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
return self.testloader
上記を動かすためのmain()コード
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
pl.seed_everything(0)
# model
autoencoder = LitAutoEncoder()
autoencoder = autoencoder.to(device) #for gpu
print(autoencoder)
summary(autoencoder.encoder,(1,32,32))
summary(autoencoder,(1,32,32))
trainer = pl.Trainer(max_epochs=10, gpus=1, callbacks=[MyPrintingCallback()]) ####epoch
trainer.fit(autoencoder)
print('training_finished')
results = trainer.test(autoencoder)
print(results)
dataiter = iter(autoencoder.valloader) #autoencoder.testloader
_,images, _, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images.reshape(32,1,32,32)), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# print labels
print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# torchscript
#torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
trainer.save_checkpoint("example_cifar10.ckpt")
PATH = 'example_cifar10.ckpt'
pretrained_model = autoencoder.load_from_checkpoint(PATH)
pretrained_model.freeze()
pretrained_model.eval()
latent_dim,ver = "Gray2ClolarizationResize1000", "10" #####save condition
dataiter = iter(autoencoder.valloader) #autoencoder.testloader
images0,images, images1, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images.reshape(32,1,32,32)),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# show images0
imshow(torchvision.utils.make_grid(images0.reshape(32,3,32,32)),'original_images0_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# show images1
imshow(torchvision.utils.make_grid(images1.reshape(32,3,32*2,32*2)),'normalized_images1_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,1,32,32))
decode_img = pretrained_model.decoder(encode_img)
imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32*2,32*2)), 'autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
if __name__ == '__main__':
start_time = time.time()
main()
print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))
上記のGray⇒Resizeで利用するencoder-decoderのコード
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 64,
kernel_size = 3, padding = 1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(64),
nn.Conv2d(64, 256, 3, 1, 1),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(256)
)
def forward(self, x):
x = self.encoder(x)
return x
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.decoder = nn.Sequential(
nn.ConvTranspose2d(in_channels = 256, out_channels = 64,
kernel_size = 2, stride = 2, padding = 0),
nn.ConvTranspose2d(in_channels = 64, out_channels = 16,
kernel_size = 2, stride = 2),
nn.ConvTranspose2d(in_channels = 16, out_channels = 3,
kernel_size = 2, stride = 2)
)
def forward(self, x):
x = self.decoder(x)
return x
実行結果の例(学習データを減らしています)
>python autoencoder_colorizationResize_dataset.py
cuda:0
LitAutoEncoder(
(encoder): Encoder(
(encoder): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(5): ReLU()
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(decoder): Decoder(
(decoder): Sequential(
(0): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))
(1): ConvTranspose2d(64, 16, kernel_size=(2, 2), stride=(2, 2))
(2): ConvTranspose2d(16, 3, kernel_size=(2, 2), stride=(2, 2))
)
)
)
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 32, 32] 640
ReLU-2 [-1, 64, 32, 32] 0
MaxPool2d-3 [-1, 64, 16, 16] 0
BatchNorm2d-4 [-1, 64, 16, 16] 128
Conv2d-5 [-1, 256, 16, 16] 147,712
ReLU-6 [-1, 256, 16, 16] 0
MaxPool2d-7 [-1, 256, 8, 8] 0
BatchNorm2d-8 [-1, 256, 8, 8] 512
================================================================
Total params: 148,992
Trainable params: 148,992
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 2.50
Params size (MB): 0.57
Estimated Total Size (MB): 3.07
----------------------------------------------------------------
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 32, 32] 640
ReLU-2 [-1, 64, 32, 32] 0
MaxPool2d-3 [-1, 64, 16, 16] 0
BatchNorm2d-4 [-1, 64, 16, 16] 128
Conv2d-5 [-1, 256, 16, 16] 147,712
ReLU-6 [-1, 256, 16, 16] 0
MaxPool2d-7 [-1, 256, 8, 8] 0
BatchNorm2d-8 [-1, 256, 8, 8] 512
Encoder-9 [-1, 256, 8, 8] 0
ConvTranspose2d-10 [-1, 64, 16, 16] 65,600
ConvTranspose2d-11 [-1, 16, 32, 32] 4,112
ConvTranspose2d-12 [-1, 3, 64, 64] 195
Decoder-13 [-1, 3, 64, 64] 0
================================================================
Total params: 218,899
Trainable params: 218,899
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 3.06
Params size (MB): 0.84
Estimated Total Size (MB): 3.90
----------------------------------------------------------------
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Files already downloaded and verified
2021-01-23 19:00:38.267639: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2021-01-23 19:00:38.267755: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
| Name | Type | Params
---------------------------------------
0 | encoder | Encoder | 148 K
1 | decoder | Decoder | 69.9 K
---------------------------------------
218 K Trainable params
0 Non-trainable params
218 K Total params
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 46.22it/s, loss=0.301, v_num=62, test_loss=0.229, train_loss=0.284]
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.88it/s, loss=0.213, v_num=62, test_loss=0.171, train_loss=0.201]
Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 49.16it/s, loss=0.176, v_num=62, test_loss=0.153, train_loss=0.212]
Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 47.85it/s, loss=0.164, v_num=62, test_loss=0.155, train_loss=0.139]
Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.13it/s, loss=0.166, v_num=62, test_loss=0.142, train_loss=0.15]
Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.43it/s, loss=0.164, v_num=62, test_loss=0.15, train_loss=0.211]
Epoch 6: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.61it/s, loss=0.163, v_num=62, test_loss=0.142, train_loss=0.135]
Epoch 7: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 47.80it/s, loss=0.149, v_num=62, test_loss=0.138, train_loss=0.141]
Epoch 8: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.01it/s, loss=0.152, v_num=62, test_loss=0.152, train_loss=0.132]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.83it/s, loss=0.156, v_num=62, test_loss=0.134, train_loss=0.182]
Epoch 9: 100%|████████████████████████████████████████████████████████████████████████████████████| 156/156 [00:03<00:00, 48.77it/s, loss=0.156, v_num=62, test_loss=0.134, train_loss=0.182]
training_finished
Files already downloaded and verified
Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1407/1407 [00:26<00:00, 53.37it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.1358, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_loss': 0.13576959073543549}]
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
truck ship cat truck
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
elapsed time: 81.340 [sec]
・拡大カラー画像生成をさらに工夫したいと思う
おまけ
今回は、Networkは貧弱ですが、また事前にNormalizationしてしまっていますが、Denoisingと拡大画像生成には成功していることが分かります。
※なお、NormalizationはDLでやるより、今回のようにtransformsの関数でやるのが筋です
1 | 2 |
---|---|
オリジナル; out_data | ![]() |
インプット; ノイズ有; out_data1 | ![]() |
生成画像; preds | ![]() |
比較用正規化画像 (target画像); out_data2 | ![]() |
Denoising, Colaring, Normaliztionのコード全体
import os
import time
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torchsummary import summary
from net_encoder_decoder1D2DResize import Encoder, Decoder
def imshow(img,file='', text_=''):
img = img / 2 + 0.5 # unnormalize
npimg = img.detach().numpy() #img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.text(x = 3, y = 2, s = text_, c = "red")
plt.pause(3)
if file != '':
plt.savefig(file+'.png')
plt.close()
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_epoch_end(self, trainer, pl_module):
print('')
class MyAddGaussianNoise(object):
def __init__(self, mean=0., std=0.1):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class ImageDataset(torch.utils.data.Dataset):
def __init__(self, data_num,train_=True, transform1 = None, transform2 = None,train = True):
self.transform1 = transform1
self.transform2 = transform2
self.ts = torchvision.transforms.ToPILImage()
self.ts2 = transform=transforms.ToTensor()
self.train = train_
self.data_dir = './'
self.data_num = data_num
self.data = []
self.label = []
# download
CIFAR10(self.data_dir, train=True, download=True)
#CIFAR10(self.data_dir, train=False, download=True)
self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts2)
def __len__(self):
return self.data_num
def __getitem__(self, idx):
out_data = self.data[idx][0]
out_label_ = self.data[idx][1]
out_label = torch.from_numpy(np.array(out_label_)).long()
if self.transform1:
out_data1 = self.transform1(out_data)
if self.transform2:
out_data2 = self.transform2(out_data)
return out_data, out_data1, out_data2, out_label
class LitAutoEncoder(pl.LightningModule):
def __init__(self, data_dir='./'):
super().__init__()
self.data_dir = data_dir
self.data_num =50000 #50000
# Hardcode some dataset specific attributes
self.num_classes = 10
self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
self.dims = (32*2, 32*2)
self.mean, self.std =[0.5,0.5,0.5], [0.25,0.25,0.25]
self.trans2 = torchvision.transforms.Compose([
torchvision.transforms.Normalize(self.mean, self.std),
torchvision.transforms.Resize(self.dims)
])
self.trans1 = torchvision.transforms.Compose([
torchvision.transforms.Normalize(self.mean, self.std),
MyAddGaussianNoise(0., 0.5),
#torchvision.transforms.Grayscale()
])
self.encoder = Encoder()
self.decoder = Decoder()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()
self.test_acc = pl.metrics.Accuracy()
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
x = self.encoder(x)
x = self.decoder(x)
return x
def training_step(self, batch, batch_idx):
# training_step defined the train loop. It is independent of forward
_,x,x_ , y = batch
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x_)
self.log('train_loss', loss, prog_bar = True)
return loss
def validation_step(self, batch, batch_idx):
_,x, x_, y = batch
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x_)
self.log('test_loss', loss, prog_bar = True)
return loss
def test_step(self, batch, batch_idx):
# Here we just reuse the validation_step for testing
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def setup(self, stage=None): #train, val, testデータ分割
# Assign train/val datasets for use in dataloaders
cifar10_full =ImageDataset(self.data_num, train=True, transform1=self.trans1, transform2=self.trans2)
n_train = int(len(cifar10_full)*0.02)
n_val = int(len(cifar10_full)*0.08)
n_test = len(cifar10_full)-n_train -n_val
self.cifar10_train, self.cifar10_val, self.cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
def train_dataloader(self):
self.trainloader = DataLoader(self.cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
return self.trainloader
def val_dataloader(self):
self.valloader = DataLoader(self.cifar10_val, shuffle=False, batch_size=32, num_workers=0)
return self.valloader
def test_dataloader(self):
self.testloader = DataLoader(self.cifar10_test, shuffle=False, batch_size=32, num_workers=0)
return self.testloader
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
pl.seed_everything(0)
# model
autoencoder = LitAutoEncoder()
autoencoder = autoencoder.to(device) #for gpu
print(autoencoder)
summary(autoencoder.encoder,(3,32,32))
summary(autoencoder,(3,32,32))
trainer = pl.Trainer(max_epochs=100, gpus=1, callbacks=[MyPrintingCallback()]) ####epoch
trainer.fit(autoencoder)
print('training_finished')
results = trainer.test(autoencoder)
print(results)
dataiter = iter(autoencoder.testloader)
_,images, _, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images.reshape(32,3,32,32)), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# print labels
print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# torchscript
#torch.jit.save(autoencoder.to_torchscript(), "model_cifar10.pt")
trainer.save_checkpoint("example_cifar10.ckpt")
PATH = 'example_cifar10.ckpt'
pretrained_model = autoencoder.load_from_checkpoint(PATH)
pretrained_model.freeze()
pretrained_model.eval()
latent_dim,ver = "ClolarizationResize1000", "100" #####save condition
dataiter = iter(autoencoder.testloader)
images0,images, images1, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images.reshape(32,3,32,32)),'original_images_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# show images0
imshow(torchvision.utils.make_grid(images0.reshape(32,3,32,32)),'original_images0_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
# show images1
imshow(torchvision.utils.make_grid(images1.reshape(32,3,32*2,32*2)),'normalized_images1_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
encode_img = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,3,32,32))
decode_img = pretrained_model.decoder(encode_img)
imshow(torchvision.utils.make_grid(decode_img.cpu().reshape(32,3,32*2,32*2)), 'autoencode_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
if __name__ == '__main__':
start_time = time.time()
main()
print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))