背景
kerasだけには限らないことですが、学習させたいデータの数が不揃いな場合がほとんどだと思います。
データ数がちょっとの差しかない場合はあまり問題にはなりませんが、何倍もの差がある場合は数の多いデータに対してのみ予測精度の高い学習器が出来上がってしまいます。
今回はkerasで画像を学習させる際にある程度量の差を緩和する方法を書きたいと思います。
やり方
入力に重みを付けます。
データ数の多いラベルの物の重みを軽くし、データ数の少ないラベルの重みを重くします。
例えば、ラベルが3つあるデータの点数がそれぞれ、
ラベル | データ数 |
---|---|
データ1 | 100 |
データ2 | 200 |
データ3 | 300 |
だった場合、重みのデータは
ラベル | 重み |
---|---|
データ1 | 3 |
データ2 | 1.5 |
データ3 | 1 |
となります。 |
コード
fit_generator内に
fit_generator(
#----省略----
class_weight=cal_weight(data),
#----省略----
)
この項目を書き足します。
class_weightの部分はdictionary型です。
そのため、例に上げたような重みのデータは{0: 3.0, 1: 1.5, 2: 1.0}のような形で記述する必要があります。
次に重みを計算する部分ですが、ImageDateGeneratorデータの読み込み方(flowとflow_from_directory)によって変えなければならないです。
flowを使用する場合
使用する学習データのラベル配列が必要となります。
class_id_listはラベルのリストをkerasのto_categoricalで変換したものを指します。
例えば6枚の画像のデータセットに対して猫と犬というラベルが付けられている場合は
from keras.utils.np_utils import to_categorical
class_name = ['猫','犬']
class_list = ['猫','犬','猫','犬','猫','猫']
def make_id_list (label):
return class_name.index(label)
class_id_list = to_categorical(list(map(make_id_list, class_list)))
print(class_id_list)
#array([[ 1., 0.],
# [ 0., 1.],
# [ 1., 0.],
# [ 0., 1.],
# [ 1., 0.],
# [ 1., 0.]])
なぜこんなめんどくさいことをしているかというと、kerasではラベルをバイナリベクトルで扱う方が都合がいい場合が多いからです。
このバイナリベクトルを入力とすれば以下の関数から重みを取り出せます。
def cal_weight(class_id_list):
amounts_of_class_array = np.zeros(len(class_id_list[0]))
for class_id in class_id_list:
amounts_of_class_array = amounts_of_class_array + class_id
mx = np.max(amounts_of_class_array)
class_weights = {}
for q in range(0,len(amounts_of_class_array)):
class_weights[q] = round(float(math.pow(amounts_of_class_array[q]/mx, -1)),2)
return class_weights
flow_from_directoryを使用する場合
使用する学習データはクラス名が書かれた画像のフォルダを使用することを想定しています。
class_name_listに学習に使うクラスのlistを入力し、IN_DIRにクラス名が書かれた画像フォルダのディレクトリを入力してください。
def cal_weight(class_name_list,IN_DIR):
amounts_of_class_dict = {}
mx = 0
for class_name in class_name_list:
class_dir = IN_DIR + os.sep + class_name
file_list = os.listdir(class_dir)
amounts_of_class_dict[class_name] = len(file_list)
if mx < len(file_list):
mx = len(file_list)
class_weights = {}
count = 0
for class_name in class_name_list:
class_weights[count] = round(float(math.pow(amounts_of_class_dict[class_name]/mx, -1)),2) #重み=(データ数/最大値)の逆数
count += 1
return class_weights