4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[ディープラーニング] データセット作成 : Augmentaion → Train / Testデータ分割 →Train / Testデータをlist形式で取得。を効率よく。

Last updated at Posted at 2019-06-14

やったこと、なぜやったか

やったこと:

  • 限られた枚数しかないデータセットでAugmentationを行い、

  • 増やした画像をTrain_data / Test_dataに分割、

  • フォルダソート(GUI確認用)

  • 同時にlist形式で学習用のTrain_data / Test_dataの画像データ、正解ラベルデータも手に入れる。

その過程を効率的におこなった(と思っている)。

なぜやったか:

限られた枚数のデータセットを用いて精度の良い分類(2クラス)を行うため、Augmentationが必要。
そしてAugmentationを行なった画像を効率良くTrain / Testデータフォルダに整理したかった。

そして以下のモチベーションがあった。

⑴ Augmentationした画像を、自分でフォルダ移動させるのがめんどくさい(Pythonでやりたい)笑
⑵ Train / Testデータ格納フォルダを作るのがめんどくさい(Pythonでやりたい)笑
⑶ 同時にlist形式で学習用のTrain_data / Test_dataの画像データ、正解ラベルデータも手に入れたい 笑

参考にしたサイト等:

多数。笑(後ほど整理)

#ファイル構造

  • 作業ディレクトリ
    • pyファイル / 作業ファイル
    • image_folder
      • images / 画像データ
      • train / 作成するもの
        • class1
        • class2
      • test / 作成するもの
        • class1
        • class2

Augmentation

データセットが格納されているフォルダパスを定義

my_path = 'image_folder/images/'

このpathにはcat, dogの画像が格納されている。
各ファイルは以下のようなファイル名で定義されており、ファイル名で画像クラスが分かる状態。

dog1.jpg
cat2.jpg
 .
 .
 .

今回は少数のdog, catの画像をAugmentation。(2クラス分類を行う)
Augmentationの関数を定義↓

from os import listdir
from os.path import isfile, join
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img


def augmentation(dir_path, initial_letter_of_file='d', augment_num=3):
 
    """
    note : 指定ディレクトリ内の,指定頭文字で始まるファイルを指定枚数augmentationする
    ----------
    dir_path : フォルダパス
    initial_letter : augmentしたいファイル名の頭文字
    aument_num : augmentしたい枚数
    ----------
    """

    # 各画像ファイル名の抽出
    files_name = [f for f in listdir(my_path) if isfile(join(my_path, f))]
    files_name.remove('.DS_Store')
    
    
    #Augmentの設定
    datagen = ImageDataGenerator(rotation_range=40,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             fill_mode='nearest'
    )
    
    #指定された頭文字のファイルを指定枚数増やすloop
    for i, file in enumerate(files_name):
        img = load_img(my_path + file)
        x = img_to_array(img)
        x = x.reshape((1,) + x.shape) 

        if file[0] == initial_letter_of_file:
            i = 0
            for batch in datagen.flow(x, save_to_dir=dir_path, save_prefix=initial_letter_of_file, save_format="jpg"):
                i += 1
                if i > augment_num:
                    break

        else:
            pass

関数を動かす。

augmentation(my_path, 'd')
augmentation(my_path, 'c')
'image_folder/images/'

↑フォルダに画像が増えてることを確認。

Train / Testデータの格納フォルダの作成

後ほど、Augmentationした画像を各クラスごとにTrain, Testにソートするが、
その格納ファイルを作ってあげる。

import os

class1_dir_train = 'image_folder/train/class1/'
class1_dir_val = 'image_folder/test/class1/'
class2_dir_train = 'image_folder/train/class2/'
class2_dir_val = 'image_folder/test/class2/'

def make_dir(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)
    
make_dir(class1_dir_train)
make_dir(class1_dir_val)
make_dir(class2_dir_train)
make_dir(class2_dir_val)

#本題 : データのふるい分け

import cv2
import numpy as np
import sys
import shutil


def train_test_split(files_name, class_1='d', class_2='c', train_size=0.8):

    """
    note : 画像フォルダから、指定クラスを、指定割合でtrain_sprit
    ----------
    class_1 : クラス名(今回はdog)
    class_2 : クラス名(今回はcat)
    train_size : 分割したい割合
    ----------
    """
    
    # クラスごと何枚ずつソートされているか
    class_1_count = 0
    class_2_count = 0
    
    # 各クラスの画像数は、imagesフォルダの合計枚数の半数ずつ
    each_class_size = len(files_name) // 2
    
    # train, testの画像数を定義
    train_size = each_class_size * train_size
    test_size = each_class_size - train_size

    # 画像、それぞれのクラスラベル格納用array
    training_images = []
    training_labels = []
    test_images = []
    test_labels = []

    # 画像サイズの指定
    size = 200
    
   
    for i, file in enumerate(files_name):

        # ソートする指定枚数(train_size)を超えるまで、trainにソート
        # それを超えたらtestにソート
        # 他方のクラスも同様の動作

        if files_name[i][0] == class_1:
            class_1_count += 1
            image = cv2.imread(my_path + file)
            image = cv2.resize(image, (size, size), interpolation = cv2.INTER_AREA)
            if class_1_count <= train_size:
                training_images.append(image)
                training_labels.append(1)
                cv2.imwrite(class1_dir_train + class_1 + str(class_1_count) + '.jpg', image)
            if class_1_count > train_size and class_1_count <= train_size + test_size:
                test_images.append(image)
                test_labels.append(1)
                cv2.imwrite(class1_dir_val + class_1 + str(class_1_count) + '_' + '.jpg', image)

        if files_name[i][0] == class_2:
            class_2_count += 1
            image = cv2.imread(my_path + file)
            image = cv2.resize(image, (size, size), interpolation = cv2.INTER_AREA)
            if class_2_count <= train_size:
                training_images.append(image)
                training_labels.append(0)
                cv2.imwrite(class2_dir_train + class_2 + str(class_2_count) + '.jpg', image)
            if class_2_count > train_size and class_2_count <= train_size + test_size:
                test_images.append(image)
                test_labels.append(0)
                cv2.imwrite(class2_dir_val + class_2 + str(class_2_count) + '_' + '.jpg', image)
                
    return training_images, training_labels, test_images, test_labels

動かす。

training_images, training_labels, test_images, test_labels = train_test_split(updated_files_name)
training_images
[array([[[122, 154, 183],
         [144, 174, 200],
         [157, 182, 203],
         ...,
         [121, 159, 179],
         [128, 170, 189],
         [139, 180, 202]],
 
        [[112, 144, 173],
         [143, 173, 198],
         [150, 175, 195],
         ...,
          .
          .
          .
          .
training_labels
[1,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 0
 .
 .
 .

Augmentationした画像はうまくフォルダソートされ、
Train / Test データを listで取得することができた。

4
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?