0
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

tensorflow向け 画像拡張クラス作った話

Last updated at Posted at 2020-07-13

#背景
Tensorflowでは、画像のデータ拡張を行うために、KerasのImageDataGeneratorがよく使われていました。
これは、入力画像に対し、マルチプロセスで画像に様々なデータ拡張を適用しながら、簡単にtensorflowモデルの学習ができるため、広く使われていました。

しかし、tensorflow 2.0以降multiprocessingが非推奨になった事により、マルチプロセス処理によりデータ拡張を行いながらtensorflowで学習を行っていると、突然エラーも吐かずプログレスバーが止まります。特に痛いのは、エラーが発生しないため、時間単位で使用料が発生するサービスを使っている場合、学習は進まないのに無駄に課金することになります... :money_mouth:

私は複数のファイルに分割して書き込んだhdf5ファイルからマルチプロセスで読み込みながら学習するジェネレータを使っていたのですが、tensorflow<2.0でもmultiprocessingを使うと2日程度で学習が止まる事がありました。tensorflow>=2.0以降は、2時間程度でも止まるなど、更に頻繁に止まるようになりました。

#そこで...
ImageDataGeneratorのように、簡単にマルチプロセスで様々なデータ拡張を画像に適用しながら、tensorflowで学習できるクラスを作りました。

#方針
tensorflow推奨のデータ入力方法はtensorflow.data.Dataset(以下、tf.data.Datasetとする)を使ったものになります。これを使うことで、例えばこちらで言及されていているように、高速かつ、マルチプロセスのデータ入力処理を作成することが可能になります。
しかし、tf.dataはstack overflow等にも未だあまり情報がなく、各データ拡張それぞれのについて試している様な書き込みはありますが、ImageDataGeneratorのように入力画像に対し簡単に様々なデータ拡張しながら学習する方法が見つかりませんでした…

tf.dataを使えば高速なデータ入力処理を作れますが、公式ドキュメントを見るといくつか落とし穴があることがわかります。

###1. tf.data.Dataset.from_generatorではマルチプロセスでデータ拡張されない
tf.data.Dataset.from_generatorを使えば、pythonのジェネレータをラップして、tf.dataとしてfit()関数により学習できます。最初、ImageDataGeneratorをこの関数でラップすれば良いやん!と簡単に考えていました。
しかし、公式ドキュメント、from_generatorのNoteには次のような記載があります。

Note: The current implementation of Dataset.from_generator() uses tf.numpy_function and inherits the same constraints. In particular, it requires the Dataset- and Iterator-related operations to be placed on a device in the same process as the Python program that called Dataset.from_generator(). The body of generator will not be serialized in a GraphDef, and you should not use this method if you need to serialize your model and restore it in a different environment.

tf.numpy_functionを使っていることにより、マルチプロセスに対応していないという事で諦めました。
###2. 出来るだけtfのみで実装する
公式ドキュメントのtf.functionに記載されていますが、パフォーマンスを実現するために@tf.functionデコレータで囲うと、全てのコードがtfのコードに自動的に変換されます。その際に外部ライブラリやnumpy等を使っていると、tf.numpy_functionやtf.py_func等でラップすることになり、結局1.と同様の制限に引っかかることになります。
従って処理及びデータ型はなるべくtf.Tensor型を使い、そうではなくてもpython標準の型のみを使用するようにしました。


###3. ラベル画像も同時に拡張する
入力画像を回転等の変形を行った場合、ラベルの元となる画像も全く同じ変形をする必要がありませんか?
私はそうでしたので、(オプションの)ラベル画像に対して、入力画像と全く同じ変形を適用するようにしました。

#インストール方法
python -m pip install git+https://github.com/piyop/tfaug

#対応しているデータ拡張

  • random_rotation
  • random_flip_left_right
  • random_flip_up_down
  • random_shift
  • random_zoom
  • random_shear
  • random_brightness
  • random_saturation
  • random_hue
  • random_contrast
  • random_crop
  • そのほかに、幾何学変換時の補完手法選択(NEAREST or BILINER)

#実践的な使い方
###1.ファイルパスから画像を読み込み、データ拡張を適用する場合

import tensorflow as tf
from tensorflow.data.experimental import AUTOTUNE

