Help us understand the problem. What is going on with this article?

Kerasでクラスバランスを保ったまま各Batchを作成するBalancedImageDataGeneratorを作った

More than 1 year has passed since last update.

KerasのGeneratorを自作する を参考にしてたら、クラス不均衡問題の場合に学習の収束の仕方が気になったのでクラスバランスが統一されたbatchを生成してくれるBalancedImageDataGenerator.flow_from_directoryを書いてみました。

例えば、yの合計が160個で、クラスの内訳が[0が100個, 1が50個, 2が10個]の場合、batch_size=32を渡すと[0が20個, 1が10個, 2が2個]のbatchを生成してくれます。
クラス数のバランスを保ったままbatchを作成するため、batch_sizeは必ずしも渡したものが使われるとは限りません。

mkdir data && cd data
wget http://pjreddie.com/media/files/cifar.tgz
tar zxvf cifar.tgz

trainとtestのディレクトリにpng形式の画像ファイルが格納されており、ファイル名は_.pngとなっています。

cifar
├── labels.txt
├── test
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...
├── train
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...
import numpy as np
import pathlib
from decimal import Decimal, ROUND_HALF_UP
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from PIL import Image
from keras.applications.mobilenet import MobileNet

class BalancedImageDataGenerator(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.images = []
        self.labels = []

    def round(self, f):
        return Decimal(str(f)).quantize(Decimal('0'), rounding=ROUND_HALF_UP)

    def flow_from_directory(self, directory, classes, batch_size=32, seed=42, categorical=False, shuffle=True):
        pathlist = np.array(list(pathlib.Path(directory).iterdir()))
        y_str = [path.stem.split('_')[1] for path in pathlist]
        le = LabelEncoder()
        le.fit(np.array(classes))
        y = le.transform(y_str)
        bincount = np.bincount(y)

        onecount = bincount * batch_size / sum(bincount)
        onecount = [self.round(x) for x in onecount]
        onecount = np.array(onecount, dtype=int)
        one_batch = sum(onecount)
        print("batch_size:", onecount)
        if shuffle:
            r = np.random.permutation(len(y))
            pathlist2 = pathlist[r]
            y2 = y[r]
        else:
            pathlist2 = pathlist
            y2 = y
        while True:
            for i in ange(0, len(y)):
                path = pathlist2[i]
                label = y2[i]
                with Image.open(path) as f:
                    self.images.append(np.asarray(f.convert('RGB'), dtype=np.float32))
                if categorical:
                    self.labels.append(to_categorical(label, len(classes)))
                else:
                    self.labels.append(label)

                if len(self.images) == one_batch or i == len(y)-1:
                    inputs = np.asarray(self.images, dtype=np.float32)
                    targets = np.asarray(self.labels, dtype=np.float32)
                    self.reset()
                    yield inputs, targets

train_dir = pathlib.Path('data/cifar/train/')
train_datagen = BalancedImageDataGenerator()
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
batch_size = 32
epoch = 10
input_shape = (32, 32, 3)

model = MobileNet(input_shape=input_shape, weights=None, classes=len(classes))
model.compile(optimizer="Adam", loss="sparse_categorical_crossentropy")
model.fit_generator(
    generator=train_datagen.flow_from_directory(train_dir, classes, batch_size=batch_size),
    steps_per_epoch=int(np.ceil(len(list(train_dir.iterdir())) / batch_size)),
    epochs=epoch,
    verbose=1)
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away