Help us understand the problem. What is going on with this article?

不均衡データにおけるsampling

More than 5 years have passed since last update.

動機

1:10,000以上の不均衡データを使用した分類器の学習を効果的に行いたいな、ってのがモチベーションです。
web系のCV分析などされている方はこの辺り悩まれているのではないかと思います。
僕もその一人ですw

このブログに書いたこと

  • 不均衡データにおけるサンプリング手法の概要
  • under-sampling, over-samplingの具体的な手法の概要
  • 参考コード

不均衡データにおけるサンプリング手法の概要

参考文献

上の論文の適当なまとめ

本当に適当にまとめているので詳しくは論文をご覧ください。。

不均衡データに対する対処法

- algorithm-level approaches
  不均衡を調整する係数をモデルに導入する(コスト関数を調整する)
- data-level approaches
  多数派データを減少させ、少数派データを増加させる手法。
  順序については、undersampling, oversamplingの順に行う。
今回は主にdata-level approachesについて記載します。

サンプリングの種類

大まかに分けてunder-sampling, over-sampling, Hybrid Methodsがあります。
- under-sampling
  ・ 多数派データを減少させる
  ・ Random undersamplingとその他の手法に分けられる
  ・ Random undersamplingでは有益なデータを削除してしまう可能性がある
   ⇒ クラスターベースの手法なら各クラスdistinctなデータ群となるため、
     一部の有益なデータのみを消す事は無い

- over-sampling
  ・ 少数派データを増加させる
  ・ Random oversamplingとその他の手法に分けられる
  ・ Random oversamplingでは過学習を引き起こしやすい
   ⇒ 既存のデータの複製ではなく周辺のデータ(にノイズを加えたデータ)を増やす事により解決

- Hybrid Methods
  under-sampling, over-samplingの両者を行う

エラーの種類
  • Intrisic noise: サンプルに内在する予測不能の誤差
  • Squared bias: 系統的誤差(分類器由来の誤差)
  • Variance: サンプリングによる誤差

式で見ると以下のようになります。
erro.png

サンプリングアルゴリズム

- under-sampling
  ・ 少数派データ/クラスタから最遠/最近のデータ/クラスタを残す/落とす
   (クラスタの場合は重心の距離などで判定)
  ・ k-meansで全データをクラスタリングし、各クラスタについてpositive、negative sampleの比によりnegative sampleの削減数を決定
   ← 今回はこの手法を使っています
- over-sampling
  ・ SMOTEという手法がデファクトスタンダードっぽい
    k-NNベースで少数派サンプルの近傍5個のうち1つについてノイズを付加してサンプリング

参考コード

under-sampling

まずはunder-samplingについて

def undersampling(imp_info, cv, m):
    # minority data
    minodata = imp_info[np.where(cv==1)[0]]

    # majority data
    majodata = imp_info[np.where(cv==0)[0]]

    # kmeans2でクラスタリング
    whitened = whiten(imp_info) # 正規化(各軸の分散を一致させる)
    centroid, label = kmeans2(whitened, k=3) # kmeans2
    C1 = []; C2 = []; C3 = []; # クラスタ保存用
    C1_cv = []; C2_cv = []; C3_cv = [] 
    for i in xrange(len(imp_info)):
        if label[i] == 0:
            C1 += [whitened[i]]
            C1_cv.append(cv[i])
        elif label[i] == 1:
            C2 += [whitened[i]]
            C2_cv.append(cv[i])
        elif label[i] == 2:
            C3 += [whitened[i]]
            C3_cv.append(cv[i])

    # numpy形式の方が扱いやすいため変換
    C1 = np.array(C1); C2 = np.array(C2); C3 = np.array(C3) 
    C1_cv = np.array(C1_cv); C2_cv = np.array(C2_cv); C3_cv = np.array(C3_cv);

    # 各クラスの少数派データの数
    C1_Nmajo = sum(1*(C1_cv==0)); C2_Nmajo = sum(1*(C2_cv==0)); C3_Nmajo = sum(1*(C3_cv==0)) 

    # 各クラスの多数派データの数
    C1_Nmino = sum(1*(C1_cv==1)); C2_Nmino = sum(1*(C2_cv==1)); C3_Nmino = sum(1*(C3_cv==1))
    t_Nmino = C1_Nmino + C2_Nmino + C3_Nmino

    # 分母に0が出る可能性があるので1をプラスしておく
    C1_MAperMI = float(C1_Nmajo)/(C1_Nmino+1); C2_MAperMI = float(C2_Nmajo)/(C2_Nmino+1); C3_MAperMI = float(C3_Nmajo)/(C3_Nmino+1);

    t_MAperMI = C1_MAperMI + C2_MAperMI + C3_MAperMI

    under_C1_Nmajo = int(m*t_Nmino*C1_MAperMI/t_MAperMI)
    under_C2_Nmajo = int(m*t_Nmino*C2_MAperMI/t_MAperMI)
    under_C3_Nmajo = int(m*t_Nmino*C3_MAperMI/t_MAperMI)
    t_under_Nmajo = under_C1_Nmajo + under_C2_Nmajo + under_C3_Nmajo

