1
0

勾配ブースティング決定木の性能比較

Last updated at Posted at 2024-06-02

勾配ブースティング決定木はライブラリが3つあり

  • XGBoost
  • LightGBM
  • Scikit-Learn

があります。
という訳で、この3つについてアヤメのデータセットで条件を同じにして100回検証し時間と精度で比較してみようと思います。

プログラムと出力

from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
import time

df = pd.read_csv("iris.csv")

y = df["category"]
x = df.drop("category", axis=1)
print("XGBoost:", end="")
acc1= []
start = time.time()
for i in range(100):
    x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
    model1 = XGBClassifier()
    model1.fit(x_train, y_train)
    y_pred1 = model1.predict(x_test)
    acc1.append(accuracy_score(y_test, y_pred1))
print(time.time()-start)

print("LightGBM:", end="")
acc2 = []
start = time.time()
for i in range(100):
    x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
    model2 = LGBMClassifier()
    model2.fit(x_train, y_train)
    y_pred2 = model2.predict(x_test)
    acc2.append(accuracy_score(y_test, y_pred2))
print(time.time()-start)

print("Scikit-Learn:", end="")
acc3 = []
start = time.time()
for i in range(100):
    x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
    model3 = GradientBoostingClassifier()
    model3.fit(x_train, y_train)
    y_pred3 = model3.predict(x_test)
    acc3.append(accuracy_score(y_test, y_pred3))
print(time.time()-start)

df_acc = pd.DataFrame(acc1).describe()
df_acc = pd.concat([df_acc, pd.DataFrame(acc2).describe()], axis=1)
df_acc = pd.concat([df_acc, pd.DataFrame(acc3).describe()], axis=1)
df_acc.columns = ["XGBoost", "LightGBM", "Scikit-Learn"]
df_acc

出力結果はこうなりました

XGBoost:10.425119638442993
LightGBM:8.228994369506836
Scikit-Learn:27.98310923576355

image.png

次にこれをグラフにして精度を可視化します。

plt.boxplot([acc1, acc2, acc3], labels=df_acc.columns, positions=[0, 1, 2])
plt.scatter([0, 1, 2], [sum(acc1)/100, sum(acc2)/100, sum(acc3)/100], marker="x", color="#000000")
plt.show()

Untitled.png

結果の解釈

これらの結果から考えられることとして(あくまでアヤメデータですが)
予測時間→LightGBM>XGBoost>>Scikit-Learn
精度の安定性→Scikit-Learn>XGBoost>LightGBM
精度の高さ→Scikit-Learn>LightGBM>XGBoost

まとめ

一長一短ですね

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0