2
0

交差検証のモデル保存と精度一覧を出す関数を作った

Posted at

交差検証はScikit-Learnだと「sklearn.metrics.cross_val_score」にあります。
しかしそこでできるのは精度の計測で、一番良かった時のモデルが保存されているわけではありません(もしかしたら自分が知らないだけかも)。
そこで今回は交差検証をしてなおかつ詳細な精度の分かる関数を作ってみました。

関数

import pandas as pd
from sklearn.metrics import classification_report
def closs_val_model_accuracy(model, x, y, cv=50):
    ylist = list(set(y.values.astype("str")))
    models = []
    acc = []
    for i in range(cv):
        x_test = x.loc[int(i*len(x)/cv):int((i+1)*len(x)/cv)]
        y_test = y.loc[int(i*len(x)/cv):int((i+1)*len(x)/cv)]
        x_train = pd.concat([x.loc[0:int((i)*len(x)/cv)], x.loc[int((i+2)*len(x)/cv):]])
        y_train = pd.concat([y.loc[0:int((i)*len(x)/cv)], y.loc[int((i+2)*len(x)/cv):]])
        model = model
        model.fit(x_train, y_train)
        y_pred = model.predict(x_test)
        rep = classification_report(y_test, y_pred, output_dict=True)
        tmp = []
        for i in range(len(ylist)):
            try:
                tmp.append(rep[str(ylist[i])]["precision"])
                tmp.append(rep[str(ylist[i])]["recall"])
                tmp.append(rep[str(ylist[i])]["f1-score"])
                tmp.append(rep[str(ylist[i])]["support"])
            except:
                tmp.append(None)
                tmp.append(None)
                tmp.append(None)
                tmp.append(None)
                print(rep)
        tmp.append(rep["accuracy"])
        acc.append(tmp)
        models.append([model, rep["accuracy"]])
    columns = []
    for i in range(len(ylist)):
        for col in ["precision", "recall", "f1-score", "support"]:
            columns.append(ylist[i]+"_"+col)
    columns.append("accuracy")
    df_acc = pd.DataFrame(acc)
    df_acc.columns = columns
    return df_acc.describe(), models, df_acc

使用例

ここではScikit-Learnにある乳がんデータを使用します。

from lightgbm import LGBMClassifier
import numpy as np
import pandas as pd

df = pd.read_csv("breast_cancer.csv")

y = df["y"]
x = df.drop("y", axis=1)

model = LGBMClassifier()
df_dcb, models, df_acc = closs_val_model_accuracy(model, x, y, cv=50)
{'0.0': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 13}, 'accuracy': 1.0, 'macro avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 13}, 'weighted avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 13}}
{'1.0': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}, 'accuracy': 1.0, 'macro avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}, 'weighted avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}}
{'1.0': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}, 'accuracy': 1.0, 'macro avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}, 'weighted avg': {'precision': 1.0, 'recall': 1.0, 'f1-score': 1.0, 'support': 12}}

この出力は片方のラベルしかテストデータに無かった時の物です。

df_dcb

image.png

このように詳細な精度の分布が分かります。
count(サンプル数)が異なるのは先ほどの片方しかラベルが無い場合です。

これで最適なモデルはsorted関数を使って一番スコアの高かったモデルを使えば最適なモデルになります。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0