概要
言いたいことはタイトルに書いてある通りです(笑)
実際のサービスのデータを用いてデータ分析したり様々な学習モデルを作成していると、だいたいはデータに偏りがあることがほとんどです。
例えば、スパムメールの分類などは良い例なのですが、殆どのメールが正常である中で数少ないスパムの分類をすると、「全部正常だと分類したほうが正解率が上がってしまい表面上は良いモデル」が出来上がってしまい、実用に耐えないものが出来てしまいます。
今回はそういった偏りのあるデータセット(不均衡データといいます)に対してクラス分類モデルを作成する際のポイントを記述します。
分類を行うモデルはランダムフォレストを指定します。しかし、他の分類モデルでも応用できる考え方ですので、これを押さえておくと良いと思います。
scikit-learnのRandomForestClassifierのドキュメントによると、 class_weight
のパラメータを balanced
を指定するとクラスごとのサンプル数の重みを自動で付けてくれるとのこと。便利ですね。
中身の計算方法は n_samples / (n_classes * np.bincount(y))
です。
実際にデータを用いて確かめる
不均衡なデータに対してclass_weightを考慮しない分類と考慮する分類でどのような差異があるのか見てみましょう。
今回用いたpythonコード( jupyter notebook形式 ): https://github.com/kazuki-hayakawa/RandomForest_imbalanced/blob/master/RandomForest_notebook.ipynb
スパムの分類のように正しいクラスを正例、誤りなどに用いられるクラスを負例と言ったりしますが、今回はサンプル数の少ない方を負例と呼称します。
今回は正例を7000, 負例を100という不均衡データを用意しました。
# 不均衡データセットの作成
import numpy as np
np.random.seed(0)
X1 = np.random.rand(7000, 2)
X2 = np.random.rand(100, 2) *0.2 + 0.9
y1 = np.array([1 for _ in range(7000)])
y2 = np.array([2 for _ in range(100)])
正例のラベルを1, 負例のラベルを2としています。
計算のためにデータを7:3の割合で訓練データとテストデータに分割します。
from sklearn.cross_validation import train_test_split
X = np.concatenate([X1, X2])
y = np.concatenate([y1, y2])
# データを 7:3 で訓練用データとテストデータに分割する
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
不均衡データを考慮しない分類
ではまずは、不均衡データを考慮しないで普通に分類をしてみます。
(サンプルなのでグリッドサーチなどによるハイパーパラメータの調整は行っていません)
from sklearn.ensemble import RandomForestClassifier
# 不均衡データの考慮がない場合
clf = RandomForestClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
結果は以下の通りです。
(乱数生成なので人によっては違った結果が出てくることがありますが、デモなのでご容赦ください)
正解率: 0.993896713615
負例の予測数: 21
実際の負例の数: 30
負例を正しく識別できている割合: 0.7
不均衡データを考慮する分類
では次にデータの偏りを考慮してclass_weightを明示的に指定して分類を行います。
# 不均衡データの考慮をする場合
clf = RandomForestClassifier(class_weight='balanced')
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
結果は以下のとおりです。
正解率: 0.992018779343
負例の予測数: 25
実際の負例の数: 30
負例を正しく識別できている割合: 0.8333333333333334
まとめ
結果を見比べてみると
不均衡を考慮する | 不均衡を考慮しない | |
---|---|---|
正解率 | 0.9920 | 0.9938 |
負例の検出率 | 0.833 | 0.700 |
このように、不均衡を考慮しなくてもそれなりの正解率は出ていますが、これは「ほとんど1のラベルなら全部1とみなしてもほとんど正解してしまう」現象ですね。
しかし、不均衡を考慮することで、負例の検出率が上がっていることがわかると思います。
これはスパムを間違って正しいメールだと誤認識しないというような意味を持ちます。
実際のサービスにおいては負例の検出力が高いモデルが求められている場合がありますので、モデルの用途やデータセットに応じてこういった工夫が必要になるかと思われます。
参考にしたサイト
Scikit-learnによるランダムフォレスト
→ RandomForestClassifier の他のパラメータの意味が書かれていて勉強になりました。