0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【色付け入門】coloringの精度向上出来たかも♬

Last updated at Posted at 2021-02-07

前回記事でAutoencoder利用して、簡単な色付けをやってみた。
今回は、深堀してさらに綺麗な色付けを出来るか遊んでみたので、まとめておく。
アイディアのほとんどは、以下の参考①、そして理論的な話は参考②を参考としています。
【参考】
Image Colorization with Convolutional Neural Networks
色空間の変換(4) YCbCr/YPbPr 色空間
結果は、cifar10に対して、以下のような絵が得られました。

1
出力画像(生成後規格化)autoencode_preds_cifar10_Gray2ClolarizationNormalizeResize4LYCC05_20.png
入力画像original_images_cifar10_Gray2ClolarizationNormalizeResize4LYCC05_20.png
規格化元画像original_images0_norm_cifar10_Gray2ClolarizationNormalizeResize4LYCC05_20.png
元画像original_images0_cifar10_Gray2ClolarizationNormalizeResize4LYCC05_20.png

やったこと

・色空間ってなんだ
・変換関数を試してみる
・pytoachでやってみる

・色空間ってなんだ

これについては、参考②でも書いておりますが、実際の画像は参考③を見るとある程度想像つくと思います。参考②③では、いろいろな色空間に変換して、それぞれのチャンネルの頻度を見て、明るさ調整をしています。
【参考】
【画像処理】暗視カメラ出来たかも♬
【画像処理】静止画・カメラ動画の輝度ヒストグラム調整と明度自動調整をやってみる♬
今回は、これをグレー画像とそれ以外の色付け要素に分解して、グレー画像をもとに色付け要素だけを学習して、元のグレー画像に重畳する。このようにして、カラー画像を復元しようというものです。絵に描くと、以下の通りです。
(参考①から引用)
coloring.jpg
参考①では、色空間としてCIELABを利用しています。
しかし、同じものを利用するのも発展性が無いということで、参考②を頼りに別の色空間を利用することとしました。
色空間は、元々rgbであれば3チャンネルで、色の表現方法はそのほかにも以下の参考のようにいろいろ提案・利用されています。この多様性は、見え方や表現方法の違いから多くの派生が生まれているようです。
【参考】
色空間
Lab色空間
YUV

そして、参考⑥によれば、今回色付けで参考にした上図の根拠は、以下のようです。
「Lab色空間は人間の視覚を近似するよう設計されている。知覚的均等性を重視しており、L成分値は人間の明度の知覚と極めて近い。したがって、カラーバランス調整を正確に行うために出力曲線を a および b の成分で表現したり、コントラストの調整のためにL成分を使ったりといった利用が可能である」
ということで、LとABを分離し、上記のような学習から色塗りを実施しています。
参考①では、以下のように記載しています。
「This colorspace contains exactly the same information as RGB, but it will make it easier for us to separate out the lightness channel from the other two (which we call A and B). 」

一方で、参考②のような話もあります。つまり、
「テレビ・ビデオや画像のJPEG圧縮に使われる色空間に YCbCr がある。このうち,Y は輝度信号,Cb と Cr は色差信号を表す。YCbCr は,色空間というよりも信号の伝送・記録方式というほうがいいかもしれない。YCbCr と似たものに YPbPr がある。これらの違いについては,
1.YPbPr はアナログ信号を表し,YCbCr はデジタルデータを表す [海外での解釈?]
2.YCbCr は標準テレビ (SDTV) の信号を表し,YPbPr はハイビジョン (HDTV) の信号を表す [日本での解釈?]
という解釈があって混乱しているが,ここでは 1. の解釈 (YPbPr はアナログ,YCbCr はデジタル) に従うことにする。」
参考⑦より、いづれにしても以下の通りで、テレビなどでは以下の方式ということです。
「YUVやYCbCrやYPbPrとは、輝度信号Yと、2つの色差信号を使って表現される色空間。」
ということで、今回は、YCbCrを採用することとしました。
現存の白黒画像は、本来白黒映像か白黒写真由来のものだろうと想像し、その場合はこのYCbCrで表現できると考えました。
※以下で見るように基本的にはLABでもYCbCrもどちらも関数一個でrgbから変換できるので、対象によって学習を変えればいいし、そもそもLとYを比較するとあまり主張するほどの違いはないようです。