from tfaug import augment_img
import test_tfaug_tool as tool

DATADIR = r'testdata\tfaug'+os.sep

def aug_from_filepath(self):

    batch_size = 2
    filepaths = [DATADIR+'Lenna.png'] * 10

    # define tf.data.Dataset
    ds = tf.data.Dataset.from_tensor_slices(tf.range(10)).repeat().batch(batch_size)

    # construct preprocessing function
    dtype = tf.int32
    class tf_img_preproc():            
        def __init__(self, filepaths):
            self.filepaths = filepaths            

        def preproc(self, image_nos):

            return tf.convert_to_tensor(tool.read_imgs([self.filepaths[no] 
                                                   for no 
                                                   in image_nos]),
                                        dtype=dtype)

    preproc_obj = tf_img_preproc(filepaths)        
    func = lambda x:tf.py_function(preproc_obj.preproc, [x], dtype)

    # define augmentation
    aug_fun=augment_img(standardize=False,random_rotation=90,training=True)
    # map augmentation
    ds_aug = ds.map(func).map(aug_fun, num_parallel_calls=AUTOTUNE)

    # check augmented image
    fig, axs=plt.subplots(batch_size, 10, figsize=(10, batch_size), dpi=300)
    for i, imgs in enumerate(iter(ds_aug.take(10))):            
        axs[0,i].axis("off")
        axs[0,i].imshow(imgs[0])
        axs[1,i].axis("off")
        axs[1,i].imshow(imgs[1])

    plt.savefig(DATADIR+'aug_from_filepath.png')
            

上記コードでは90゜までのランダム回転のみ適用しています。これにより、次の画像が得られます。
aug_from_filepath.png

###2.tfrecordから画像と教師を読み込み、データ拡張を適用する場合

import tensorflow as tf
from tensorflow.data.experimental import AUTOTUNE

from tfaug import augment_img
import test_tfaug_tool as tool

DATADIR = r'testdata\tfaug'+os.sep

batch_size = 2

#test file
filepaths = [DATADIR+'Lenna.png'] * 10

#function for generate tfExample      
def image_example(iimg, imsk):                    
    feature = {'image': tool._bytes_feature(tool.np_to_pngstr(iimg)),
               'msk': tool._bytes_feature(tool.np_to_pngstr(imsk))}            
    return tf.train.Example(features=tf.train.Features(feature=feature))

path_tfrecord = DATADIR+r'sample.tfrecords'
#save tfrecord
with tf.io.TFRecordWriter(path_tfrecord) as writer:
    for filepath in filepaths:                      
        img = np.array(Image.open(filepath).convert('RGB') )     
        #use same image as msk                
        writer.write(image_example(img, img).SerializeToString())

# construct preprocessing function
dtype = tf.uint8

def preproc(tfexamples):                
    return (tf.map_fn(tf.image.decode_png,tfexamples['image'], dtype=dtype),
            tf.map_fn(tf.image.decode_png,tfexamples['msk'], dtype=dtype))

# define augmentation
aug_fun=augment_img(standardize=False,random_zoom=[0.2,0.8],training=True)

#define dataset
ds_aug = (tf.data.TFRecordDataset([path_tfrecord]).repeat().batch(batch_size)
              .apply(tf.data.experimental.parse_example_dataset(tool.tfexample_format))
              .map(preproc,num_parallel_calls=AUTOTUNE)
              .map(aug_fun,num_parallel_calls=AUTOTUNE))

# check augmented image
fig, axs=plt.subplots(batch_size*2, 10, figsize=(10, batch_size*2), dpi=300)
for i, (imgs, msks) in enumerate(iter(ds_aug.take(10))):       
    for row in range(batch_size):
        axs[row*2,i].axis("off")
        axs[row*2,i].imshow(imgs[row])
        axs[row*2+1,i].axis("off")
        axs[row*2+1,i].imshow(msks[row])

plt.savefig(DATADIR+'aug_from_tfrecord.png')

# to learn a model
# model.fit(ds_aug)

上記コードではy方向±20%、x方向±80%までのズームを適用しています。これにより、次の画像が得られます。
aug_from_tfrecord.png

詳細な使用方法例はtest参照。
(https://github.com/piyop/tfaug)

0
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?