7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Keras ImageDataGeneratorでnumpyアレイ型ファイルをデータ拡張

Last updated at Posted at 2020-03-28

Keras ImageDataGeneratorでnumpyアレイ型ファイルをデータ拡張

今回実施した事

  • npyファイルで保存された学習用データをKeras ImageDataGeneratorでデータ拡張
  • その後、拡張データ数に合わせて、ラベルファイルを作成し、npyファイルとして書き戻し

記事を書いた動機

  • 学習用データはnpyファイルで持っておくと扱いやすいが、一度、データ拡張を行うと画像データとして保存され、画像数を増やすとラベルファイルと整合性が取れなくなる。
  • しかし、既存のデータ拡張ツール(Keras ImageDataGeneratorや、Augmentor)では、上記の問題を解決してくれる機能は持っていなそうだった(自分が調べた限りでは)

今回は、Keras ImageDataGeneratorを使用して、以下の処理を実装した。

  1. Image DataGenaratorで、学習用データ(npy)ファイルを読み込み
  2. ラベル種別毎に、データ拡張を実行
  3. 拡張したデータ、及びラベルデータをnpyファイルに書き戻し

※データはMNISTの数字10文字想定


###実行環境

  • python 3.7.3
  • numpy 1.16.2
  • tensorflow 1.14.0
  • keras 2.2.4

外部ライブラリの読み込み

import numpy as np
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator,load_img,img_to_array
import os
import shutil
import glob
import matplotlib.pyplot as plt
from matplotlib import cm
from tqdm import tqdm

学習用データとラベルデータの読み込み

# 画像読み込み
data = np.load("/tmp/train_data.npy") 
label = np.load("/tmp/train_label.npy")

dir_list = glob.glob("/tmp/output_*")

#前回実行時のディレクトリが残っている場合、削除
for rdir in dir_list:
    shutil.rmtree(rdir)

データ拡張

# 文字種毎にデータ抽出、データ拡張
for i in range(10):
    index = np.where(label[:, i] == 1)

    data_c = data[index[0]]
    
    # 軸をN,H,W,Cに入れ替え
    data_c = data_c.transpose(0,2,3,1)

    # ImageDataGeneratorのオブジェクト生成
    datagen = ImageDataGenerator(
            shear_range=0.2, #シアー変換
            zoom_range=0.2, #ズーム
            rotation_range=15) #回転

    # 生成後枚数
    num_image = 4000

    # 保存先ディレクトリの作成    
    save_path = "/tmp/output_%s/"%i
    os.mkdir(save_path)
    
    # データ拡張
    g = datagen.flow(data_c, batch_size=1, save_to_dir=save_path, save_format='png', save_prefix='out_a_from_npy_')
    for k in range(num_image):
        batches = g.next()

データ拡張後の画像ファイルをnpyファイルに変換

IMG_SIZE=28
#True=Grayscale, False=RGB
COLOR=True
#Name to save
SAVE_FILE_NAME='SaveImages'
#shape File Name
if COLOR:
    SAVE_FILE_NAME=SAVE_FILE_NAME+'_'+str(IMG_SIZE)+'Gray'
else:
    SAVE_FILE_NAME=SAVE_FILE_NAME+'_'+str(IMG_SIZE)+'RGB'

merge_array = np.empty([0,28,28,1])    

for l in range(10):
    
    #Name to load images Folder
    DIR_NAME='./output_%s'%l
    
    #load madomagi images and reshape
    img_list=glob.glob(DIR_NAME+'/*.png')
    temp_img_array_list=[]
    for img in tqdm(img_list):
        temp_img=load_img(img,grayscale=COLOR,target_size=(IMG_SIZE,IMG_SIZE))
        temp_img_array=img_to_array(temp_img)
        temp_img_array_list.append(temp_img_array)

    temp_img_array_list=np.array(temp_img_array_list)
    
    print(temp_img_array_list.shape)

    merge_array = np.concatenate([merge_array, temp_img_array_list])
    
print(merge_array.shape)
merge_array = merge_array.transpose(0,3,1,2)
np.save("train_data_mr.npy",merge_array)

ラベルデータの作成

label_count = []

for i in range(10):
    path = "/tmp/output_%s"%i
    files = os.listdir(path)  
    count = len(files)
    label_count.append(count)

# 各ラベル作成
char_01=np.array([1,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
char_02=np.array([0,1,0,0,0,0,0,0,0,0,0,0,0,0,0])
char_03=np.array([0,0,1,0,0,0,0,0,0,0,0,0,0,0,0])
char_04=np.array([0,0,0,1,0,0,0,0,0,0,0,0,0,0,0])
char_05=np.array([0,0,0,0,1,0,0,0,0,0,0,0,0,0,0])
char_06=np.array([0,0,0,0,0,1,0,0,0,0,0,0,0,0,0])
char_07=np.array([0,0,0,0,0,0,1,0,0,0,0,0,0,0,0])
char_08=np.array([0,0,0,0,0,0,0,1,0,0,0,0,0,0,0])
char_09=np.array([0,0,0,0,0,0,0,0,1,0,0,0,0,0,0])
char_10=np.array([0,0,0,0,0,0,0,0,0,1,0,0,0,0,0])
                         
for i in range(10):
    if i == 0:
        array_01 = np.tile(char_01,(label_count[i],1))
    if i == 1:
        array_02 = np.tile(char_02,(label_count[i],1))
    if i == 2:
        array_03 = np.tile(char_03,(label_count[i],1))
    if i == 3:
        array_04 = np.tile(char_04,(label_count[i],1))
    if i == 4:
        array_05 = np.tile(char_05,(label_count[i],1))
    if i == 5:
        array_06 = np.tile(char_06,(label_count[i],1))
    if i == 6:
        array_07 = np.tile(char_07,(label_count[i],1))
    if i == 7:
        array_08 = np.tile(char_08,(label_count[i],1))
    if i == 8:
        array_09 = np.tile(char_09,(label_count[i],1))
    if i == 9:
        array_10 = np.tile(char_10,(label_count[i],1))


label_data_mr = np.concatenate([array_01, array_02, array_03, array_04, array_05, array_06, array_07, array_08, array_09, array_10])

np.save("train_label_mr.npy",label_data_mr)

最後に

  • データ拡張のコードについては、Augmentorに置き換えれば流用できそう。Augmentorでしか行えないデータ拡張変換(Perspective skew、Elastic distortion等)も、今後試したい。
  • 少々冗長なコードが含まれていて、ラベルの種類が増えてきた場合は辛いかもしれない。
7
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?