・変換関数を試してみる

ここでは、opencvの以下の変換関数を利用しました。
RGB画像のYCbCrとLABへの変換関数と逆関数は以下のとおりです。
なお、試した結果が膨大になってしまったことと、気づきが別途あったので近日中に別記事とします。

関数 適用
cv2.cvtColor(image0, cv2.COLOR_BGR2RGB) jpgやpngの画像データをRGB形式に変換※PILのImage.open()で読み込むと変換できない
cv2.cvtColor(image1, cv2.COLOR_BGR2YCR_CB) jpgやpngの画像データをYCrCb形式に変換
Y, Cr, Cb = cv2.split(orgYCrCb) YCrCb形式の画像を要素Y,Cr,Cbに分離
YCC = cv2.merge((Y,Cr,Cb)) 要素Y,Cr,Cbを合成する
cv2.cvtColor(YCC, cv2.COLOR_YCR_CB2BGR) YCrCb形式の画像YCCをBGR形式に変換
cv2.cvtColor(image1, cv2.COLOR_BGR2LAB) jpgやpngの画像データをLAB形式に変換
L, A, B = cv2.split(orgLAB) LAB形式の画像を要素L,A,Bに分離
LAB = cv2.merge((L,A,B)) 要素L,A,Bを合成する
cv2.cvtColor(LAB, cv2.COLOR_LAB2BGR) LAB形式の画像LABをBGR形式に変換

また、pytorchのtorchvision.transformsの変換関数、および親和性の高いPILの以下の関数を利用します。pytorchの変換は、torch-tensorまたは、PIL形式のファイルを変換します。

1 2
ts = transforms.ToPILImage() PIL形式に変換
ts2 = transforms.ToTensor() torch.tensorに変換
ts3 = transforms.Grayscale() grayscaleに変換
ts3 = transforms.Normalize(mean,std) 平均meanと標準偏差stdに規格化する
image_=ts(image0).convert('L') grayscaleに変換
image_g=ts3(ts(image0)) PIL形式に変換した後、graysclaleに変換
Y_ = ts(Y).convert('RGB') 1chデータYをPIL形式に変換した後、RGB形式に変換

・pytoachでやってみる

参考①のnetworkとvgg16を組み合わせた以下のおまけのnetworkで学習してみました。
ほぼ、前回の参考⑧と同様ですが、以下の変更を実施しています。
※始める前の見通しでは非常に簡単なお話だと思っていましたが、かなり突っ込んでしまったので、あっさり記載します
①Y,CrCbで学習する
②Networkを変更する
【参考】
【pytorch-lightning入門】自前datasetで~Denoising, Coloring, Normalization, そして拡大カラー画像生成で遊んでみた♬

Dataloaderの準備①Y,CrCbで学習するためにデータ準備

まず、dataset及びDataloaderは以下のとおりとしました。
今回も、既存のCifar10のデータを利用します。

ここで画像処理をどの段階でやるのが効率的か悩みましたが、最終的に以下の通りとしました。
※現時点では、さらにいろいろ遊んでみると精度が上がるし、以下の記述も変更すべきとなる可能性があると思っています
1.BGRからYCrCb空間に変換する
2.YとCrCb成分を別々に提供する
3.当初、同一関数でreturn Y, CCのように渡していましたが、挙動が怪しいので別関数としました
4.どこまでの処理をデータ読込とこの関数で実施するのかを悩みましたが、二つの関数で共通な画像サイズ拡大と規格化は最初のデータ読込時に実施することとしました。つまりこの関数ではその後の処理のみをすることとしました(なので、一部無駄コードを残しています)
5.画像の値域制御のための処理(img = img / 2 + 0.5;unnormalize処理)は、画像が綺麗に見える値(ここでは1/4;通常plt.imshowの描画範囲から1/2だが規格化すると超える場合がある)を採用しています
6.Yについては、当初元画像からgray変換実施してgray画像を返していたが、後ほど合成するので同時処理して、gray変換するように変更した。たぶん、こちらが正解だろうと思う。なお、元画像が同一ならばgrayとYそして、L(参考①のCIELAB空間のL)は、(見栄えの確認だが)同一だということを確認した

