6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

機械学習向けデータのクラス間のサンプル数を揃える with Python

Last updated at Posted at 2016-11-28

機械学習をする上でクラス間のサンプル数が揃っていることが望ましいです.
しかし実際にはそんなきれいなデータばかりではなく, クラス間のサンプル数が異なるデータもしばしば.

今回, ラベルデータに記されたクラス間のサンプル数を揃える処理をPythonで実装したのでメモ.

やりたいこと

以下の様なデータ配列とそのラベルデータがあった際に

  • ラベルデータのサンプル数を揃える
  • 1番の処理に合わせてデータ配列の要素も取り除く
# データ配列
data = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
# ラベル配列
label = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])

###############
# データ処理...
###############

>>>data
[10 11 12 14 15 16]
>>>label
[0 0 1 1 2 2]

コード

詳細はコメントに書いてあります.
簡単に言うと, 最小サンプル数よりも多いサンプル数を持つクラスに対して以下の処理を行っています.

  1. そのクラスのデータ要素のインデックス配列を取得
  2. random.sample()を利用して, そのインデックス配列からランダムに削除する個数分の要素のインデックスを取得.
  3. 取得したインデックスのデータとラベルを削除
import numpy as np
import random

# データ配列
data = np.array(range(10,20))
print("data:", data)
# ラベル配列
label = np.array([0, 0, 1, 1, 1, 2, 2, 2, 2, 2])
print("label:", label)
# 全クラスのサンプル数
sample_nums = np.array([])


print("\n各クラスのサンプル数を計算")
for i in range(max(label)+1):
    # 各クラスのサンプル数
    sample_num = np.sum(label == i)
    # サンプル数管理配列に追加
    sample_nums = np.append(sample_nums, sample_num)
print("sample_nums:", sample_nums)

# 全クラス内の最小サンプル数
min_num = np.min(sample_nums)
print("min_num:", min_num)


print("\n各クラスのサンプル数を揃える")
for i in range(len(sample_nums)):

    # 対象クラスのサンプル数と最小サンプル数の差
    diff_num = int(sample_nums[i] - min_num)
    print("クラス%d 削除サンプル数: %d (%0.2f%)" % (i, diff_num, (diff_num/sample_nums[i])*100))

    # 削除する必要が無い場合はスキップ
    if diff_num == 0:
        continue

    # 削除する要素のインデックス
    # タプルになっているのでlistに変換 (0番目のインデックスに配置されている)
    indexes = list(np.where(label == i)[0])
    print("\tindexes:", indexes)

    # 削除するデータのインデックス
    del_indexes = random.sample(indexes, diff_num)
    print("\tdel_indexes:", del_indexes)

    # データから削除
    data = np.delete(data, del_indexes)
    label = np.delete(label, del_indexes)


print("\ndata:", data)
print("label:", label)

実行結果

data: [10 11 12 13 14 15 16 17 18 19]
label: [0 0 1 1 1 2 2 2 2 2]

各クラスのサンプル数を計算
sample_nums: [ 2.  3.  5.]
min_num: 2.0

各クラスのサンプル数を揃える
クラス0 削除サンプル数: 0 (0.00%)
クラス1 削除サンプル数: 1 (33.33%)
	indexes: [2, 3, 4]
	del_indexes: [3]
クラス2 削除サンプル数: 3 (60.00%)
	indexes: [4, 5, 6, 7, 8]
	del_indexes: [7, 8, 6]

data: [10 11 12 14 15 16]
label: [0 0 1 1 2 2]

終わりに

Pythonに詳しい人ならもっと効率化出来るんだろうな

6
6
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?