5
8

More than 1 year has passed since last update.

scikit-learn all_estimators関数を使って,様々な学習器を試す

Last updated at Posted at 2021-11-10

#目的

機械学習において,学習器の選定は重要な事項であるが,自分でプログラムを書いて様々な学習器を試すには労力がかかる.
scikit-learnにはすべての学習器を読み出して,試すための関数all_estimatorsが用意されており,これを使用する方法をまとめる.

#使い方
データには,テキストデータを使用しており,教師あり学習にて分類を行う.

.py
import numpy as np
import pandas as pd
import glob
from sklearn.metrics import classification_report, make_scorer, accuracy_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import KFold, StratifiedKFold
import ngram
import random
import os
from transformers import BertJapaneseTokenizer
from sklearn.utils import all_estimators

#Bertの分かち書き関数
bert_preprocess = BertJapaneseTokenizer.from_pretrained("daigo/bert-base-japanese-sentiment")
def tokenize(word):
    return bert_preprocess.tokenize(word)

#学習データの読み込みと成型
train_df = pd.read_csv("./train.csv")
test_df = pd.read_csv("./test.csv")
train_text = train_df["text"].values.astype('U')
test_text = test_df["text"].values.astype('U')
y = train_df["label"].values.astype("int8")

#ベクトル化
vectorizer = TfidfVectorizer(analyzer=tokenize)
vectorizer.fit(train_text)
X = vectorizer.transform(train_text)

allAlgorithms = all_estimators(type_filter ="classifier") #教師あり学習の分類手法をすべて呼び出す

#15分割交差検証にて,各分類手法の精度の比較を行う.
kf = KFold(n_splits=15) # cross validationはGroupKFoldで行う
for(name, algorithm) in allAlgorithms:
    # 細かいパラメータの設定を必要とする学習法もあるため,エラーになる場合を考慮
    try: 
        train_check = []
        for fold, (tr_idx, va_idx) in enumerate(kf.split(X, y)):
            x_train, y_train = X[tr_idx], y[tr_idx]
            x_valid, y_valid = X[va_idx], y[va_idx]
            m = algorithm()
            m.fit(x_train, y_train)
            p = m.predict(x_valid)
            p = np.where(p<0.5, 0, 1)
            train_check.append(accuracy_score(y_valid, p))
        print(f"{name:<30}の精度= {np.array(train_check).mean(axis=0)}")
    except:
        pass

分類手法それぞれの精度を出力

.py
AdaBoostClassifier            の精度= 0.8816950383667025
BaggingClassifier             の精度= 0.8799853826840333
BernoulliNB                   の精度= 0.8782442112277192
CalibratedClassifierCV        の精度= 0.9080909995452726
ComplementNB                  の精度= 0.8924587255921589
DecisionTreeClassifier        の精度= 0.8503017260138698
DummyClassifier               の精度= 0.03308345827086457
ExtraTreeClassifier           の精度= 0.8084926005465735
ExtraTreesClassifier          の精度= 0.888767328047688
GradientBoostingClassifier    の精度= 0.8717660689174932
KNeighborsClassifier          の精度= 0.8424500962731846
LinearSVC                     の精度= 0.9049886468177321
LogisticRegression            の精度= 0.9007845926886406
LogisticRegressionCV          の精度= 0.9066867467167318
MLPClassifier                 の精度= 0.8885936911424166
MultinomialNB                 の精度= 0.884156120138129
NearestCentroid               の精度= 0.8685759222490856
NuSVC                         の精度= 0.8632568100334217
PassiveAggressiveClassifier   の精度= 0.8894950422686552
Perceptron                    の精度= 0.8858933896415154
RandomForestClassifier        の精度= 0.8833649241445343
RidgeClassifier               の精度= 0.8981822902362633
RidgeClassifierCV             の精度= 0.8978822900861884
SGDClassifier                 の精度= 0.9068875472173823
SVC                           の精度= 0.9033816425120772

それぞれの分類手法での精度が確認できる.

5
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
8