31
23

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

画像ディレクトリをtrain_test_splitする関数

Last updated at Posted at 2019-04-16

scikit-learnのtrain_test_split関数は非常に便利ですが、
画像データで使うと、一旦全ての画像をロードする形になってしまいます。
となるとメモリにデータが全部乗り切らない、KerasのImageDataGeneratorの使用ができない(水増しした画像を一旦保存するなどすればできなくはないとは思いますが)などの問題が出てきてしまいます。
今回の関数を使えば、ディレクトリをtrain、testに分けてくれるので、ImageDataGeneratorも使用することができます。

とりあえず使い方

ディレクトリ構造は下の想定です。下の例だと pet_dataの下にクラス(calpis等)のディレクトリがあり、その中に画像が入っています。
スクリーンショット 2019-04-16 10.50.28.png
実行するとこうなります
スクリーンショット 2019-04-16 11.11.20.png

例でいうと、mainのoriginal_dirにpet_dataのpathを指定、base_dirは好きに指定してください。
trainとvalとtestに分けたいときは一度分けたvalフォルダをもう一度分けてもらえればいけます。
train_test_spritなのにフォルダ名がvalidationなのは勘弁してください。。。。。

import os
import shutil
import random


def image_dir_train_test_sprit(original_dir, base_dir, train_size=0.8):
    '''
    画像データをトレインデータとテストデータにシャッフルして分割します。フォルダもなければ作成します。

    parameter
    ------------
    original_dir: str
      オリジナルデータフォルダのパス その下に各クラスのフォルダがある
    base_dir: str
      分けたデータを格納するフォルダのパス そこにフォルダが作られます
    train_size: float
      トレインデータの割合
    '''
    try:
        os.mkdir(base_dir)
    except FileExistsError:
        print(base_dir + "は作成済み")

    #クラス分のフォルダ名の取得
    dir_lists = os.listdir(original_dir)
    dir_lists = [f for f in dir_lists if os.path.isdir(os.path.join(original_dir, f))]
    original_dir_path = [os.path.join(original_dir, p) for p in dir_lists]

    num_class = len(dir_lists)

    # フォルダの作成(トレインとバリデーション)
    try:
        train_dir = os.path.join(base_dir, 'train')
        os.mkdir(train_dir)
    except FileExistsError:
        print(train_dir + "は作成済み")

    try:
        validation_dir = os.path.join(base_dir, 'validation')
        os.mkdir(validation_dir)
    except FileExistsError:
        print(validation_dir + "は作成済み")

    #クラスフォルダの作成
    train_dir_path_lists = []
    val_dir_path_lists = []
    for D in dir_lists:
        train_class_dir_path = os.path.join(train_dir, D)
        try:
            os.mkdir(train_class_dir_path)
        except FileExistsError:
            print(train_class_dir_path + "は作成済み")
        train_dir_path_lists += [train_class_dir_path]
        val_class_dir_path = os.path.join(validation_dir, D)
        try:
            os.mkdir(val_class_dir_path)
        except FileExistsError:
            print(val_class_dir_path + "は作成済み")
        val_dir_path_lists += [val_class_dir_path]


    #元データをシャッフルしたものを上で作ったフォルダにコピーします。
    #ファイル名を取得してシャッフル
    for i,path in enumerate(original_dir_path):
        files_class = os.listdir(path)
        random.shuffle(files_class)
        # 分割地点のインデックスを取得
        num_bunkatu = int(len(files_class) * train_size)
        #トレインへファイルをコピー
        for fname in files_class[:num_bunkatu]:
            src = os.path.join(path, fname)
            dst = os.path.join(train_dir_path_lists[i], fname)
            shutil.copyfile(src, dst)
        #valへファイルをコピー
        for fname in files_class[num_bunkatu:]:
            src = os.path.join(path, fname)
            dst = os.path.join(val_dir_path_lists[i], fname)
            shutil.copyfile(src, dst)
        print(path + "コピー完了")

    print("分割終了")


def main():
    original_dir = "pet_data"
    base_dir = "sprit_data"
    train_size = 0.8
    image_dir_train_test_sprit(original_dir, base_dir, train_size)
    

if __name__ == "__main__":
    main()

今回とは関係ないですが、同じDICの方のモデルの蒸留の記事 https://qiita.com/M_Hiro/items/0ba24788c78540046bcd#_reference-181e9e03e1859a41c2c3

31
23
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
31
23

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?