35
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Chainer Trainerを使って独自データセットの学習をしてみた

Posted at

Chainer1.11.0がリリースされ、Trainerという学習ループを抽象化する機能が追加されたようなので、独自のAV女優顔画像データセットを使って学習をしてみます。

顔画像の抽出やデータ拡張については、Qiita - chainerによるディープラーニングでAV女優の類似画像検索サービスをつくったノウハウを公開するを参照してください。
元記事ではnumpy形式に変換していますが、今回は学習時にディレクトリから直接画像を読み込むため、numpy形式に変換しません。

ここで使う顔画像は各女優につき1000枚ずつ画像があり、64×64のサイズにリサイズし、以下構成のディレクトリに分けられているものとします。

./root
    |
    |--- /actress1
    |        |--- image1.jpg
    |        |--- image2.jpg
    |        |--- image3.jpg
    |
    |--- /actress2
    |        .
    |        .
    |--- /actress3
    .
    .
    .

データを学習用と検証用に分ける

まず、顔画像データを学習用と検証用に分割します。学習の際のデータ読み込み時に学習用と検証用に分割しながら学習することもできますが、どんなデータが学習で使われていてどんなデータが検証で使われているのか分かりづらいため、事前に分割しておきます。

#!/usr/bin/env python
#-*- coding:utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import glob
import logging
import os
import random
import shutil

def separate_train_val(args):
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    if not os.path.exists(os.path.join(args.output_dir, 'train')):
        os.mkdir(os.path.join(args.output_dir, 'train'))

    if not os.path.exists(os.path.join(args.output_dir, 'val')):
        os.mkdir(os.path.join(args.output_dir, 'val'))

    directories = os.listdir(args.root)

    for dir_index, dir_name in enumerate(directories):
        files = glob.glob(os.path.join(args.root, dir_name, '*.jpg'))
        random.shuffle(files)
        if len(files) == 0: continue

        for file_index, file_path in enumerate(files):
            if file_index % args.val_freq != 0:
                target_dir = os.path.join(args.output_dir, 'train', dir_name)
                if not os.path.exists(target_dir):
                    os.mkdir(target_dir)
                shutil.copy(file_path, target_dir)
                logging.info('Copied {} => {}'.format(file_path, target_dir))
            else:
                target_dir = os.path.join(args.output_dir, 'val', dir_name)
                if not os.path.exists(target_dir):
                    os.mkdir(target_dir)
                shutil.copy(file_path, target_dir)
                logging.info('Copied {} => {}'.format(file_path, target_dir))

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')

    parser = argparse.ArgumentParser(description='converter')
    parser.add_argument('--root', default='.')
    parser.add_argument('--output_dir', default='.')
    parser.add_argument('--val_freq', type=int, default=10)
    args = parser.parse_args()

    separate_train_val(args)

分割したディレクトリは以下のような構成になります。

./train_val_root
    |
    |--- /train
    |       |--- actress1
    |       |       |--- image1.jpg
    |       |       |--- image2.jpg
    |       |       |--- image3.jpg
    |       |              ・
    |       |              ・
    |       |--- actress2
    |       |      ・
    |     |    ・         
    |
    |--- /val
    |       |--- actress1
    |       |
    |       |--- actress2
    .
    .

ディレクトリからデータを読み込むためのDatasetを作成

chainer.dataset.DatasetMixinを継承して指定のディレクトリからデータを読み込むためのクラスを定義します。認識の際に使用するクラス(0~9の数字)とラベル(ディレクトリ名)を出力するメソッドを定義(create_label_file)していますが、これは気持ち悪いので真似しないで下さい。

class DatasetFromDirectory(chainer.dataset.DatasetMixin):

    def __init__(self, root='.', label_out='', dtype=np.float32, label_dtype=np.int32):
        directories = os.listdir(root)
        label_table = []
        pairs = [] # tuple (filepath, label) list
        for dir_index, dir_name in enumerate(directories):
            label_table.append((dir_index, dir_name))
            file_paths = glob.glob(os.path.join(root, dir_name, '*.jpg'))
            for file_path in file_paths:
                pairs.append((file_path, dir_index))

        self._pairs = pairs
        self._root = root
        self._label_out = label_out
        self._label_table = label_table
        self._dtype = dtype
        self._label_dtype = label_dtype

        if label_out != '':
            self.create_label_file()

    def __len__(self):
        return len(self._pairs)

    def get_example(self, i):
        path, int_label = self._pairs[i]
        with Image.open(path) as f:
            image = np.asarray(f, dtype=self._dtype)
        image = image.transpose(2, 0, 1)
        label = np.array(int_label, dtype=self._label_dtype)
        return image, label

    def create_label_file(self):
        with open(self._label_out, "w") as f:
            for (label_index, label_name) in self._label_table:
                f.write('{},{}\n'.format(label_index, label_name))

