Python
機械学習
DeepLearning
ディープラーニング
Keras

[失敗編]AIに助けてもらって天の川を描く

7/7日に、ふと思い立って天の川を描きたくなりました。しかし、残念なことに、デザイナーからジョブチェンしてから長い期間が経過しており、ペンタブにまったく触っておらずもう描ける気がしません。画材もすべて知り合いに譲ってしまいました。
そもそも、Qiita記事を書いたりはタブレットで対応できてしまいますし、計算もクラウドのものを使っているため、デスクトップが壊れてからはノートPCですら買っておらず、デザインツールも入れていません。今手元にあるのは、Win10タブレットと、ペイントのみ・・

では、諦めるしかないのか。

いえ、折角データサイエンスを勉強しているので、AIに助けてもらう形で天の川を描けないか、抗ってみましょう。

こんなの作りたい

image.png

↑みたいな、ペイントでも描ける、白黒2色の画像をAIに食べさせると、

↑みたいな画像を作ってくれるAIモデル。(白い点の部分が星といった具合に)

※フリー素材をお借りしました : 手を伸ばせばつかめそうな宮古島の星空(天の川)

使うものや、そのほか勉強ついでに検証すること

  • pix2pix : 論文
    • ※線画を着色したり、夏の画像を冬にしたり、色々できるGANsの論文です。
  • pix2pixのKeras implementation のgithubリポジトリ
  • Google Colaboratoryで、データセットや重みなどのデータのやりとりを色々検証してみます。
  • 結構pix2pixがデータセットが少なくても結構それっぽいものができてしまう経験を以前しましたが、さらに少ないとどうなるのか・・・ということで、今回かなり少ないデータセットで進めてみます。もちろん、数が多い方が汎化性能などで好ましいことは分かりますが、どのくらいで許容できるレベルになるのか、という点にも興味があります。休日にデータセットを集めたりアノテーションといったしんどい作業やりたくない
  • Colaboratoryで700エポックくらい回してみます。それなりに長時間になるため、2回に分けて、2回目は学習済みの重みのデータを読み込んでから学習をスタートさせる必要があります。そのため、pix2pixのコードを読んで、調整しつつColaboratoryのノート上に持ってきます。

どのくらいデータを集めたの?

  • 天の川といったキーワードで、手動で70強程度の画像を集めました。
  • それらをトリミングして、反転したりして400程度のデータセットにし、train、test、valと割り振ります。
    • ※本当はtrainの割り振りを多くすべきだとは思いますが、前述のKerasのコードが割り振り最低100件といった感じになていたため、結果的にtrainの件数がそこまで多くなっていません。(200弱)
    • ※今振り返ると、Colaboratoryのノートにもっていったので、その際にコード調整しても良かった感はあります。

CIFAR10だったりのメジャーなデータセットが数万件といった具合なので、それらと比べると圧倒的に数が少ない、少数精鋭(?)な感じですが、どうなるのか・・

結果

数は力だった・・あまり綺麗ではない出力結果となりました。(やっぱり、数千件程度はないと厳しいか・・という印象)

まあでも、半分のサイズで表示すると少しそれっぽく見えます。残念なので、以降は半分のサイズでお送り致します。

食べさせた画像 :

生成された画像 :

食べさせた画像 :

生成された画像 :

食べさせた画像(なんとなく、下の方に山的なものが映っている画像を作りたかった) :

生成された画像 :

なんとなく宇宙っぽいけれども・・まだ合成画像感が結構残っています。今回はうまくいかなかったですが、機会があればデータセットをもっと増やしたり、精査したりして再チャレンジしたいところ。(色々得られた点があるので、今回はこれでOKとして終わりにします。)

技術的なお話

  • ※結構長く、煩雑に、そしてあーじゃないこーじゃない、と途中やっていたので、コードは綺麗ではありません。(綺麗にする前に力尽きました)
  • なので、「ここ記述が変だ」とか「ここうまく動かない」といった個所も結構あると思いますが、今回はマサカリは弱めにお願いします。(個人メモ的な側面が・・)
  • 細かいところが不要な方は、この時点でいいねを押していただいてからブラウザバック推奨です!

Colaboratoryでのファイルの受け渡し

Colaboratoryに本格的にお世話になり始めていますが、ファイルの受け渡しが問題になりました。
主に、データセットのアップロードと、学習後の重みのデータなどの保存が課題となります。

