LoginSignup
21
27

More than 5 years have passed since last update.

Chainerで画像分類するためのデータセットを自作して読み込む方法

Posted at

概要

Chainer入門者向けの記事です。チュートリアルから一歩進んで、世に公開されている画像データや自分で集めた画像などから画像分類の学習用データセットを作る方法をまとめました。

いざ自分でChainerで画像分類に取り掛かろうとした際に、最初に躓いたところであり、調べるのにも時間が掛かったので、備忘録のようなものになります。

環境

  • macOS High Sierra 10.13.6
  • python 3.6.5
  • Chainer 4.3.1

実行例

DeepLearningLabさんがGitHubに以前ハッカソンにて使用していた画像データを公開してくれていますので、そのデータをサンプルにしてChainer用のデータセットを作成していきます。データセットはこちら。

かなりのデータ量なので試す方はご注意ください。とは言っても3GBちょいですけどね。

フォルダ構成は以下のとおりです。

data
|-- train
|   |-- images
|   |   |-- XXXXXXXXX.jpg
|   |   |-- XXXXXXXXX.jpg
|   |   `-- ...
|   `-- train_labels.txt
|-- valid
|   |-- images
|   |   |-- XXXXXXXXX.jpg
|   |   `-- ...
|   `-- valid_labels.txt
`-- test
    |-- images
    |   |-- XXXXXXXXX.jpg
    |   `-- ...
    `-- test_labels.txt

このデータでは、それぞれのimagesフォルダの中に画像が入っており、同階層の.txtファイルに画像のファイル名と、対応するクラスIDがスペース区切りで入っています。

イメージはこんな感じです。

XXXXXXXXX01.jpg 0
XXXXXXXXX02.jpg 1
XXXXXXXXX03.jpg 2

では、これを読み込むにはどうすれば良いのでしょうか。

Chainerには非常に便利なLabeledImageDatasetというクラスがあり、これを使用します。コードはこの2行だけ。とっても簡単ですね。

from chainer.datasets import LabeledImageDataset

train = LabeledImageDataset('data/train/train_labels.txt', 'data/train/images')

ただし、これだけだと元々の画像のサイズがバラバラだとChainerにはそのまま突っ込めません。画像をリサイズするひと手間が必要になります。こちらもTransformDatasetクラスを使えば簡単にリサイズ出来ちゃいます。

from chainercv.transforms import resize
from chainer.datasets import TransformDataset

def transform(in_data):
    img, label = in_data
    img = resize(img, (224, 224))
    return img, label

train = TransformDataset(train, transform)

トレーニングデータと同様にバリデーションデータも作成します。

valid = LabeledImageDataset('data/valid/valid_labels.txt', 'data/valid/images')
valid = TransformDataset(valid, transform)

これで最低限の前処理は完了です。

実際にはここでグレースケールに変換したり、傾けてみたりなど色々と工夫をすることも可能ですが、それはまたの機会にまとめてみます。

ご参考

ChainerCVを使用しない方法

上記では画像のリサイズにChainerCVを使用しましたが、他にも方法はあります。代表的なものとしてはPillowライブラリを使用したものがあります。

以下のようにtransformクラスを書き換えます。

from PIL import Image

width, height = 224, 224 #ここは好きなサイズで構いません。

def transform(data):
    img, label = data
    img = img.astype(np.uint8)
    img = Image.fromarray(img.transpose(1, 2, 0))
    img = img.resize((width, height))
    img = np.asarray(img).transpose(2, 0, 1).astype(np.float32) / 255.
    return img, label

OpenCVを使用する方法

OpenCVを使い慣れていて、前処理もOpenCVでしてしまいたい方は、以下のようにすれば良いです。

import cv2

with open('data/train/train_labels.txt', 'r', encoding='utf-8') as f:
    train_labels = ''.join(f.readlines()).split('\n')[:-1]

x, t = [], []

for train_label in train_labels:
    filename, label = train_label.split(' ')
    filepath = 'data/train/images/{}'.format(filename)
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img, (2, 0, 1))

    label = np.array(label, 'i')

    x.append(img)
    t.append(label)

train = chainer.datasets.tuple_dataset.TupleDataset(x, t)

終わりに

普段コードを書き続けていないと、こういう読み込みなどのお作法はすぐに忘れてしまう。ライブラリが充実してきて、この辺りの処理が簡略化されていくことは大歓迎です。

21
27
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
27