LoginSignup
5
6

More than 3 years have passed since last update.

AIに猫ちゃんを生成させてほっこりするぞ!~猫といく潜在変数空間の旅~

Last updated at Posted at 2020-12-01

さいしょに

この記事は顔学2020アドベントカレンダーの4日目の記事です.
顔学といっても顔は人だけのものではありません.今回は猫ちゃんを中心に人以外の顔の生成について,画像生成の最先端技術を絡めてお話します.

StyleGAN2の学習済モデル

NVIDIAのStyleGAN2は学習済モデルが豊富です.今までの投稿でも何度も出現しているStyleGAN2ですが,まだ人の顔の生成モデルしか紹介していなかったので今回は別のモデルを試してみた結果を紹介します.
stylegan2-teaser-1024x256.png
(引用: https://github.com/NVlabs/stylegan2)

過去のStyleGAN2関連の記事

FFHQデータセット

高解像度の多種多様な人間の顔を含むデータセットで,標準で語られるStyleGAN2のモデルはFFHQで学習されています.

ffhq-teaser.png
(引用: https://github.com/NVlabs/ffhq-dataset)

LSUNデータセット

LSUNデータセットはシーン認識用に作成された大規模な画像データセットです.この中のカテゴリに猫ちゃんや馬,教会の画像などが含まれています.

teaser_web.jpg
(引用: https://www.yf.io/p/lsun)

コード

今回はLSUNの人以外の画像で学習されたモデルを使う方法と,実行結果を紹介します.

この記事で紹介するコードは,StyleGAN2 Google Colab Exampleのコメントを日本語に直して解説を少し加えたものです.

動作環境

動作環境はGoogle Colaboratoryです.

ライブラリのインポート

必要なライブラリをインポートします.Colabではデフォルトでtensorflow2.x系がインストールされているので1.x系を利用することを明示的に示します.

%tensorflow_version 1.x
import tensorflow as tf
import argparse
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import re
import sys
from io import BytesIO
import IPython.display
import numpy as np
from math import ceil
from PIL import Image, ImageDraw
import imageio
import pretrained_networks

事前準備

まず,StyleGAN2のNVIDIAの公式実装をGitHubからクローンし,動作環境がそろっていることを確認します.

> git clone https://github.com/NVlabs/stylegan2.git
> cd stylegan2
> nvcc test_nvcc.cu -o test_nvcc -run
> !nvidia-smi -L

学習済モデルのロード

利用したい学習済モデルのパスを指定します.ローカルのあるモデルを指定したい場合はパスを相対パスが必要です.

利用可能なモデルは以下の通りです.基本的にはconfigの設定はfを選択するようにしてください.aはただのStyleGANです.

  • 1024×1024 pixelの人の顔

    • stylegan2-ffhq-config-a.pkl
    • stylegan2-ffhq-config-b.pkl
    • stylegan2-ffhq-config-c.pkl
    • stylegan2-ffhq-config-d.pkl
    • stylegan2-ffhq-config-e.pkl
    • stylegan2-ffhq-config-f.pkl
  • 512×384 pixelの車

    • stylegan2-car-config-a.pkl
    • stylegan2-car-config-b.pkl
    • stylegan2-car-config-c.pkl
    • stylegan2-car-config-d.pkl
    • stylegan2-car-config-e.pkl
    • stylegan2-car-config-f.pkl
  • 256x256 pixelの馬

    • stylegan2-horse-config-a.pkl
    • stylegan2-horse-config-f.pkl
  • 256x256 pixelの教会

    • stylegan2-church-config-a.pkl
    • stylegan2-church-config-f.pkl
  • 256x256 pixelの猫ちゃん

    • stylegan2-cat-config-f.pkl
    • stylegan2-cat-config-a.pkl
network_pkl = "gdrive:networks/stylegan2-ffhq-config-f.pkl"
_G, _D, Gs = pretrained_networks.load_networks(network_pkl) # Gsにモデルの全容が入っている
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')] 

dnnlibを扱う上で便利な関数

NVIDIAの実装にはdnnlibという独自のライブラリが含まれています.こちらはドキュメントなども公開されておらず,理解するのにかなり時間がかかるのでこれを意識しないためにも関数でラップしてあげます.

幸いにもExampleにはラップ済の関数が用意されているのでこれを利用させてもらいます.

# dnnlibを利用する上で便利な関数

# W空間の潜在変数のリストから画像を生成する
def generate_images_in_w_space(dlatents, truncation_psi):
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    Gs_kwargs.truncation_psi = truncation_psi
    dlatent_avg = Gs.get_var('dlatent_avg') # [component]

    imgs = []
    for row, dlatent in log_progress(enumerate(dlatents), name = "Generating images"):
        #row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
        dl = (dlatent-dlatent_avg)*truncation_psi   + dlatent_avg
        row_images = Gs.components.synthesis.run(dlatent,  **Gs_kwargs)
        imgs.append(PIL.Image.fromarray(row_images[0], 'RGB'))
    return imgs       

# Z空間の潜在変数のリストから画像を生成する
def generate_images(zs, truncation_psi):
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if not isinstance(truncation_psi, list):
        truncation_psi = [truncation_psi] * len(zs)

    imgs = []
    for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
        Gs_kwargs.truncation_psi = truncation_psi[z_idx]
        noise_rnd = np.random.RandomState(1) # fix noise
        tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
        images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
        imgs.append(PIL.Image.fromarray(images[0], 'RGB'))
    return imgs

# シード値からZ空間の潜在変数を作る
def generate_zs_from_seeds(seeds):
    zs = []
    for seed_idx, seed in enumerate(seeds):
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
        zs.append(z)
    return zs

# シード値のリストから画像を生成する
def generate_images_from_seeds(seeds, truncation_psi):
    return generate_images(generate_zs_from_seeds(seeds), truncation_psi)

# 生成した画像を保存する
def saveImgs(imgs, location):
  for idx, img in log_progress(enumerate(imgs), size = len(imgs), name="Saving images"):
    file = location+ str(idx) + ".png"
    img.save(file)

# 生成した画像を出力する
def imshow(a, format='png', jpeg_fallback=True):
  a = np.asarray(a, dtype=np.uint8)
  str_file = BytesIO()
  PIL.Image.fromarray(a).save(str_file, format)
  im_data = str_file.getvalue()
  try:
    disp = IPython.display.display(IPython.display.Image(im_data))
  except IOError:
    if jpeg_fallback and format != 'jpeg':
      print ('Warning: image was too large to display in format "{}"; '
             'trying jpeg instead.').format(format)
      return imshow(a, format='jpeg')
    else:
      raise
  return disp

def showarray(a, fmt='png'):
    a = np.uint8(a)
    f = StringIO()
    PIL.Image.fromarray(a).save(f, fmt)
    IPython.display.display(IPython.display.Image(data=f.getvalue()))


def clamp(x, minimum, maximum):
    return max(minimum, min(x, maximum))

def drawLatent(image,latents,x,y,x2,y2, color=(255,0,0,100)):
  buffer = PIL.Image.new('RGBA', image.size, (0,0,0,0))

  draw = ImageDraw.Draw(buffer)
  cy = (y+y2)/2
  draw.rectangle([x,y,x2,y2],fill=(255,255,255,180), outline=(0,0,0,180))
  for i in range(len(latents)):
    mx = x + (x2-x)*(float(i)/len(latents))
    h = (y2-y)*latents[i]*0.1
    h = clamp(h,cy-y2,y2-cy)
    draw.line((mx,cy,mx,cy+h),fill=color)
  return PIL.Image.alpha_composite(image,buffer)


def createImageGrid(images, scale=0.25, rows=1):
   w,h = images[0].size
   w = int(w*scale)
   h = int(h*scale)
   height = rows*h
   cols = ceil(len(images) / rows)
   width = cols*w
   canvas = PIL.Image.new('RGBA', (width,height), 'white')
   for i,img in enumerate(images):
     img = img.resize((w,h), PIL.Image.ANTIALIAS)
     canvas.paste(img, (w*(i % cols), h*(i // cols))) 
   return canvas

# Z空間の潜在変数をW空間にマッピングネットワークを利用して射影する
def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):
  dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]
  dlatent_avg = Gs.get_var('dlatent_avg') # [component]
  for i in range(truncation_cutoff):
    dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg

  return dlatent

# 与えられた潜在変数をステップ数で線形補間する
def interpolate(zs, steps):
   out = []
   for i in range(len(zs)-1):
    for index in range(steps):
     fraction = index/float(steps) 
     out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
   return out

# https://github.com/alexanderkuk/log-progress より進捗可視化用機能を拝借している
def log_progress(sequence, every=1, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # 0.5%ごと
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

シード値から画像の生成

# シードを生成する
seeds = np.random.randint(9999320, size=8)
# シード値から潜在変数に変換して画像を生成
imshow(createImageGrid(generate_images_from_seeds(seeds, 0.7), 0.7 , 2))

vats.png

Z空間上で線形補間

# Z潜在空間で線形補間
zs = generate_zs_from_seeds([5015289 , 9148088])
number_of_steps = 5

imgs = generate_images(interpolate(zs,number_of_steps), 1.0)
imshow(createImageGrid(imgs, 0.4 , 1))

cats_1.png

W空間上で線形補完

本来はZ空間の潜在変数をマッピングネットワークに通してW空間に射影して画像を生成するため,W空間のほうが次元数も高く元データの潜在変数の偏りに強い滑らかな空間になっています.なので,W空間で直接線形補完を行うほうが画像の変化もより滑らかで連続的になります.

# Z空間の潜在変数をW空間に射影してから線形補完(より滑らかな空間で補完可能)
zs = generate_zs_from_seeds([5015289 , 9148088])

dls = []
for z in zs:
  dls.append(convertZtoW(z ,truncation_psi=1.0))

number_of_steps = 5

imgs = generate_images_in_w_space(interpolate(dls,number_of_steps), 1.0)
imshow(createImageGrid(imgs, 0.4 , 1))

cats_2.png

W空間で猫と旅をする

最後に,複数の潜在変数間を線形補間して動画を作成してみます.
なんか目が回って宇宙猫みも感じますね.

# Z空間の潜在変数をW空間に射影してから線形補間(より滑らかな空間で補間可能)
zs = generate_zs_from_seeds([42165,6149575,3487643,3766864 ,3857159,5360657,3720613 ])

dls = []
for z in zs:
  dls.append(convertZtoW(z ,truncation_psi=1.0))

number_of_steps = 10

imgs = generate_images_in_w_space(interpolate(dls,number_of_steps), 1.0)
%mkdir out
movieName = 'out/mov.mp4'

with imageio.get_writer(movieName, mode='I') as writer:
    for image in log_progress(list(imgs), name = "Creating animation"):
        writer.append_data(np.array(image))

ezgif.com-gif-maker.gif

さいごに

今日は少しテイストを変えて,猫の顔生成について技術的側面を多めで書いてみました.何か質問があればどしどしお願いします!

ちなみに馬や教会の生成結果はこちらです.さらにちなみに修論は進んでません(やばい).

ezgif.com-gif-maker (2).gif
ezgif.com-gif-maker (1).gif

今日使ったコードはGoogle Colabで公開しています.
こちらからいろんなシード値で猫を生成してみたり,馬なども試してみてください.

参考

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6