LoginSignup
4
1

More than 1 year has passed since last update.

MMGenerationを利用したGANの学習(cycle GAN)

Last updated at Posted at 2022-07-25

はじめに

前回MMGenerationの事前学習済みモデルを利用して、サンプルコードを動かしてみました。

今回は、CyacleGANを使って実際のトレーニングをしてみようと思います。

CyacleGANとは?という方は、こちらを参照してみてみください。

コードはこちらにおいています。
image.png
をクリックして、google colaboratoryで実行することができます。

MMGenerationを環境構築

まず、前回同様にMMGenerationの利用環境を構築します。
前回の記事と同じ内容です。)

使用するモジュールのインポート

import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device = ", device)

MMCVのインストール

# MMCVのインストール
!pip install -U openmim
!mim install mmcv-full

MMGenerationのインストール

!git clone https://github.com/open-mmlab/mmgeneration.git
%cd /content/mmgeneration
!pip install -v -e .  # or "python setup.py develop"

コンフィグファイルの作成

MMGenerationでは、用意されたコンフィグファイルを使うことでモデルの切り替えを行います。
また、コンフィグファイルの内容を一部編集することで利用する用途に応じた学習設定を行うことができます。

コンフィグファイルの読み込み

今回は、「cyclegan_lsgan_id0_resnet_in_facades_b1x1_80k.py」というコンフィグファイルを利用します。
どのような用途を想定しているかはわかりませんが、他のものが「winter to summer」、「horse to zebra」だったので、オリジナルの学習を行う場合にはこれが一番良さそうだと思いました。


from mmcv import Config
cfg = Config.fromfile('./configs/cyclegan/cyclegan_lsgan_id0_resnet_in_facades_b1x1_80k.py')

学習用のデータについて

CycleGANでは、変換を行いたい2種類の画像を使って学習を行うことになりますが、コンフィグファイルを確認しても、2種類の画像のフォルダパスを直接指定することは出来なさそうでした。
取りあえずは、サンプルのテストデータと同じフォルダ構成にする必要がありそうです。

画像のルートフォルダ
|
| - trainA (グループAの学習用の画像フォルダ)
|
| - trainB (グループBの学習用の画像フォルダ)
|
| - testA  (グループAのテスト用の画像フォルダ)
|
| - testB  (グループBのテスト用の画像フォルダ)

コンフィグファイルの編集

コンフィグファイルにて学習用フォルダのルートパスを指定します。
(コンフィグで指定されているフォルダ「./data/unpaired_facades」にデータをおいても構いません。)

# データのパス
cfg.data.train.dataroot = #フォルダのルートパス
cfg.data.test.dataroot = #フォルダのルートパス
cfg.data.val.dataroot = #フォルダのルートパス
cfg.gpu_ids = range(0, 1)
cfg.seed = 123

学習の実施

学習方法については、ドキュメントにほぼ記載されていなかったのですが、「tools.train.py」の内容を参考に実施してみました。

必要モジュールのインポート

import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings

import cv2
import mmcv
import torch
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash

from mmgen import __version__
from mmgen.apis import set_random_seed, train_model
from mmgen.datasets import build_dataset
from mmgen.models import build_model
from mmgen.utils import collect_env, get_root_logger

学習の実施

model = build_model(
    cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

datasets = [build_dataset(cfg.data.train)]

timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())

meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'

meta['env_info'] = env_info
meta['config'] = cfg.pretty_text

train_model(
    model,
    datasets,
    cfg,
    distributed=False,
    timestamp=timestamp,
    meta=meta)

上記の内容で学習を実施することができました。

学習結果の確認

学習方法がよくなかったせいか、MMGenerationで用意している関数「sample_img2img_model」では、うまくモデルを呼び出すことができませんでした。
(コンフィグファイルがモデルに格納できていないようです。)

そのため、以下の内容で少しカスタムした関数を作りました。

sample_img2img_model2(モデル, 変換する画像のパス, target_domain=ターゲットのスタイル, **kwargs)

モデル:学習したモデル
変換する画像のパス:変換する画像のパス
target_domain;tarinAの画像とtrainBの画像のどちらのスタイルに変換するかを指定します。今回はtrainAは「mask」、trainB「photo」になります。

cfg2 = cfg

from mmgen.datasets.pipelines import Compose
from mmgen.models import BaseTranslationModel

from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmcv.utils import is_list_of

def sample_img2img_model2(model, image_path, target_domain=None, **kwargs):
    """Sampling from translation models.

    Args:
        model (nn.Module): The loaded model.
        image_path (str): File path of input image.
        style (str): Target style of output image.
    Returns:
        Tensor: Translated image tensor.
    """
    assert isinstance(model, BaseTranslationModel)

    # get source domain and target domain
    if target_domain is None:
        target_domain = model._default_domain
    source_domain = model.get_other_domains(target_domain)[0]

    #cfg = model._cfg
    cfg = cfg2
    device = next(model.parameters()).device  # model device
    # build the data pipeline
    test_pipeline = Compose(cfg.test_pipeline)

    # prepare data
    data = dict()
    # dirty code to deal with test data pipeline
    data['pair_path'] = image_path
    data[f'img_{source_domain}_path'] = image_path
    data[f'img_{target_domain}_path'] = image_path

    data = test_pipeline(data)
    if device.type == 'cpu':
        data = collate([data], samples_per_gpu=1)
        data['meta'] = []
    else:
        data = scatter(collate([data], samples_per_gpu=1), [device])[0]

    source_image = data[f'img_{source_domain}']
    # forward the model
    with torch.no_grad():
        results = model(
            source_image,
            test_mode=True,
            target_domain=target_domain,
            **kwargs)
    output = results['target']
    return output

kaggleのモネ風の画像作成コンペ(こちら)にて、学習したモデルにて手持ちの写真からモネ風の画像を生成してみました。

test_folder = #画像フォルダ

test_images = glob(test_folder + "/*.jpg")

m = len(test_images)

plt.figure(figsize=(24,48))

for i,image_path in enumerate(test_images):
    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
    plt.subplot(m,2,(i * 2) + 1)
    plt.imshow(original_image)

    translated_image = sample_img2img_model2(model, image_path, target_domain='mask')
    translate_image = translated_image.cpu().numpy()[0]
    translate_image = translate_image.transpose(1,2,0)
    plt.subplot(m,2,(i * 2) + 2)
    plt.imshow(translate_image)

plt.show()

モネ風とまで言えるかはわかりませんが、油絵風の画像は生成できているかと思います。
ダウンロード (1)のコピー.jpg

まとめ

MMGenerationを使って、CycleGANを使った学習まで実施することができました。
MMGenerationのドキュメントはまだあまり整備されていないようで、ドキュメントにも利用方法はなかなか見つからないため手探りになってしまう感はあります。

また、実際に試してみたい方は、記事でもご紹介しましたがkaggleのこちらのコンペなどで試してみると良いかと思います。

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1