Edited at

機械学習向けデータのクラス間のサンプル数を揃える with Python

More than 1 year has passed since last update.

機械学習をする上でクラス間のサンプル数が揃っていることが望ましいです.

しかし実際にはそんなきれいなデータばかりではなく, クラス間のサンプル数が異なるデータもしばしば.

今回, ラベルデータに記されたクラス間のサンプル数を揃える処理をPythonで実装したのでメモ.


やりたいこと

以下の様なデータ配列とそのラベルデータがあった際に


  • ラベルデータのサンプル数を揃える

  • 1番の処理に合わせてデータ配列の要素も取り除く

# データ配列

data = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
# ラベル配列
label = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])

###############
# データ処理...
###############

>>>data
[10 11 12 14 15 16]
>>>label
[0 0 1 1 2 2]


コード

詳細はコメントに書いてあります.

簡単に言うと, 最小サンプル数よりも多いサンプル数を持つクラスに対して以下の処理を行っています.


  1. そのクラスのデータ要素のインデックス配列を取得

  2. random.sample()を利用して, そのインデックス配列からランダムに削除する個数分の要素のインデックスを取得.

  3. 取得したインデックスのデータとラベルを削除

import numpy as np

import random

# データ配列
data = np.array(range(10,20))
print("data:", data)
# ラベル配列
label = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
print("label:", label)
# 全クラスのサンプル数
sample_nums = np.array([])

print("\n各クラスのサンプル数を計算")
for i in range(max(label)+1):
# 各クラスのサンプル数
sample_num = np.sum(label == i)
# サンプル数管理配列に追加
sample_nums = np.append(sample_nums, sample_num)
print("sample_nums:", sample_nums)

# 全クラス内の最小サンプル数
min_num = np.min(sample_nums)
print("min_num:", min_num)

print("\n各クラスのサンプル数を揃える")
for i in range(len(sample_nums)):

# 対象クラスのサンプル数と最小サンプル数の差
diff_num = int(sample_nums[i] - min_num)
print("クラス%d 削除サンプル数: %d (%0.2f%)" % (i, diff_num, (diff_num/sample_nums[i])*100))

# 削除する必要が無い場合はスキップ
if diff_num == 0:
continue

# 削除する要素のインデックス
# タプルになっているのでlistに変換 (0番目のインデックスに配置されている)
indexes = list(np.where(label == i)[0])
print("\tindexes:", indexes)

# 削除するデータのインデックス
del_indexes = random.sample(indexes, diff_num)
print("\tdel_indexes:", del_indexes)

# データから削除
data = np.delete(data, del_indexes)
label = np.delete(label, del_indexes)

print("\ndata:", data)
print("label:", label)


実行結果

data: [10 11 12 13 14 15 16 17 18 19]

label: [0 0 1 1 1 2 2 2 2 2]

各クラスのサンプル数を計算
sample_nums: [ 2. 3. 5.]
min_num: 2.0

各クラスのサンプル数を揃える
クラス0 削除サンプル数: 0 (0.00%)
クラス1 削除サンプル数: 1 (33.33%)
indexes: [2, 3, 4]
del_indexes: [3]
クラス2 削除サンプル数: 3 (60.00%)
indexes: [4, 5, 6, 7, 8]
del_indexes: [7, 8, 6]

data: [10 11 12 14 15 16]
label: [0 0 1 1 2 2]


終わりに

Pythonに詳しい人ならもっと効率化出来るんだろうな