特に、通常は学習中ずっと見ながら待っているわけにもいかず、Colaboratoryが12時間経過でノート以外のファイルがすべて削除されるという仕様の都合、学習の後のノートの最後などに、残す必要があるファイルを退避させる処理が必要になります。(学習が終わった後にその退避用のセルが実行されるようにしておくなど)

今回は、以下のものを試してみました。

  1. GoogleDriveをマウントしてみる。
  2. Colaboratoryで用意されている、upload関数やdownload関数を使ってみる。
  3. GoogleDriveで公開URLを設定して、wgetコマンドなどでアップロードする。
  4. 外部のクラウドサービスのものをマウントしてみる。

※今回APIは利用しませんでしたが、APIも良さそうです。

1.GoogleDriveをマウントしてみる。

同じGoogleのサービス同士なので、結構使いやすいのかな?と思いましたが、最初結構インストールなどが色々必要な点と、認証的なところでセルの実行以外にも毎回コピペ作業が発生する?印象を受けて、しっくりきませんでした。(実験を何度も繰り返す際に、すべてのセル実行だけで完結するようにしておきたい)

今後、よりシンプルになってくれることを期待しつつ(GCPもどんど進化していっている点も加味し)、今回は見送りました。

2. Colaboratoryで用意されている、upload関数やdownload関数を使ってみる。

google.colab.files.upload関数やgoogle.colab.files.download関数を使う方法です。

Firefoxでは動かないので、Chromeを使う必要があるといった制約はあるものの、インストールなども不要で、セルの実行だけで対応できるのは魅力です。

しかし・・現状速度が遅すぎます。数十MB程度のファイルのアップロードなどで、数分~10分くらいかかったりします。すぐに実験したい、という際に結構しんどく感じます。なんどか使っていましたが、途中でしんどくなって、この方法は見送りました。

どうやら、この方法は本当に小さいファイル(ノートなど)向けで、重みやデータセットを扱うには不向きなようです。

3. GoogleDriveで公開URLを設定して、wgetコマンドなどでアップロードする。

URLが分かるとそのファイルにアクセスできてしまうので、仕事では使えませんが、今回は公開になっても問題ないファイルなので、アップロードはこちらは使ってみました。

GoogleDriveで対象のファイルで右クリックして、「共有可能なリンクを取得」を選択すると発行されるURLの、id部分の文字列を以下のコマンドに設定してColaboratory上で実行するだけです。公開になっても大丈夫なファイルであれば、とてもお手軽にアップロードができ、処理もすぐに終わります。

!wget "https://drive.google.com/uc?export=download&id=<対象ファイルのid部分のパラメーター>" -O "<保存先のファイルパス>"

ただ、アップロードはできても重みなどの退避のためのダウンロードがこれだとできません。

4.外部のクラウドサービスのものをマウントしてみる。

google Colaboratoryでのデータ永続化の記事にある、webdavのものを参考にさせていただきました。

1つのセルの実行だけでマウントできますし、毎度のコピペも不要で、普通に読み書きやLinuxのコマンドで操作ができます。

100MBなどのファイルだと、少しweb上の画面に正常にアップし終わるまでは時間がかかるようですが、問題ない範囲です。

一時的な重みの退避対応やデータセットのアップであれば無料枠の10~15GBで十分なので、シンプルに使えるというのもあり、今後はしばらくこちらを使っていこうと思います。(容量増やしたくなったら有料プランに)

天の川の画像を集める

数が多いと、仕事ではクローラーやスクレイピングのスクリプトを書いたりしますが、今回は70件程度なので気合で集めました。

途中で、Google画像検索で検索結果で気に入った画像をリストに保存できる機能が追加になっているのに気づき、便利に使わせていただきました。

トリミングや反転などを行う

この辺りは、ローカルのjupyterで実行しています。

pillow==5.2.0 を使用しました。
以降の記述で、Imageクラスのclose関数の記述は省略していますが、実際に大量の画像に対して処理を行う際には適宜close関数を呼び出して扱ってください。

今回は縦横256pxの画像で対応しているため、横長画像であれば縦幅をまずは256pxになるように縮小(resize関数)し、特定の領域をずらして3画像ずつ(左端・中央・右端といった具合に)正方形の画像をトリミング(crop関数)していきました。

サンプルコード :

