LoginSignup
34
28

More than 5 years have passed since last update.

[keras]各学習データの量が均一ではない場合の対処方

Last updated at Posted at 2018-01-31

背景

kerasだけには限らないことですが、学習させたいデータの数が不揃いな場合がほとんどだと思います。
データ数がちょっとの差しかない場合はあまり問題にはなりませんが、何倍もの差がある場合は数の多いデータに対してのみ予測精度の高い学習器が出来上がってしまいます。
今回はkerasで画像を学習させる際にある程度量の差を緩和する方法を書きたいと思います。

やり方

入力に重みを付けます。
データ数の多いラベルの物の重みを軽くし、データ数の少ないラベルの重みを重くします。
例えば、ラベルが3つあるデータの点数がそれぞれ、

ラベル データ数
データ1 100
データ2 200
データ3 300

だった場合、重みのデータは

ラベル 重み
データ1 3
データ2 1.5
データ3 1

となります。

コード

fit_generator内に

fit_generator(
#----省略----
class_weight=cal_weight(data), 
#----省略----
)

この項目を書き足します。
class_weightの部分はdictionary型です。
そのため、例に上げたような重みのデータは{0: 3.0, 1: 1.5, 2: 1.0}のような形で記述する必要があります。
次に重みを計算する部分ですが、ImageDateGeneratorデータの読み込み方(flowとflow_from_directory)によって変えなければならないです。

flowを使用する場合

使用する学習データのラベル配列が必要となります。

class_id_listはラベルのリストをkerasのto_categoricalで変換したものを指します。
例えば6枚の画像のデータセットに対して猫と犬というラベルが付けられている場合は

from keras.utils.np_utils import to_categorical
class_name = ['猫','犬']
class_list = ['猫','犬','猫','犬','猫','猫']

def make_id_list (label):
    return class_name.index(label)

class_id_list = to_categorical(list(map(make_id_list, class_list)))
print(class_id_list)
#array([[ 1.,  0.],
#       [ 0.,  1.],
#       [ 1.,  0.],
#       [ 0.,  1.],
#       [ 1.,  0.],
#       [ 1.,  0.]])

なぜこんなめんどくさいことをしているかというと、kerasではラベルをバイナリベクトルで扱う方が都合がいい場合が多いからです。
このバイナリベクトルを入力とすれば以下の関数から重みを取り出せます。

def cal_weight(class_id_list):
    amounts_of_class_array = np.zeros(len(class_id_list[0]))
    for class_id in class_id_list:
        amounts_of_class_array = amounts_of_class_array + class_id
    mx = np.max(amounts_of_class_array)
    class_weights = {}
    for q in range(0,len(amounts_of_class_array)):
        class_weights[q] = round(float(math.pow(amounts_of_class_array[q]/mx, -1)),2)
    return class_weights

flow_from_directoryを使用する場合

使用する学習データはクラス名が書かれた画像のフォルダを使用することを想定しています。
class_name_listに学習に使うクラスのlistを入力し、IN_DIRにクラス名が書かれた画像フォルダのディレクトリを入力してください。

def cal_weight(class_name_list,IN_DIR):
    amounts_of_class_dict = {}
    mx = 0
    for class_name in class_name_list:
        class_dir = IN_DIR + os.sep + class_name
        file_list = os.listdir(class_dir)
        amounts_of_class_dict[class_name] = len(file_list)
        if mx < len(file_list):
            mx = len(file_list)
    class_weights = {}
    count = 0
    for class_name in class_name_list:
        class_weights[count] = round(float(math.pow(amounts_of_class_dict[class_name]/mx, -1)),2) #重み=(データ数/最大値)の逆数
        count += 1
    return class_weights
34
28
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
34
28