LoginSignup
13
14

More than 5 years have passed since last update.

Augmentorを使った画像の水増し

Posted at

画像データの水増し

Augmentor

機械学習用の学習用画像データを、良い感じに水増しするライブラリ。
話題になってから2ヶ月たちますが、特に記事が見当たらないので投稿。
Githubに詳しい解説があるので、使い方はそちらを見て下さい。

ソース

  • 画像Datasetを格納したフォルダを指定する
  • 上記フォルダ配下の、Subフォルダ内画像枚数を、全て水増しして統一
    • 画像のOpenDataは大抵フォルダ名がラベルなので
memo.txt
 C:test
 └ dataset ←これを指定
  ├ a ←これを5000枚に統一
  ├ b ←これを5000枚に統一
  └ c ←これを5000枚に統一

augment.py
import os
import argparse
import Augmentor
import shutil
import time


# https://github.com/mdbloice/Augmentor
# pip install Augmentor
# 機械学習用にイメージデータを水増しするライブラリ

# 再帰フォルダ探索のPATH格納
IMG_LIST_PATH = []
# 増やしたい画像枚数
TEST_DATA_NUM = 5000


def augment_data(img_dic):
    # Augmentorの処理
    for path, num in img_dic.items():
        make_test_img_num = TEST_DATA_NUM - num

        # 画像フォルダ
        p = Augmentor.Pipeline(path)

        # キャンバスの歪み
        p.skew_tilt(probability=0.3, magnitude=0.5)

        # 中心の歪み
        p.random_distortion(probability=0.3, grid_width=2, grid_height=2, magnitude=2)

        # 回転
        p.rotate90(probability=0.3)
        p.rotate270(probability=0.3)
        p.rotate(probability=0.3, max_left_rotation=10, max_right_rotation=10)

        # 反転
        p.flip_left_right(probability=0.3)
        p.flip_top_bottom(probability=0.3)

        # ずらし
        p.crop_random(probability=1, percentage_area=0.3)

        p.resize(probability=1.0, width=64, height=64)
        p.sample(make_test_img_num)


def traverse_dir(path):
    for file_or_dir in os.listdir(path):
        abs_path = os.path.abspath(os.path.join(path, file_or_dir))
        if os.path.isdir(abs_path):
            traverse_dir(abs_path)
        else:
            # 画像を見つけたら、親フォルダのPathを格納しておく
            img_directory = os.path.dirname(abs_path)
            if img_directory not in IMG_LIST_PATH:
                IMG_LIST_PATH.append(img_directory)


def get_folder_list(path):
    global IMG_LIST_PATH

    # 再帰でフォルダ内のDatasetを取得
    traverse_dir(path)
    imglist_num = {}
    # DatasetのFolderPathと、Folder内の画像数をDictionaryに入れておく
    for path in IMG_LIST_PATH:
        imglist_num[path] = len(os.listdir(path))

    IMG_LIST_PATH = {}
    return imglist_num


def copy_original_data(imgs_dic):
    for path in imgs_dic.keys():
        # Augmentorの処理まち
        while True:
            check_foleder = os.path.join(path, "output")
            if os.path.exists(check_foleder):
                time.sleep(1)
                break

        # label直下に持ってくる
        test_data_path = os.path.join(path, "output")
        for file_or_dir in os.listdir(test_data_path):
            abs_path = os.path.abspath(os.path.join(test_data_path, file_or_dir))
            if os.path.isdir(abs_path):
                pass
            else:
                dest_path= os.path.join(path, file_or_dir)
                shutil.move(abs_path, dest_path)

        del_folder = os.path.join(path, "output")
        shutil.rmtree(del_folder)


def main():
    # 一括でAugmentorを行いたいFolderを指定する
    target_path = "C:\\test\\dataset"

    # 上記指定フォルダ配下のフォルダに存在するもの全部を取得
    imgs_dic = get_folder_list(target_path)
    # Augumentorを一括で実施
    augment_data(imgs_dic)
    # Augmentorがoutput folderを作るので、同一フォルダにコピーしてくる
    copy_original_data(imgs_dic)

if __name__ == "__main__":
    main()
13
14
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
14