Datasetのコードに提供するためのrgb2YCrCb変換のコード
class rgb2YCrCb(object):
    def __init__(self):
        self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transform=transforms.ToTensor()
        mean, std =[0.5,0.5,0.5], [0.25,0.25,0.25]
        self.ts3 = torchvision.transforms.Normalize(mean, std)
        pass
    
    def __call__(self, tensor):
        tensor = tensor  / 4 + 0.5     # unnormalize
        orgYCrCb = cv2.cvtColor(np.float32(self.ts(tensor)), cv2.COLOR_BGR2YCR_CB)
        Y, Cr,Cb = cv2.split(orgYCrCb)
        CC = cv2.merge((Cr,Cb))
        CC = np.array(CC).reshape(2,32*2,32*2) #(2,32*2,32*2)
        #print(CC.shape)
        return np.array(CC)
    
    def __repr__(self):
        return self.__class__.__name__
    
class rgb2YCrCb_(object):
    def __init__(self):
        self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transform=transforms.ToTensor()
        mean, std =[0.5,0.5,0.5], [0.25,0.25,0.25]
        self.ts3 = torchvision.transforms.Normalize(mean, std)
        pass
    
    def __call__(self, tensor):
        #tensor = self.ts3(self.ts2(self.ts(tensor)))  / 4 + 0.5     # unnormalize        
        tensor = tensor  / 4 + 0.5     # unnormalize
        orgYCrCb = cv2.cvtColor(np.float32(self.ts(tensor)), cv2.COLOR_BGR2YCR_CB)
        Y, Cr,Cb = cv2.split(orgYCrCb)
        CC = cv2.merge((Cr,Cb))
        Y = np.array(Y).reshape(1,32*2,32*2) #(1,32*2,32*2)
        #print(Y.shape)
        return Y
ここは、前回との大きな違いは、`self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts3)`のself.ts3として、ToTensor()と同時に規格化を実施することとした。それ以外は、前回のコードとほぼ同一です。 つまり、 ・return out_data, out_data1, out_data2, out_labelとあるように返却値は、 1. out_data;元のデータ、 2. out_data1;transform1で処理したデータ、 3. out_data2;transform2で処理したデータ 4. 元データのlabel を返しています。
Cifar10を処理後提供するためのDatasetのコード
class ImageDataset(torch.utils.data.Dataset):

    def __init__(self, data_num, train_=True, transform1 = None, transform2 = None, train = True):
                
        self.transform1 = transform1
        self.transform2 = transform2
        self.ts = torchvision.transforms.ToPILImage()
        self.ts2 = transforms.ToTensor()
        mean, std =[0.5,0.5,0.5], [0.25,0.25,0.25] #(0.485, 0.456, 0.406), (0.229, 0.224, 0.225) 後者はtorchvisionで掲載されているImagenetの平均、標準偏差だが前者を採用
        self.ts3 =  transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        self.train = train_
        
        self.data_dir = './'
        self.data_num = data_num
        self.data = []
        self.label = []

        # download
        CIFAR10(self.data_dir, train=True, download=True)
        self.data =CIFAR10(self.data_dir, train=self.train, transform=self.ts3)

    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        out_data = self.data[idx][0]
        out_label_ =  self.data[idx][1]
        out_label = torch.from_numpy(np.array(out_label_)).long()
        if self.transform1:
            out_data1 = self.transform1(out_data)
        if self.transform2:
            out_data2 = self.transform2(out_data)
        return out_data, out_data1, out_data2, out_label
