4
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ImageDataGeneratorを使ってみた

Last updated at Posted at 2021-09-13

はじめに

最近、深層学習の実装ではPytorchに浮気している自分ですが、TensorFlowの中のImageDataGeneratorについて改めて勉強したので、その記録です。

使い方がメインなので、モデル構築や学習は行いません。

環境

Anacondaの仮想環境を使用。

python 3.9.5
Tensorflow 2.5.0

使うデータ

Kaggle DatasetのIntel Image Classificationを使います。

{'buildings' -> 0,
'forest' -> 1,
'glacier' -> 2,
'mountain' -> 3,
'sea' -> 4,
'street' -> 5 }

の対応になっている6クラス分類です。

基本的な使い方

データのディレクトリ構造をとりあえず以下のように設定します。

data/
├ seg_train/
│   ├ buildings/
│   ├ forest/
│   ├ glacier/
│     省略
├ seg_test/
│   ├ buildings/
│   ├ forest/
│   ├ glacier/
│     省略
└ seg_pred/
   ├ 10004.jpg
   ├ 10005.jpg
    省略

seg_trainとseg_testはサブディレクトリがクラスごとに分かれていて、その中に画像が
seg_predは予測してほしい画像なので、クラス関係なくごっちゃになって画像がディレクトリ直下に配置されています。

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_dir = './data/seg_train'
valid_dir = './data/seg_test'
test_dir = './data/seg_pred'

train_datagen = ImageDataGenerator(rescale=1./255, # 255で割ることで正規化
                                   zoom_range=0.2, # ランダムにズーム
                                   horizontal_flip=True, # 水平反転
                                   rotation_range=40, # ランダムに回転
                                   vertical_flip=True) # 垂直反転

valid_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

Generatorをそれぞれtrain用、valid用、test用と用意します。trainは水増しを行い、valid,testは水増しはせず正規化だけします。
ImageDataGeneratorで行える水増し処理一覧は公式ドキュメント参照。

train_generator = train_datagen.flow_from_directory(
    train_dir, target_size=(64, 64), batch_size=64, class_mode='categorical', shuffle=True)

valid_generator = valid_datagen.flow_from_directory(
    valid_dir, target_size=(64, 64), batch_size=64, class_mode='categorical', shuffle=True)

generatorに対して、flow_from_directoryを使用して、画像データを読み取らせます。
この時、上記のように読み取らせたいディレクトリの中にクラスごとに分かれたディレクトリが存在していないとうまく読み取ってくれないので注意。

target_sizeに指定した大きさにリサイズします。
class_modeは今回多クラス分類なのでcategoricalを指定。

うまくいけばこのような表示がされるはず。

Found 14037 images belonging to 6 classes.
Found 3000 images belonging to 6 classes.

上がtrainデータ、下がvalidデータについての実行結果です。画像枚数とクラス数が認識できています。

じゃあtestデータに対してもやってみましょう。

# 指定したディレクトリ直下に直接画像があるので認識してくれない
test_generator = test_datagen.flow_from_directory(
    test_dir, target_size=(64,64), batch_size=64, class_mode=None, shuffle=False)

予測データのため、classは何クラスあるかわからないので、class_modeはNoneにします。

ただ、これを実行すると以下のような表示が。

Found 0 images belonging to 1 classes.

先ほども言ったようにtest_dirにはディレクトリ直下にフォルダがなく直接画像データがたくさん置かれています。flow_from_directoryは指定したディレクトリにあるフォルダの数をクラス数として認識するので、フォルダが1つもない場合、画像を正しく読み取ってくれません。

なので何か適当なフォルダを作成して、そこに画像を全部入れてあげればいいです。

# seg_pred/predというディレクトリを作成してその中に画像を格納する
# shutilはファイルやディレクトリの移動・削除を扱える
import shutil

os.makedirs('./data/seg_pred/pred', exist_ok=True)
for image_path in glob.glob(test_dir + '/*'):
    shutil.move(image_path, test_dir + '/pred')

seg_predの中にpredというフォルダを作ってそこに画像データを全部移動させています。

shutil.move(移動させたい画像のパス, 移動先のパス)で画像を移動できます。

これでtestデータに対してflow_from_directoryを行えば

Found 7301 images belonging to 1 classes.

しっかりと7301枚の画像データを認識しています。

ちなみに.class_indicesでどうラベル付けしたか確認できます。

train_generator.class_indices
# -> 
# {'buildings': 0,
#  'forest': 1,
#  'glacier': 2,
#  'mountain': 3,
#  'sea': 4,
#  'street': 5}

どんな画像を生成したのかみてみる

水増しした画像がどういうものか確認することで、意味のない水増しを防げます。

# 1バッチ分取り出す(64個の画像)
items = next(iter(train_generator))
        
plt.figure(figsize=(12,12))
for i, image in enumerate(items[0][:25], 1):
    plt.subplot(5,5,i)
    plt.imshow(image)
    plt.axis('off')

download.png

ランダムに回転してたり、ちゃんと水増しが適用されていることがわかります。

items[0] が64個の画像データ
items[1] が64個のone-hot化されたラベルデータになっています。

なのでラベルも変換して表示できます。

# indexからラベル名に戻すための辞書を定義
index2label_dict = {
    0 : 'buildings',
    1 : 'forest',
    2 : 'glacier',
    3 : 'mountain',
    4 : 'sea',
    5 : 'street'}

items = next(iter(train_generator))
        
plt.figure(figsize=(12,12))
for i, image in enumerate(items[0][:25], 1):
    label_index = np.argmax(items[1][i-1])
    label_name = index2label_dict[label_index]
    plt.subplot(5,5,i)
    plt.imshow(image)
    plt.title(label_name)
    plt.axis('off')

例えばitems[1][1]の表示は以下のようになっています。

array([0., 0., 1., 0., 0., 0.], dtype=float32)

この1になっているのが何番目か知れればラベル名を返せるのでnp.argmaxで取得しています。

ここで水増し画像が訓練データとして適切か確認したらモデル構築・学習へと進む形になります。

まとめ

フォルダ構成を整えてあげることが大事だが、そこさえしっかりやれば水増し、前処理までやってくれるので便利だと感じた。

あとはバッチごとにデータを読み込むので、 もっと大量の画像を扱うことになった時にメモリオーバーにならなくていいと思う。

4
9
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?