from PIL import Image
raw_img = Image.open(raw_file_path)
resize_ratio = 256 / raw_img.height
width = int(raw_img.width * resize_ratio)
_resized_img = raw_img.resize(size=(width, 256), resample=Image.LANCZOS)
resized_img = _resized_img.convert(mode='RGB')
center_box = (
    int(resized_img.width / 2 - 128),
    0,
    int(resized_img.width / 2 - 128) + 256,
    256)
center_img = resized_img.crop(box=center_box)

反転などもして画像の件数の水増しを行いましたが、左右の反転はtranspose関数で、FLIP_LEFT_RIGHTを指定するだけです。

reflected_img = original_img.transpose(method=Image.FLIP_LEFT_RIGHT)

2値変換した画像を用意する

ペイントの画像を食べさせる都合、ペイントで作るのが楽な白黒の2値の画像を用意します。

今回はピクセルの色が、100未満であれば0(黒)、それ以外であれば255(白)となるように指定しました。

サンプルコード :

color_img = Image.open(color_img_file_path)
gray_img = color_img.convert(mode='L')
binary_img = gray_img.point(lambda x: 0 if x < 100 else 255)

pix2pix用の画像を用意する

pix2pixで用意されているスクリプトのためのデータセットは、左に実際の画像(GROUND TRUTH、今回は集めた天の川の画像)、右に学習用の入力画像(今回は生成した白黒の画像)の幅512px高さ256pxの画像を用意する必要があります。paste関数などを使って作っていきましょう。

canvas_img = Image.new(
    mode='RGB', size=(512, 256), color='#000000')

color_img = Image.open(fp=COLOR_IMG_DIR_PATH + img_name)
canvas_img.paste(im=color_img, box=(0, 0))

binary_img = Image.open(fp=BINARY_IMG_DIR_PATH + img_name)
canvas_img.paste(im=binary_img, box=(256, 0))

フォルダの割り振りとzip化

pix2pix用に、test、train、valという名称のディレクトリに、それぞれ画像を割り振っていきます。
基本的にtrainが一番多くなるように、train7割、残り1割5分ずつといった画像数ずつランダムに割り振っていきます。ただし、前述のとおりgithubのコードが100件未満だと調整しないとエラーで怒られてしまう都合、今回はtestとvalを100件にしているため、train比率が少なめになっていました。

それぞれ、拡張子はjpgで、以下のように1から順番に割り振っていきます。

  • train/
    • 1.jpg
    • 2.jpg
    • ...
  • test/
    • 1.jpg
    • 2.jpg
    • ...
  • val/
    • 1.jpg
    • 2.jpg
    • ...

割り振りが終わったらアップ用にzip化しておきます。(もしくはマウントした領域にそのままアップしてもいいかもしれません)

指定のディレクトリのzip化はshutilモジュールのmake_archive関数を使いました。
base_nameにはzipのファイル名を指定します。

import shutil

shutil.make_archive(
    base_name='pix2pix_milky_way_dataset', format='zip',
    root_dir='<train、testなどのフォルダを含んでいるディレクトリ>')

pix2pixの環境を作る

ここからはColaboratory上の作業となります。
cloneしたり、アップしたファイルを参照してデータセット作成用のスクリプトを流したりしていきます。