公式のimagenetのサンプルを見ると、作成したデータセットクラスをベースにして学習時にデータを加工することもできます。学習時にランダムに画像を少し回転させたり、画像を少しずらしたりすることで、完全に同じデータから学習することが少なくなるため、汎化性能の向上が期待できます。

Trainerで独自データセットの学習

実際に用意したデータセットの学習をします。Chainer Trainerを使った実装にすることで、元コードの半分程度の量で実装することができます。

class CNN(chainer.Chain):
    """
    CNN (CCPCCPCP)
    """
    def __init__(self, n_classes):
        super(CNN, self).__init__(
            conv1_1=L.Convolution2D(3, 32, 3, pad=1),
            bn1_1=L.BatchNormalization(32),
            conv1_2=L.Convolution2D(32, 32, 3, pad=1),
            bn1_2=L.BatchNormalization(32),

            conv2_1=L.Convolution2D(32, 64, 3, pad=1),
            bn2_1=L.BatchNormalization(64),
            conv2_2=L.Convolution2D(64, 64, 3, pad=1),
            bn2_2=L.BatchNormalization(64),

            conv3_1=L.Convolution2D(64, 128, 3, pad=1),
            bn3_1=L.BatchNormalization(128),

            fc4=L.Linear(8192, 1024),
            fc5=L.Linear(1024, n_classes),
        )
        self.train = True

    def __call__(self, x, t):
        h = F.relu(self.bn1_1(self.conv1_1(x), test=not self.train))
        h = F.relu(self.bn1_2(self.conv1_2(h), test=not self.train))
        h = F.max_pooling_2d(h, 2, 2)

        h = F.relu(self.bn2_1(self.conv2_1(h), test=not self.train))
        h = F.relu(self.bn2_2(self.conv2_2(h), test=not self.train))
        h = F.max_pooling_2d(h, 2, 2)

        h = F.relu(self.bn3_1(self.conv3_1(h), test=not self.train))
        h = F.max_pooling_2d(h, 2, 2)

        h = F.dropout(F.relu(self.fc4(h)), ratio=0.3, train=self.train)
        h = self.fc5(h)

        loss = F.softmax_cross_entropy(h, t)
        chainer.report({'loss': loss, 'accuracy': F.accuracy(h, t)}, self)
        return loss
model = CNN(10)
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

mean = np.load(args.mean)
train_data = datasets.DatasetFromDirectory(args.train_root, label_out=label_file)
val_data = datasets.DatasetFromDirectory(args.val_root)

train_iter = chainer.iterators.SerialIterator(train_data, args.batch_size)
val_iter = chainer.iterators.SerialIterator(val_data, args.batch_size, repeat=False, shuffle=False)

# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.n_epoch, 'epoch'), out=args.output_dir)

snapshot_interval = (args.snapshot_interval, 'iteration')

# Copy the chain with shared parameters to flip 'train' flag only in test
eval_model = model.copy()
eval_model.train = False

trainer.extend(extensions.Evaluator(val_iter, eval_model, device=args.gpu))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
    model, 'model_iter_{.updater.iteration}'), trigger=snapshot_interval)
trainer.extend(extensions.snapshot_object(
    optimizer, 'optimizer_iter_{.updater.iteration}'), trigger=snapshot_interval)
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss',
     'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar(update_interval=10))

if args.resume:
    if not os.path.exists(args.resume):
        raise IOError('Resume file is not exists.')
    logging.info('Load optimizer state from {}'.format(args.resume))
    chainer.serializers.load_npz(args.resume, trainer)

trainer.run()

# Save the trained model
chainer.serializers.save_npz(os.path.join(args.output_dir, 'model_final'), model)
chainer.serializers.save_npz(os.path.join(args.output_dir, 'optimizer_final'), optimizer)

print()
logging.info('Saved the model and the optimizer')
logging.info('Training is finished!')

extensions.snapshot()で保存されるオブジェクトはtrainer用のもののため、実際に予測する際に読み込むmodeloptimizerextensions.snapshot_object()で別途保存する必要があります。

まとめ

Chainer Trainerを使い独自のデータセットの学習をしてみました。
Trainerを使った感想としては、予想通りKerasに近いという印象です。初めてChainerを使ってみたときはミニバッチ毎に読み込む箇所で手間がかかってしまった記憶があるため、そのような部分が抽象化されているTrainerはわかりやすい実装だと感じました。

しかし、KerasではImageDataGeneratorクラスのflow_from_directoryを使うことでDatasetクラスの実装をせずにディレクトリからデータを読み込むこともできるため、より簡単に作ることもできます。

最後に宣伝になりますが、CNNを使ってAV女優の類似画像検索をしたサイトを作っているので、よかったら見てみてください。

Babelink - 類似AV女優検索サービス
※アダルトサイトのため、閲覧には十分注意をしてください。

35
41
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?