scikit-learnのtrain_test_split関数は非常に便利ですが、
画像データで使うと、一旦全ての画像をロードする形になってしまいます。
となるとメモリにデータが全部乗り切らない、KerasのImageDataGeneratorの使用ができない(水増しした画像を一旦保存するなどすればできなくはないとは思いますが)などの問題が出てきてしまいます。
今回の関数を使えば、ディレクトリをtrain、testに分けてくれるので、ImageDataGeneratorも使用することができます。
とりあえず使い方
ディレクトリ構造は下の想定です。下の例だと pet_dataの下にクラス(calpis等)のディレクトリがあり、その中に画像が入っています。
実行するとこうなります
例でいうと、mainのoriginal_dirにpet_dataのpathを指定、base_dirは好きに指定してください。
trainとvalとtestに分けたいときは一度分けたvalフォルダをもう一度分けてもらえればいけます。
train_test_spritなのにフォルダ名がvalidationなのは勘弁してください。。。。。
import os
import shutil
import random
def image_dir_train_test_sprit(original_dir, base_dir, train_size=0.8):
'''
画像データをトレインデータとテストデータにシャッフルして分割します。フォルダもなければ作成します。
parameter
------------
original_dir: str
オリジナルデータフォルダのパス その下に各クラスのフォルダがある
base_dir: str
分けたデータを格納するフォルダのパス そこにフォルダが作られます
train_size: float
トレインデータの割合
'''
try:
os.mkdir(base_dir)
except FileExistsError:
print(base_dir + "は作成済み")
#クラス分のフォルダ名の取得
dir_lists = os.listdir(original_dir)
dir_lists = [f for f in dir_lists if os.path.isdir(os.path.join(original_dir, f))]
original_dir_path = [os.path.join(original_dir, p) for p in dir_lists]
num_class = len(dir_lists)
# フォルダの作成(トレインとバリデーション)
try:
train_dir = os.path.join(base_dir, 'train')
os.mkdir(train_dir)
except FileExistsError:
print(train_dir + "は作成済み")
try:
validation_dir = os.path.join(base_dir, 'validation')
os.mkdir(validation_dir)
except FileExistsError:
print(validation_dir + "は作成済み")
#クラスフォルダの作成
train_dir_path_lists = []
val_dir_path_lists = []
for D in dir_lists:
train_class_dir_path = os.path.join(train_dir, D)
try:
os.mkdir(train_class_dir_path)
except FileExistsError:
print(train_class_dir_path + "は作成済み")
train_dir_path_lists += [train_class_dir_path]
val_class_dir_path = os.path.join(validation_dir, D)
try:
os.mkdir(val_class_dir_path)
except FileExistsError:
print(val_class_dir_path + "は作成済み")
val_dir_path_lists += [val_class_dir_path]
#元データをシャッフルしたものを上で作ったフォルダにコピーします。
#ファイル名を取得してシャッフル
for i,path in enumerate(original_dir_path):
files_class = os.listdir(path)
random.shuffle(files_class)
# 分割地点のインデックスを取得
num_bunkatu = int(len(files_class) * train_size)
#トレインへファイルをコピー
for fname in files_class[:num_bunkatu]:
src = os.path.join(path, fname)
dst = os.path.join(train_dir_path_lists[i], fname)
shutil.copyfile(src, dst)
#valへファイルをコピー
for fname in files_class[num_bunkatu:]:
src = os.path.join(path, fname)
dst = os.path.join(val_dir_path_lists[i], fname)
shutil.copyfile(src, dst)
print(path + "コピー完了")
print("分割終了")
def main():
original_dir = "pet_data"
base_dir = "sprit_data"
train_size = 0.8
image_dir_train_test_sprit(original_dir, base_dir, train_size)
if __name__ == "__main__":
main()
今回とは関係ないですが、同じDICの方のモデルの蒸留の記事 https://qiita.com/M_Hiro/items/0ba24788c78540046bcd#_reference-181e9e03e1859a41c2c3