※事前に以前の記事でもColaboratory上で動かせることを確認しています。(Google Colaboratoryでpix2pixを動かしてみる。

まずはColaboratoryでノートを開いて、GPUを有効にします。そののちに、用意したアカウントなどでディスクスペースのマウント対応をしておきます。

必要なライブラリなどをインストール。

!pip install parmap==1.5.1
!pip install tqdm==4.17.0
!pip install opencv_python==3.3.0.10
!pip install h5py==2.7.0
!pip install numpy==1.13.3
!pip install pydot==1.2.3
!pip install graphviz==0.8.1
!pip install pydot3==1.0.9
!pip install pydot-ng==1.0.0
!apt-get install graphviz

必要なコードのclone処理やディレクトリ操作。

import os
import shutil
import zipfile
from datetime import datetime

from IPython.display import display
from PIL import Image
from google import colab

if not os.path.exists('/content/workspace/'):
    os.makedirs('/content/workspace/')
os.chdir('/content/workspace/')
!git clone https://github.com/phillipi/pix2pix.git
!git clone https://github.com/tdeboissiere/DeepLearningImplementations.git
!rsync -a /content/workspace/DeepLearningImplementations/pix2pix/ /content/workspace/pix2pix/

用意したデータセットのzipをアップして展開する

  • train、test、valといったフォルダを含んだzipをColaboratory上にアップして、必要なフォルダに展開します。
  • pix2pix/datasets/ 以下に任意の名称で配置します。今回はmilky_wayという名称にしました。以下のようなディレクトリ構成になります。

  • /content/workspace/pix2pix/datasets/milky_way/

    • train/
    • test/
    • val/

※余談ですが、/content がColaboratoryでのルートディレクトリになります。作業中、頻繁に作業ディレクトリをchdir関数で切り替えたりしている都合、想定的な指定よりもルートディレクトリからの絶対パス的に指定するとエラーになりにくく良いな・・と途中で思いルートディレクトリから大部分を指定するようにしています。

  • 今回はGoogleDriveにアップして「共有可能なリンクを取得」し、そのIDを利用してwgetコマンドを使いました。マウントした領域から移してきて展開する形でも問題ありません。
def upload_zip_and_extract_from_google_drive(zip_file_id, extract_dest_dir_path):
    """
    GoogleDrive側で公開URLが設定されてあるzipファイルのアップロードを行い、
    指定のディレクトリにzip内のファイルの展開を行う。

    Notes
    -----
    Colaboratory側で用意されているupload関数が、大きいファイルの場合
    処理時間が相当長くなるので、代替としてwgetコマンドで対応している。

    Parameters
    ----------
    zip_file_id : str
        対象のファイルをGoogleDrive上で公開URLを設定した際に、URLの
        idパラメーターに設定されているハッシュ化された文字列。
    extract_dest_dir_path : str
        zipの内容の展開先のディレクトリ。
    """
    os.system(
        'wget "https://drive.google.com/uc?export=download&id=%s" -O "./temp.zip"' % \
        zip_file_id)
    with zipfile.ZipFile('./temp.zip') as uploaded_zip_file:
        uploaded_zip_file.extractall(extract_dest_dir_path)
    os.remove('./temp.zip')


upload_zip_and_extract_from_google_drive(
    zip_file_id='<対象のGoogleDrive上のIDの文字列>',
    extract_dest_dir_path='/content/workspace/pix2pix/datasets/milky_way/')

コマンドでpix2pixで必要な、HDF5形式のデータセットを用意します。

os.chdir('/content/workspace/pix2pix/src/data/')
!python make_dataset.py /content/workspace/pix2pix/datasets/milky_way/ 3 --img_size 256

※ここで、各フォルダで最低100件は画像が入っていないとエラーになったので、githubのコードをしばらく読んで悩んでいたり・・

学習を動かす

  • とりあえず、200エポックほど学習させます。
    • ※今回は、大体1エポックあたりGPUを使って70秒程度だったようです。
  • 元のコードで、5エポックごとに重みが保存されるようになっています。そのため、最後の保存された重みをマウントした領域に退避させる処理を書いておきます。(学習を始めたあと、寝ても大丈夫なように)
  • MOUNTED_DIR_PATH の値は、マウントした領域で、重みを保存するディレクトリのパスを設定しておいてください。
os.chdir('/content/workspace/pix2pix/src/model/')
def save_last_epoch_weights_to_mounted_dir():
    """
    最後のエポックの重みのデータをマウントされているディレクトリに
    コピーを行う。(12時間経過で消えてしまわないようにするため)

    Notes
    -----
    ある程度、アップが終わるまで時間がかかる。(それまでは、関数
    実行時点ではファイルサイズが0Bと表示されたり、lsコマンドでの
    リスト表示ができないのでアクセスする場合には注意。)
    """
    current_time_str = datetime.now().strftime('%Y%m%d%H%M%S')

    file_name_list = os.listdir('/content/workspace/pix2pix/models/CNN/')
    epoch_num_list = []
    for file_name in file_name_list:
        epoch_num_str = file_name.split('weights_epoch')[1]
        epoch_num_str = epoch_num_str.replace('.h5', '')
        epoch_num_list.append(int(epoch_num_str))
    last_epoch = max(epoch_num_list)
    print('last_epoch :', last_epoch)
    DCGAN_weights_file_name = 'DCGAN_weights_epoch%s.h5' % last_epoch
    disc_weights_file_name = 'disc_weights_epoch%s.h5' % last_epoch
    gen_weights_file_name = 'gen_weights_epoch%s.h5' % last_epoch

    print(datetime.now(), 'DCGAN weights copy started.')
    shutil.copy(
        src='/content/workspace/pix2pix/models/CNN/' + DCGAN_weights_file_name,
        dst=MOUNTED_DIR_PATH + current_time_str + DCGAN_weights_file_name)
    print(datetime.now(), 'discriminator weights copy started.')
    shutil.copy(
        src='/content/workspace/pix2pix/models/CNN/' + disc_weights_file_name,
        dst=MOUNTED_DIR_PATH + current_time_str + disc_weights_file_name)
    print(datetime.now(), 'generator weights copy started.')
    shutil.copy(
        src='/content/workspace/pix2pix/models/CNN/' + gen_weights_file_name,
        dst=MOUNTED_DIR_PATH + current_time_str + gen_weights_file_name)

以下のコマンドで学習がスタートします。

!python main.py 64 64 --backend tensorflow --nb_epoch 200 --dset milky_way

以下のセルも学習が終わったら実行するようにしておきます。

最後の保存されたエポックの保存処理 :

save_last_epoch_weights_to_mounted_dir()

どんな感じなのか、学習後の推論結果を表示(学習画像・推論画像・GROUND TRUTHの表示) :

display(Image.open('/content/workspace/pix2pix/figures/current_batch_validation.png'))

さらに学習を重ねる

  • 前述までの段階で、200エポック分の学習が終わっています。
  • ただ、最終的に表示された画像を見ると結構まだ微妙な印象があったため、追加でさらに500エポック学習させます。(以前pix2pixを動かした際にも、そうやって1200エポックあたりまで学習させたところ、ある程度安定してきたという経験もあります)
  • Colaboratory上で、pix2pixのモデルで重みのデータを読み込んでから学習がスタートするようにするため、githubのコードを参考に、調整を加えながらColaboratoryにコードを持ってきます。(コマンド経由で学習を行わずに、ノート上で動かしていきます。)
  • 前述の200エポック後、一度12時間経過してファイルの削除などがされていると思います。そのため、一度再度pix2pixのコードのcloneやマウント処理、GPUの有効化などをしてください。
    • ※もし12時間経過していない場合、おそらく後述の500エポック分学習させる過程で停止とファイルの削除が走ってしまうと思います。
  • 別のノートを用意して、進めていきます。
import os
import sys
sys.path.append('/content/workspace/pix2pix/src/utils/')
sys.path.append('/content/workspace/pix2pix/src/model/')
import time
import shutil
from datetime import datetime

from PIL import Image
from keras.utils import generic_utils
from keras.optimizers import Adam
import keras.backend as K
import numpy as np
from IPython.display import display
from PIL import Image

import data_utils
import models

MOUNTED_DIR_PATH = '<重みを保存したマウントしたディレクトリ>'

os.chdir('/content/workspace/pix2pix/src/model/')

以下は大体、githubのコードそのままで、コマンドラインを経由しない点などの調整を加えてあります。

X_full_train, X_sketch_train, X_full_val, X_sketch_val = data_utils.load_data(
    dset='milky_way', image_data_format='channels_last')
img_dim = X_full_train.shape[-3:]
nb_patch, img_dim_disc = data_utils.get_nb_patch(
    img_dim=(256, 256, 3), patch_size=(64, 64),
    image_data_format='channels_last')
opt_dcgan = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
opt_discriminator = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
generator_model = models.load(
    model_name='generator_unet_upsampling',
    img_dim=(256, 256, 3), nb_patch=16, bn_mode=2, use_mbd=False, batch_size=4)
discriminator_model = models.load(
    model_name='DCGAN_discriminator', img_dim=(64, 64, 3), nb_patch=16,
    bn_mode=2, use_mbd=False, batch_size=4)

以下の部分で、前回学習させて退避させていた重みのデータを読み込ませています。(load_weights関数部分)

generator_model.load_weights(
    MOUNTED_DIR_PATH + '20180709105254gen_weights_epoch195.h5')
discriminator_model.load_weights(
    MOUNTED_DIR_PATH + '20180709105254disc_weights_epoch195.h5')
generator_model.compile(loss='mae', optimizer=opt_discriminator)
discriminator_model.trainable = False
DCGAN_model = models.DCGAN(
    generator=generator_model,
    discriminator_model=discriminator_model,
    img_dim=(256, 256, 3), patch_size=(64, 64),
    image_dim_ordering='channels_last')
DCGAN_model.load_weights(
    filepath=MOUNTED_DIR_PATH + '20180709105254DCGAN_weights_epoch195.h5')

def l1_loss(y_true, y_pred):
    return K.sum(K.abs(y_pred - y_true), axis=-1)

loss = [l1_loss, 'binary_crossentropy']
loss_weights = [1E1, 1]

DCGAN_model.compile(
    loss=loss, loss_weights=loss_weights, optimizer=opt_dcgan)
discriminator_model.trainable = True
discriminator_model.compile(
    loss='binary_crossentropy', optimizer=opt_discriminator)

gen_loss = 100
disc_loss = 100

学習エポック数の指定(nb_epoch)。前回、200エポック学習させる際に、5エポックごと(剰余が0のとき)に保存されるようにコードがなっていましたが、最初のエポックも保存される都合、それだと最後のエポックが保存されないじゃないか・・ということで、今回は501と1多くしてあります。ほかの設定などは、コマンドラインから経由したときのデフォルト値などを設定しています。

nb_epoch = 501
batch_size = 4
n_batch_per_epoch = 100
epoch_size = n_batch_per_epoch * batch_size
patch_size = (64, 64)
image_data_format = 'channels_last'
model_name = 'CNN'

学習の重みデータが同じディレクトリに保存されるので、以前のものは削除しておきます。

# 過去の学習データの重みデータを削除しておく。
shutil.rmtree('/content/workspace/pix2pix/models/CNN/',
              ignore_errors=True)
os.makedirs('/content/workspace/pix2pix/models/CNN/')

以下の部分もgithubのコードほぼそのままです。

ただ、それなりに長時間の学習となるため、途中で止まった場合などのことも加味して、5エポックごとではなく50エポックごとに重みを保存するようにし( if e % 50 == 0: の部分)、その時点での重みをマウントした領域へ退避し( save_last_epoch_weights_to_mounted_dir 部分 )、その時点での推論結果の画像をノート上に表示( display(Image.open 部分)するように調整してあります。

for e in range(nb_epoch):
    progbar = generic_utils.Progbar(epoch_size)
    batch_counter = 1
    start = time.time()

    for X_full_batch, X_sketch_batch in \
            data_utils.gen_batch(X_full_train, X_sketch_train, batch_size):

        X_disc, y_disc = data_utils.get_disc_batch(
            X_full_batch,
            X_sketch_batch,
            generator_model,
            batch_counter,
            patch_size,
            image_data_format,
            label_smoothing=0,
            label_flipping=0)
        disc_loss = discriminator_model.train_on_batch(X_disc, y_disc)

        X_gen_target, X_gen = next(
            data_utils.gen_batch(X_full_train, X_sketch_train, batch_size))
        y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8)
        y_gen[:, 1] = 1


        discriminator_model.trainable = False
        gen_loss = DCGAN_model.train_on_batch(X_gen, [X_gen_target, y_gen])
        discriminator_model.trainable = True

        batch_counter += 1
        progbar.add(
            batch_size, 
            values=[('D logloss', disc_loss),
                    ('G tot', gen_loss[0]),
                    ('G L1', gen_loss[1]),
                    ('G logloss', gen_loss[2])])

        if batch_counter % (n_batch_per_epoch / 2) == 0:
            data_utils.plot_generated_batch(
                X_full_batch, X_sketch_batch, generator_model,
                batch_size, image_data_format, 'training')
            X_full_batch, X_sketch_batch = next(data_utils.gen_batch(
                X_full_val, X_sketch_val, batch_size))
            data_utils.plot_generated_batch(
                X_full_batch, X_sketch_batch, generator_model,
                batch_size, image_data_format, 'validation')

        if batch_counter >= n_batch_per_epoch:
            break

    print('')
    print('Epoch %s/%s, Time: %s' % (e + 1, nb_epoch, time.time() - start))

    if e % 50 == 0:
        gen_weights_path = os.path.join(
            '../../models/%s/gen_weights_epoch%s.h5' % (model_name, e))
        generator_model.save_weights(gen_weights_path, overwrite=True)

        disc_weights_path = os.path.join(
            '../../models/%s/disc_weights_epoch%s.h5' % (model_name, e))
        discriminator_model.save_weights(disc_weights_path, overwrite=True)

        DCGAN_weights_path = os.path.join(
            '../../models/%s/DCGAN_weights_epoch%s.h5' % (model_name, e))
        DCGAN_model.save_weights(DCGAN_weights_path, overwrite=True)

        save_last_epoch_weights_to_mounted_dir()
        display(Image.open(
            '/content/workspace/pix2pix/figures/current_batch_validation.png'))

あとは学習終わるまで寝て待ちます。

推論!

まあ・・残念な結果に終わったのは、タイトルや結果を前述しているので、結果が分かっていてあまり面白くないところではあると思いますが、実際に学習が終わった後の重みのデータを参照して、ペイントで作った画像をノート上で食べさせる処理も載せておきます。

推論処理自体は、GPUが有効になっていなくてもすぐに終わります。(重みの読み込みは時間がかかりますが)
以降は別のノートを用意して進めています。

450エポック(+200エポックで、実質650エポック)の重みが500エポックのものより良さそうだったので、そちらを読み込んで使います。

import os
import sys
sys.path.append('/content/workspace/pix2pix/src/model/')
sys.path.append('/content/workspace/pix2pix/src/utils/')

from PIL import Image
import numpy as np

import data_utils
import models

MOUNTED_DIR_PATH = '<重みを保存したマウントしたディレクトリ>'

os.chdir('/content/workspace/pix2pix/src/model/')
generator_model = models.load(
    model_name='generator_unet_upsampling',
    img_dim=(256, 256, 3), nb_patch=16, bn_mode=2, use_mbd=False, batch_size=4)

discriminator_model = models.load(
    model_name='DCGAN_discriminator', img_dim=(64, 64, 3), nb_patch=16,
    bn_mode=2, use_mbd=False, batch_size=4)

generator_model.load_weights(
    MOUNTED_DIR_PATH + '20180710061705gen_weights_epoch450.h5')

discriminator_model.load_weights(
    MOUNTED_DIR_PATH + '20180710061705disc_weights_epoch450.h5')

DCGAN_model = models.DCGAN(
    generator=generator_model,
    discriminator_model=discriminator_model,
    img_dim=(256, 256, 3), patch_size=[64, 64],
    image_dim_ordering='channels_last')

DCGAN_model.load_weights(
    filepath=MOUNTED_DIR_PATH + '20180710061705DCGAN_weights_epoch450.h5')

ペイントでちまちま画像を作って、マウントされているところにアップしておきます。今回は 推論用入力画像/20180709_1.jpg といった具合に配置していったという前提で進めます。

def get_predicted_img(input_jpg_path):
    """
    対象のペイントで作成した画像から、天の川の推論後の
    画像を取得する。

    Parameters
    ----------
    input_jpg_path : str
        入力用のペイントで作成したjpgのパス。

    Returns
    -------
    predicted_img : Image
        推論された天の川画像。
    """
    img = Image.open(input_jpg_path)
    display(img)
    img_arr = np.array(img)
    normalized_img_arr = data_utils.normalization(X=img_arr)
    normalized_img_tensor = normalized_img_arr[np.newaxis, :, :, :]
    predicted_img_tensor = DCGAN_model.predict(x=normalized_img_tensor)
    predicted_img_arr = predicted_img_tensor[0][0]
    inversed_predicted_arr = data_utils.inverse_normalization(
        X=predicted_img_arr)
    inversed_predicted_arr *= 255
    inversed_predicted_arr = inversed_predicted_arr.astype(np.uint8)
    predicted_img = Image.fromarray(inversed_predicted_arr)
    img.close()
    return predicted_img

以下のように実行すると、スクショのように上に入力画像、下に推論結果の画像が表示されます。

get_predicted_img(
    input_jpg_path=MOUNTED_DIR_PATH + '推論用入力画像/20180709_2.jpg')

image.png

感想

  • 結果は残念でしたが、色々Colaboratoryのことを知れたり、pix2pixのコードを読んだりして勉強になり、そして楽しめました。
  • もっと複雑な画像で、クオリティが要求される条件でうまくいったことがあるので、単純にデータセットの件数と精査をあまりしていない点に依存するのだろうなぁ・・と思います。(いつか、個人では大変ですがしっかりとデータセット周りを対応した記事を上げたい)
  • タブレットで本格的なディープラーニングを動かせられて、Colaboratory楽しい!