LoginSignup
0
0

More than 3 years have passed since last update.

tensorflow2のtf.data.Datasetをfrow_from_directory()のように手軽に作成できる関数を作る。

Last updated at Posted at 2020-12-02

はじめに

注:個人メモですので使用の際には自己責任で。
注:tensorflow2.3.0でのみ動作、正式実装ではない関数を使用しています。

tf.data.Datasetのdatasetパイプラインをflow_from_directory()のように手軽に作るクラスを定義する。
augmentaionも組み込んでおく(フリップ、回転、ズームのみ)。

以前の投稿も参照ください

python3
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
import glob
import os
import numpy as np
import random

class creat_dataset_from_directory():
    def __init__(
    self,
    train_dir_path,
    valid_dir_path,
    img_height=150,
    img_width=150,
    batch_size=64,
    augmentation=False):
        self.train_dir = train_dir_path
        self.valid_dir = valid_dir_path
        self.height=img_height
        self.width=img_width
        self.batch_size=batch_size
        self.aug=augmentation

    def preprocess_image(self, image):
        image = tf.image.decode_jpeg(image, channels=3)#channelsは1だとGrayscale、3だとRGB
        image = tf.image.resize(image, [self.height, self.width])
        image /= 255.0  # normalize to [0,1] range
        return image

    def load_and_preprocess_image(self, path):
        image = tf.io.read_file(path)
        return self.preprocess_image(image)

    def return_dataset(self):
        label_names = [os.path.basename(i) for i in glob.glob(os.path.join(self.train_dir, '*')) if os.path.isdir(i)]
        label_names.sort()
        print('dataset has following {} classes : '.format(len(label_names)))
        label_to_index = dict((name, index) for index,name in enumerate(label_names))
        print(label_to_index)
        train_image_paths = glob.glob(os.path.join(self.train_dir, '*', '*jpg'))
        valid_image_paths = glob.glob(os.path.join(self.valid_dir, '*', '*jpg'))
        train_image_count = len(train_image_paths)
        valid_image_count = len(valid_image_paths)
        print('number of training image = ', train_image_count)
        print('number of validation image = ', valid_image_count)
        train_image_labels = [label_to_index[os.path.basename(os.path.dirname(path))]
                    for path in train_image_paths]
        valid_image_labels = [label_to_index[os.path.basename(os.path.dirname(path))]
                            for path in valid_image_paths]
        train_image_labels = tf.one_hot(train_image_labels, depth=len(label_names))
        valid_image_labels = tf.one_hot(valid_image_labels, depth=len(label_names))

        path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)
        label_ds = tf.data.Dataset.from_tensor_slices(train_image_labels)
        image_ds_train = path_ds.map(self.load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
        label_ds_train = tf.data.Dataset.from_tensor_slices(tf.cast(train_image_labels, tf.int32))
        image_label_ds_train = tf.data.Dataset.zip((image_ds_train, label_ds_train))

        if not self.aug:        
            ds = image_label_ds_train
            ds = ds.repeat()
            ds = ds.shuffle(buffer_size=train_image_count) #shuffleの順序注意!
            ds = ds.batch(self.batch_size)
            ds = ds.prefetch(buffer_size=AUTOTUNE)
        else:
            data_augmentation = tf.keras.Sequential([
                                  tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
                                  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2,fill_mode = 'reflect'),
                                  tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=0.2, fill_mode =  'reflect'),
                                    #tf.keras.layers.experimental.preprocessing.RandomCrop(140,140),
                                    #tf.keras.layers.experimental.preprocessing.RandomContrast(0.3),
                                    #tf.keras.layers.experimental.preprocessing.RandomTranslation(0.2,0.2),
                                ])
            ds = image_label_ds_train
            ds = ds.repeat()
            ds = ds.shuffle(buffer_size=train_image_count) #shuffleの順序注意!
            ds = ds.batch(self.batch_size)
            ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y))

            ds = ds.prefetch(buffer_size=AUTOTUNE)

        path_ds_valid = tf.data.Dataset.from_tensor_slices(valid_image_paths)
        image_ds_valid = path_ds_valid.map(self.load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
        label_ds_valid = tf.data.Dataset.from_tensor_slices(tf.cast(valid_image_labels, tf.int32))
        image_label_ds_valid = tf.data.Dataset.zip((image_ds_valid, label_ds_valid))
        ds_valid = image_label_ds_valid
        ds_valid = ds_valid.repeat()
        ds_valid = ds_valid.batch(self.batch_size)
        ds_valid = ds_valid.prefetch(buffer_size=AUTOTUNE)       

        #training、validationのdataset, trainingとvalidationの画像数が返されます。
        return ds, ds_valid, train_image_count, valid_image_count

使用方法

trainingvalidationのフォルダに学習画像を分ける。
各フォルダにはラベルごとにフォルダを作って画像を入れる。
kerasのfrow_from_directoryと一緒

Screenshot from 2020-12-02 21-38-48.png

python3
h,w = 150,150
batch_size = 64
creat_ds = creat_dataset_from_directory(train_dir_path='./dataset',
                             valid_dir_path='./validation',
                             img_height=h,img_width=w,
                             batch_size=batch_size,
                             augmentation=True)
ds_train, ds_valid,train_image_count, valid_image_count = creat_ds.return_dataset()
出力結果
dataset has following 5 classes : 
{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
number of training image =  3301
number of validation image =  369

モデルを作って、学習する

trainとvalidationのdatasetが返されるので、そのままfitする。(モデルは適当に作りました。)

python3
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D,BatchNormalization
from tensorflow.keras.preprocessing.image import ImageDataGenerator

epochs = 20
model = Sequential([
    Conv2D(16, 3, padding='same', activation='relu', input_shape=(h, w,3)),
    BatchNormalization(),
    MaxPooling2D(),
    Conv2D(32, 3, padding='same', activation='relu'),
    BatchNormalization(),
    MaxPooling2D(),
    Conv2D(64, 3, padding='same', activation='relu'),
    BatchNormalization(),
    MaxPooling2D(),
    Dropout(0.5),
    Flatten(),
    Dense(512, activation='relu'),
    Dense(5, activation='softmax')
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

history = model.fit(ds_train, 
        epochs=epochs, 
        steps_per_epoch=train_image_count//batch_size,
        validation_data=ds_valid,
        validation_steps=valid_image_count//batch_size, 
        validation_batch_size=batch_size,
        max_queue_size=120, workers=30, use_multiprocessing=True,
        )
出力結果
Epoch 1/20
 1/51 [..............................] - ETA: 0s - loss: 2.6099 - accuracy: 0.2969WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0165s vs `on_train_batch_end` time: 0.0273s). Check your callbacks.
51/51 [==============================] - 3s 55ms/step - loss: 2.8166 - accuracy: 0.4920 - val_loss: 9.0111 - val_accuracy: 0.2937
Epoch 2/20
51/51 [==============================] - 2s 47ms/step - loss: 1.0819 - accuracy: 0.5882 - val_loss: 17.7906 - val_accuracy: 0.2875
Epoch 3/20
51/51 [==============================] - 2s 47ms/step - loss: 0.9926 - accuracy: 0.6091 - val_loss: 17.1720 - val_accuracy: 0.2906
Epoch 4/20
51/51 [==============================] - 2s 46ms/step - loss: 1.0255 - accuracy: 0.6032 - val_loss: 10.6826 - val_accuracy: 0.2875
Epoch 5/20
51/51 [==============================] - 2s 47ms/step - loss: 0.9339 - accuracy: 0.6376 - val_loss: 7.2258 - val_accuracy: 0.3063
Epoch 6/20
51/51 [==============================] - 2s 47ms/step - loss: 0.9263 - accuracy: 0.6458 - val_loss: 2.4138 - val_accuracy: 0.3625
Epoch 7/20
51/51 [==============================] - 2s 46ms/step - loss: 0.8773 - accuracy: 0.6575 - val_loss: 1.8898 - val_accuracy: 0.4313
Epoch 8/20
51/51 [==============================] - 2s 46ms/step - loss: 0.8190 - accuracy: 0.6866 - val_loss: 1.3397 - val_accuracy: 0.4969
Epoch 9/20
51/51 [==============================] - 2s 47ms/step - loss: 0.8381 - accuracy: 0.6673 - val_loss: 1.2199 - val_accuracy: 0.5219
Epoch 10/20
51/51 [==============================] - 2s 46ms/step - loss: 0.8357 - accuracy: 0.6682 - val_loss: 0.8355 - val_accuracy: 0.6938
Epoch 11/20
51/51 [==============================] - 2s 47ms/step - loss: 0.8006 - accuracy: 0.6933 - val_loss: 0.8150 - val_accuracy: 0.7156
Epoch 12/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7640 - accuracy: 0.6994 - val_loss: 0.7955 - val_accuracy: 0.6969
Epoch 13/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7794 - accuracy: 0.6958 - val_loss: 0.7804 - val_accuracy: 0.7375
Epoch 14/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7668 - accuracy: 0.7071 - val_loss: 0.7873 - val_accuracy: 0.7500
Epoch 15/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7615 - accuracy: 0.6952 - val_loss: 0.6884 - val_accuracy: 0.7688
Epoch 16/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7184 - accuracy: 0.7178 - val_loss: 0.6732 - val_accuracy: 0.7781
Epoch 17/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7346 - accuracy: 0.7111 - val_loss: 0.7682 - val_accuracy: 0.7406
Epoch 18/20
51/51 [==============================] - 2s 47ms/step - loss: 0.6868 - accuracy: 0.7289 - val_loss: 0.6970 - val_accuracy: 0.7688
Epoch 19/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7411 - accuracy: 0.7148 - val_loss: 0.8185 - val_accuracy: 0.7125
Epoch 20/20
51/51 [==============================] - 2s 47ms/step - loss: 0.7106 - accuracy: 0.7292 - val_loss: 0.7481 - val_accuracy: 0.7375

おわりに

kerasのfrow_from_directoryのような感じにdatasetを作成するクラスを作成しました。これで少し便利になりそうですかね。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0