#複数の機械学習ライブラリを一発で適応する方法です。
機械学習では、データをトレーニング用とテスト用の2つにわけ、精度を測りますよね。そこで、簡単にデータを分けられるのがKFoldです。その他にも、StatifiedKFoldやShuffleSplitなどもありますが、今回はKFoldを使います。
##まず交差検定の一つであるsklearnのKFoldについて
KFoldはデータをk個のデータセットに分け、例えば10個に分けたら、9個をトレーニング用のデータセットとして使い、残りの1個をテスト用として使います。ちなみに、この場合、分けられた10個のデータセットがテスト用として必ず1回使われるように10回検定します。defaultのコードは下記のようになっています。
from sklearn.cross_validation import KFold
KFold(n_splits=3, shuffle=False, random_state=None)
そして下記がKFold内のパラメータです。
-
n_split
データをいくつに分けるかを指定するもの。defaultは3。なお、検定はここで指定した数値の数だけ繰り返される。 -
shuffle
defaultはFalseだが、これをTrueにすることで、連続する数字の単純なグループ分けではなく、データセットの中からランダムに値を持ってきてグループを作ることができる。 -
random_state
乱数制御のパラメータでこれを数値にすることで、毎回同じデータセットが得られる。
こちら公式ドキュメント
##続いて cross_val_score についてです。
cross_val_scoreは、classifierと、トレーニング用データ、テスト用データを指定してその精度を割り出せる便利なツールです。下記がdefaultのコード。
from sklearn.model_selection import cross_val_score
cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, n_jobs=1, verbose=0, fit_params=None, pre_dispatch='2*n_jobs')
下記が主要パラメータの説明です。
-
estimator
これでclassifierの指定をします。 -
X
トレーニング用のデータの指定 -
y
テスト用のデータの指定 -
scoring
スコアのつけ方の指定。精度のaccuracyの他にもaverage_precisionやf1などがあります。リンク -
cv
crossvalidation(交差検証)の略で、データのsplitの方法を指定できます。
こちら公式ドキュメント
##本題の複数の機械学習
メインの複数の機械学習の方法を1発でapplyしちゃう方法です。
今回使うのは、DecisionTreeClassifier, KNeighborsClassifier, SVCです。まずはコードをどうぞ。ちなみにここでは、トレーニングデータセットとテスト用データセットをすでに分けている程で進めていきます。
from sklearn.model_selection import KFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
#機械学習モデルをリストに格納
models = []
models.append(("KNC",KNeighborsClassifier()))
models.append(("DTC",DecisionTreeClassifier()))
models.append(("SVM",SVC()))
#複数のclassifier の適用
results = []
names = []
for name,model in models:
kfold = KFold(n_splits=10, random_state=42)
result = cross_val_score(model,X_train,Y_train, cv = kfold, scoring = "accuracy")
names.append(name)
results.append(result)
#適用したclassifierのスコア表示
for i in range(len(names)):
print(names[i],results[i].mean())
これで結果が
KNC 0.88
DTC 0.91
SVM 0.79
みたいな感じになると思います。