画像の平均と標準偏差の値について、参考⑨に以下の回答があります。 「これは ImageNet という画像分類データセットの RGB の平均と標準偏差です。torchvision の事前学習済みモデルを利用するならこれを使う必要がありますが、自作のモデルでゼロから学習するなら、自分のデータセットの平均と分散を計算しておいてそれを使うということになります。」 【参考】 ⑨[PyTorch における正規化について](https://teratail.com/questions/295871) #### ②Networkを変更する pytorch-lightningのいいところは、その部分を検討できるところ。ということで、ここで、networkを参考①のnetworkに変更しようと思います。しかし、残念ながらencoder部分の入力はtensorにしなさいというバグが出て、そちらはうまくいきませんでした。 ということで、decoderを置き換えます。 が、これが学習してくれずに色が塗れません。ということで、現状以下のコードに落ち着きました。
VGG16でencodeし、BnとReluで強化したdecoderのコード
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self): 
        super(Encoder, self).__init__()
        num_classes = 10
        self.block1_output = nn.Sequential (
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.block2_output = nn.Sequential (
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.block3_output = nn.Sequential (
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.block4_output = nn.Sequential (
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.block5_output = nn.Sequential (
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            #nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(512*4*4, 512),  #512 * 7 * 7, 4096),
            nn.ReLU(True),
            #nn.Dropout(),
            nn.Linear(512, 32 ),  #4096, 4096),
            nn.ReLU(True),
            #nn.Dropout(),
            nn.Linear(32, num_classes)  #4096
        )
        
    def forward(self, x):
        x1 = self.block1_output(x)
        x2 = self.block2_output(x1)
        x3 = self.block3_output(x2)
        x4 = self.block4_output(x3)
        x = self.block5_output(x4)
        x0 = x.view(x.size(0), -1)
        y = self.classifier(x0)
        return x4, y

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2)),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 256, out_channels = 128,
                                          kernel_size = 2, stride = 2, padding = 0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            #nn.ConvTranspose2d(in_channels = 128, out_channels = 64,
            #                              kernel_size = 2, stride = 2),
            nn.ConvTranspose2d(in_channels = 128, out_channels = 16,
                                          kernel_size = 2, stride = 2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 2,
                                          kernel_size = 2, stride = 2)
        )        

    def forward(self, x):
        #x = x.reshape(32,256,4,4)
        x = self.decoder(x)
        return x
そして、学習コードは以下の通りです。 1.今回は、datasetが少し複雑でpytorch-lighitningのDataloaderの構造は使えませんでした。 そこで、以下のようにmain()でDataloaderを定義しています。
    data_num = 50000
    cifar10_full =ImageDataset(data_num, train=True, transform1=trans1, transform2=trans2)
    n_train = int(len(cifar10_full)*0.95)
    n_val = int(len(cifar10_full)*0.04)
    n_test = len(cifar10_full)-n_train -n_val
    cifar10_train, cifar10_val, cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
    
    trainloader = DataLoader(cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
    valloader = DataLoader(cifar10_val, shuffle=False, batch_size=32, num_workers=0)
    testloader = DataLoader(cifar10_test, shuffle=False, batch_size=32, num_workers=0)

2.データ処理trans1とtrans2などを定義
以下のように処理が複雑になったので、最初に変換処理を定義しています。

    ts = transforms.ToPILImage()
    ts2 = transforms.ToTensor()
    mean, std =[0.5,0.5,0.5], [0.25,0.25,0.25]
    ts3 =  transforms.Normalize(mean, std)
    ts4 =  transforms.Resize((64,64))
    trans2 =  transforms.Compose([
        transforms.Resize((64,64)),
        rgb2YCrCb(), #CrCb
    ])
    trans1 =  transforms.Compose([
        transforms.Resize((64,64)),
        rgb2YCrCb_(),  #Y
    ])

学習コードの順番は前回と同様です。
main()の処理順は
・変換処理準備
・Dataloader読込、
・学習準備
・学習
・テスト
・学習済checkpointの保存
・結果確認のための初期画像出力
・学習済モデル読込, freeze(), eval()
・最後に、生成画像などを出力しています
 epoch毎に生成したいので、for文で回しています

学習コードのmain()コード
def main():
    ts = transforms.ToPILImage()
    ts2 = transforms.ToTensor()
    mean, std =[0.5,0.5,0.5], [0.25,0.25,0.25]
    ts3 =  transforms.Normalize(mean, std)
    ts4 =  transforms.Resize((64,64))
    trans2 =  transforms.Compose([
        transforms.Resize((64,64)),
        rgb2YCrCb(), #CrCb
    ])
    trans1 =  transforms.Compose([
        transforms.Resize((64,64)),
        rgb2YCrCb_(),  #Y
    ])

    data_num = 50000
    cifar10_full =ImageDataset(data_num, train=True, transform1=trans1, transform2=trans2)
    n_train = int(len(cifar10_full)*0.95)
    n_val = int(len(cifar10_full)*0.04)
    n_test = len(cifar10_full)-n_train -n_val
    cifar10_train, cifar10_val, cifar10_test = torch.utils.data.random_split(cifar10_full, [n_train, n_val, n_test])
    
    trainloader = DataLoader(cifar10_train, shuffle=True, drop_last = True, batch_size=32, num_workers=0)
    valloader = DataLoader(cifar10_val, shuffle=False, batch_size=32, num_workers=0)
    testloader = DataLoader(cifar10_test, shuffle=False, batch_size=32, num_workers=0)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #for gpu
    # Assuming that we are on a CUDA machine, this should print a CUDA device:
    print(device)
    pl.seed_everything(0)

    # model
    autoencoder = LitAutoEncoder()
    autoencoder = autoencoder.to(device) #for gpu
    print(autoencoder)
    summary(autoencoder.encoder,(1,32*2,32*2))
    summary(autoencoder.decoder,(512,4,4))
    #summary(autoencoder,(1,32*2,32*2))
    
    trainer = pl.Trainer(max_epochs=1, gpus=1, callbacks=[MyPrintingCallback()]) ####epoch
    sk = 0
    for i in range(100):
        trainer.fit(autoencoder, trainloader, valloader)    
        print('training_finished')
    
        results = trainer.test(autoencoder, testloader)
        print(results)
        if sk%10==0:
            dataiter = iter(trainloader)
            _,images, images_, labels = dataiter.next()
            print(images.shape, images_.shape)

            images0 = []
            for i in range(32):
                print(i, images[i].shape, images_[i].shape)
                YCC_ = cv2.merge((np.array(images[i]).reshape(64,64),np.array(images_[i]).reshape(64,64,2)))
                images0_ = cv2.cvtColor(YCC_, cv2.COLOR_YCR_CB2BGR)
                images0.append(ts2(images0_/255.))
            # show images 
            imshow(torchvision.utils.make_grid(images0), 'cifar10_results',text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4))) #3
            # print labels
            print(' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))

            path_ = './simple_coloring/'
            PATH = path_+'example_cifar4Ln100_{}.ckpt'.format(sk)
            trainer.save_checkpoint(PATH)

            pretrained_model = autoencoder.load_from_checkpoint(PATH)
            pretrained_model.freeze()
            pretrained_model.eval()

            latent_dim,ver = "Gray2Clolor", "1_{}".format(sk)  #####save condition
            dataiter = iter(testloader)
            images0,images, images1, labels = dataiter.next() #original, Y, CrCb, label
            # show images
            imshow(torchvision.utils.make_grid(images.reshape(32,1,32*2,32*2)/255.),path_+'1_Y_cifar10_{}_{}'.format(latent_dim,0),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
            # show images0
            imshow(torchvision.utils.make_grid(images0.reshape(32,3,32,32)),path_+'2_original_cifar10_{}_{}'.format(latent_dim,0),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
            # show images0
            imshow(torchvision.utils.make_grid(ts4(images0).reshape(32,3,32*2,32*2)),path_+'3_original_normx2_cifar10_{}_{}'.format(latent_dim,0),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))    
            # show images1
            #imshow(torchvision.utils.make_grid(images1.reshape(32,3,32*2,32*2)),'normalized_images1_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))     

            encode_img,_ = pretrained_model.encoder(images[0:32].to('cpu').reshape(32,1,32*2,32*2)) #3
            decode_img = pretrained_model.decoder(encode_img)
            decode_img_cpu = decode_img.cpu()
            images2 = []
            for i in range(32):
                print(i, images[i].shape, decode_img_cpu[i].shape)
                YCC_ = cv2.merge((np.array(images[i].reshape(64,64)),np.array(decode_img_cpu[i].reshape(64,64,2))))
                images2_ = cv2.cvtColor(YCC_, cv2.COLOR_YCR_CB2BGR)
                images2.append(ts3(ts2(images2_/255.)))
                #images2.append(ts2(images2_/255.))
            imshow(torchvision.utils.make_grid(images2), path_+'4_preds_cifar10_{}_{}'.format(latent_dim,ver),text_ =' '.join('%5s' % autoencoder.classes[labels[j]] for j in range(4)))
        sk += 1

if __name__ == '__main__':
    start_time = time.time()
    main()
    print('elapsed time: {:.3f} [sec]'.format(time.time() - start_time))    
### 結果 小さな画像の比較だと、少し色合いが異なる部分もあるが、かなり再現しているようだ。
VGG16+decoder
出力画像(生成後規格化)4_preds_cifar10_Gray2Clolor_1_90.png
入力画像1_Y_cifar10_Gray2Clolor_0.png
規格化元画像2_original_cifar10_Gray2Clolor_0.png
規格化元拡大画像3_original_normx2_cifar10_Gray2Clolor_0.png
拡大してみると、馬の絵でも外郭は最初の絵よりは滑らかになり全体はよく再現しているように見える。しかし、規格化元拡大画像と比較すると背中辺りのテカリが協調されすぎていて、不自然である。
この傾向は、他の画像でも散見され、まだまだ改善が必要なようだ。
VGG16+decoder
出力画像(生成後規格化)image_preds14.png
入力画像image_gray14.png
規格化元画像image_original14_.png
規格化元拡大画像image_originalx2_14.png

テストとして別画像の色塗り

別画像は、以下の参考⑦のWikipediaのページのものを利用させていただいた。
【参考】
⑦(再掲)YUV
cifar10を100epoch学習したモデルで色塗り
色塗りは、何(文字通り学習データ)をどのように(networkの種類)、どの程度(epoch数)学習したかによって変化する。
以下に、二つの例を示す。
※残念ながら、元絵とは比較にならない悪い状況である(これはどうやら解像度に依存してようだ)が、塗りだけを比較すれば違いが分かると思う
※どちらも、草の緑が全く再現できてない

VGG16+decoder
生成画像YCC_preds_4.png
入力グレー画像YCCoriginal_gray.png
元画像YCC.png
VGG16+decoder2
YCC_preds_4.png

まとめ

・画像を輝度を表すグレー画像と色差を表すCrCbに分割し、YからCrCbを学習することにより、色塗りで遊んでみた
・前回のグレーから全体を学習する手法と比較すると、精度向上が図れた
・networkをVGG16とdecoderにBn, Reluを導入したモデルで比較的よい学習結果を得た

・色空間については、別途報告する
・精度が不十分で、学習データの(配色の)統一や解像度の向上をするとさらなる改善が見込めるような気がする

おまけ

decoder()は、3種類記載しているが、実質利用しているものは、decoder()のみでdecoder2, decoder3は切り替えて利用する。

>python simple_coloring_YCC.py
Files already downloaded and verified
cuda:0
LitAutoEncoder(
  (encoder): Encoder(
    (block1_output): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block2_output): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block3_output): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block4_output): Sequential(
      (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (block5_output): Sequential(
      (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
    )
    (classifier): Sequential(
      (0): Linear(in_features=8192, out_features=512, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=512, out_features=32, bias=True)
      (3): ReLU(inplace=True)
      (4): Linear(in_features=32, out_features=10, bias=True)
    )
  )
  (decoder): Decoder(
    (decoder2): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
      (1): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
      (2): ConvTranspose2d(128, 16, kernel_size=(2, 2), stride=(2, 2))
      (3): ConvTranspose2d(16, 2, kernel_size=(2, 2), stride=(2, 2))
    )
    (decoder): Sequential(
      (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): ConvTranspose2d(128, 16, kernel_size=(2, 2), stride=(2, 2))
      (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): ConvTranspose2d(16, 2, kernel_size=(2, 2), stride=(2, 2))
    )
    (decoder3): Sequential(
      (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Upsample(scale_factor=2.0, mode=nearest)
      (4): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU()
      (10): Upsample(scale_factor=2.0, mode=nearest)
      (11): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (13): ReLU()
      (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (16): ReLU()
      (17): Upsample(scale_factor=2.0, mode=nearest)
      (18): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (19): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (20): ReLU()
      (21): Conv2d(32, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (22): Upsample(scale_factor=2.0, mode=nearest)
    )
  )
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 64, 64]             640
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
            Conv2d-4           [-1, 64, 64, 64]          36,928
       BatchNorm2d-5           [-1, 64, 64, 64]             128
              ReLU-6           [-1, 64, 64, 64]               0
         MaxPool2d-7           [-1, 64, 32, 32]               0
            Conv2d-8          [-1, 128, 32, 32]          73,856
       BatchNorm2d-9          [-1, 128, 32, 32]             256
             ReLU-10          [-1, 128, 32, 32]               0
           Conv2d-11          [-1, 128, 32, 32]         147,584
      BatchNorm2d-12          [-1, 128, 32, 32]             256
             ReLU-13          [-1, 128, 32, 32]               0
        MaxPool2d-14          [-1, 128, 16, 16]               0
           Conv2d-15          [-1, 256, 16, 16]         295,168
      BatchNorm2d-16          [-1, 256, 16, 16]             512
             ReLU-17          [-1, 256, 16, 16]               0
           Conv2d-18          [-1, 256, 16, 16]         590,080
      BatchNorm2d-19          [-1, 256, 16, 16]             512
             ReLU-20          [-1, 256, 16, 16]               0
           Conv2d-21          [-1, 256, 16, 16]         590,080
      BatchNorm2d-22          [-1, 256, 16, 16]             512
             ReLU-23          [-1, 256, 16, 16]               0
        MaxPool2d-24            [-1, 256, 8, 8]               0
           Conv2d-25            [-1, 512, 8, 8]       1,180,160
      BatchNorm2d-26            [-1, 512, 8, 8]           1,024
             ReLU-27            [-1, 512, 8, 8]               0
           Conv2d-28            [-1, 512, 8, 8]       2,359,808
      BatchNorm2d-29            [-1, 512, 8, 8]           1,024
             ReLU-30            [-1, 512, 8, 8]               0
           Conv2d-31            [-1, 512, 8, 8]       2,359,808
      BatchNorm2d-32            [-1, 512, 8, 8]           1,024
             ReLU-33            [-1, 512, 8, 8]               0
        MaxPool2d-34            [-1, 512, 4, 4]               0
           Conv2d-35            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-36            [-1, 512, 4, 4]           1,024
             ReLU-37            [-1, 512, 4, 4]               0
           Conv2d-38            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-39            [-1, 512, 4, 4]           1,024
             ReLU-40            [-1, 512, 4, 4]               0
           Conv2d-41            [-1, 512, 4, 4]       2,359,808
      BatchNorm2d-42            [-1, 512, 4, 4]           1,024
             ReLU-43            [-1, 512, 4, 4]               0
           Linear-44                  [-1, 512]       4,194,816
             ReLU-45                  [-1, 512]               0
           Linear-46                   [-1, 32]          16,416
             ReLU-47                   [-1, 32]               0
           Linear-48                   [-1, 10]             330
================================================================
Total params: 18,933,546
Trainable params: 18,933,546
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 26.26
Params size (MB): 72.23
Estimated Total Size (MB): 98.50
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
   ConvTranspose2d-1            [-1, 256, 8, 8]         524,544
       BatchNorm2d-2            [-1, 256, 8, 8]             512
              ReLU-3            [-1, 256, 8, 8]               0
   ConvTranspose2d-4          [-1, 128, 16, 16]         131,200
       BatchNorm2d-5          [-1, 128, 16, 16]             256
              ReLU-6          [-1, 128, 16, 16]               0
   ConvTranspose2d-7           [-1, 16, 32, 32]           8,208
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
  ConvTranspose2d-10            [-1, 2, 64, 64]             130
================================================================
Total params: 664,882
Trainable params: 664,882
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.03
Forward/backward pass size (MB): 1.56
Params size (MB): 2.54
Estimated Total Size (MB): 4.13
----------------------------------------------------------------
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2021-02-07 19:03:18.472890: W tensorflow/stream_executor/platform/default/dso_loader.cc:60] Could not load dynamic library 'cudart64_110.dll'; dlerror: cudart64_110.dll not found
2021-02-07 19:03:18.472992: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 18.9 M
1 | decoder | Decoder | 3.1 M
------------------------------------
22.0 M    Trainable params
0         Non-trainable params
22.0 M    Total params
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?