3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasでクラスバランスを保ったまま各Batchを作成するBalancedImageDataGeneratorを作った

Last updated at Posted at 2019-03-12

KerasのGeneratorを自作する を参考にしてたら、クラス不均衡問題の場合に学習の収束の仕方が気になったのでクラスバランスが統一されたbatchを生成してくれるBalancedImageDataGenerator.flow_from_directoryを書いてみました。

例えば、yの合計が160個で、クラスの内訳が[0が100個, 1が50個, 2が10個]の場合、batch_size=32を渡すと[0が20個, 1が10個, 2が2個]のbatchを生成してくれます。
クラス数のバランスを保ったままbatchを作成するため、batch_sizeは必ずしも渡したものが使われるとは限りません。

mkdir data && cd data
wget http://pjreddie.com/media/files/cifar.tgz
tar zxvf cifar.tgz

trainとtestのディレクトリにpng形式の画像ファイルが格納されており、ファイル名は_.pngとなっています。

cifar
├── labels.txt
├── test
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...
├── train
│   ├── 0_cat.png
│   ├── 1000_dog.png
│   ├── 1001_airplane.png
│   ├── ...
import numpy as np
import pathlib
from decimal import Decimal, ROUND_HALF_UP
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from PIL import Image
from keras.applications.mobilenet import MobileNet

class BalancedImageDataGenerator(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.images = []
        self.labels = []

    def round(self, f):
        return Decimal(str(f)).quantize(Decimal('0'), rounding=ROUND_HALF_UP)

    def flow_from_directory(self, directory, classes, batch_size=32, seed=42, categorical=False, shuffle=True):
        pathlist = np.array(list(pathlib.Path(directory).iterdir()))
        y_str = [path.stem.split('_')[1] for path in pathlist]
        le = LabelEncoder()
        le.fit(np.array(classes))
        y = le.transform(y_str)
        bincount = np.bincount(y)

        onecount = bincount * batch_size / sum(bincount)
        onecount = [self.round(x) for x in onecount]
        onecount = np.array(onecount, dtype=int)
        one_batch = sum(onecount)
        print("batch_size:", onecount)
        if shuffle:
            r = np.random.permutation(len(y))
            pathlist2 = pathlist[r]
            y2 = y[r]
        else:
            pathlist2 = pathlist
            y2 = y
        while True:
            for i in ange(0, len(y)):
                path = pathlist2[i]
                label = y2[i]
                with Image.open(path) as f:
                    self.images.append(np.asarray(f.convert('RGB'), dtype=np.float32))
                if categorical:
                    self.labels.append(to_categorical(label, len(classes)))
                else:
                    self.labels.append(label)

                if len(self.images) == one_batch or i == len(y)-1:
                    inputs = np.asarray(self.images, dtype=np.float32)
                    targets = np.asarray(self.labels, dtype=np.float32)
                    self.reset()
                    yield inputs, targets
                    
train_dir = pathlib.Path('data/cifar/train/')
train_datagen = BalancedImageDataGenerator()
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
batch_size = 32
epoch = 10
input_shape = (32, 32, 3)

model = MobileNet(input_shape=input_shape, weights=None, classes=len(classes))
model.compile(optimizer="Adam", loss="sparse_categorical_crossentropy")
model.fit_generator(
    generator=train_datagen.flow_from_directory(train_dir, classes, batch_size=batch_size),
    steps_per_epoch=int(np.ceil(len(list(train_dir.iterdir())) / batch_size)),
    epochs=epoch,
    verbose=1)
3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?