#    draw(majodata, label)

    # 各グループで多数派と少数派が同数になるようにデータを削除
    C1 = C1[np.where(C1_cv==0),:][0]
    random.shuffle(C1)
    C1 = np.array(C1)
    C1 = C1[:under_C1_Nmajo,:]
    C2 = C2[np.where(C2_cv==0),:][0]
    random.shuffle(C2)
    C2 = np.array(C2)
    C2 = C2[:under_C2_Nmajo,:]
    C3 = C3[np.where(C3_cv==0),:][0]
    random.shuffle(C3)
    C3 = np.array(C3)
    C3 = C3[:under_C3_Nmajo,:]

    cv_0 = np.zeros(t_under_Nmajo); cv_1 = np.ones(len(minodata))
    cv_d = np.hstack((cv_0, cv_1))

    info = np.vstack((C1, C2, C3, minodata))

    return cv_d, info

over-sampling

続いてover-samplingについて

class SMOTE(object):
    def __init__(self, N):
        self.N = N
        self.T = 0

    def oversampling(self, smp, cv):
        mino_idx = np.where(cv==1)[0]
        mino_smp = smp[mino_idx,:]

        # kNNの実施
        mino_nn = []

        for idx in mino_idx:
            near_dist = np.array([])
            near_idx = np.zeros(nnk)
            for i in xrange(len(smp)):
                if idx != i:
                    dist = self.dist(smp[idx,:], smp[i,:])

                    if len(near_dist)<nnk: # 想定ご近所さん数まで到達していなければ問答無用でlistに追加
                        tmp = near_dist.tolist()
                        tmp.append(dist)
                        near_dist = np.array(tmp)
                    elif sum(near_dist[near_dist > dist])>0:
                        near_dist[near_dist==near_dist.max()] = dist
                        near_idx[near_dist==near_dist.max()] = i
            mino_nn.append(near_idx)
        return self.create_synth( smp, mino_smp, np.array(mino_nn, dtype=np.int) )

    def dist(self, smp_1, smp_2):
        return np.sqrt( np.sum((smp_1 - smp_2)**2) )

    def create_synth(self, smp, mino_smp, mino_nn):
        self.T = len(mino_smp)
        if self.N < 100:
            self.T = int(self.N*0.01*len(mino_smp))
            self.N = 100
        self.N = int(self.N*0.01)

        rs = np.floor( np.random.uniform(size=self.T)*len(mino_smp) )

        synth = []
        for n in xrange(self.N):
            for i in rs:
                nn = int(np.random.uniform(size=1)[0]*nnk)
                dif = smp[mino_nn[i,nn],:] - mino_smp[i,:]
                gap = np.random.uniform(size=len(mino_smp[0]))
                tmp = mino_smp[i,:] + np.floor(gap*dif)
                tmp[tmp<0]=0
                synth.append(tmp)
        return synth   

実験してみた感想

  • 適当にスパースな2,000次元くらいのデータを発生させて実験
  • algorithm-level approachesでは98%程度の正確性, 100%の再現性
  • under-sampling+over-samplingでは11%程度の正確性, 100%の再現性

data-level approachesよりもalgorithm-level approachesの方が汚いデータに対してロバストかなと。
もしかしたらコードが間違ってるかもですが。。。
不備があればご教授下さいm(_ _)m

まとめ

data-level approachesではバッチ処理が対象だし計算回数が逆に多くなる可能性も高いため、結局algorithm-level approachesで調整する方が実用的ではなかろうかと感じました。
algorithm-level approachesでは少数派データサンプルでコストを高くしてweight調整の勾配を大きくすればよいだけなので、計算時間にほぼ影響ありませんし。。
他に不均衡データに対する良い対処方法がありましたらコメントいただきたく存じます。

shima_x
分析とか雑用とかやってます
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした