勾配ブースティング決定木はライブラリが3つあり
- XGBoost
- LightGBM
- Scikit-Learn
があります。
という訳で、この3つについてアヤメのデータセットで条件を同じにして100回検証し時間と精度で比較してみようと思います。
プログラムと出力
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
import time
df = pd.read_csv("iris.csv")
y = df["category"]
x = df.drop("category", axis=1)
print("XGBoost:", end="")
acc1= []
start = time.time()
for i in range(100):
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
model1 = XGBClassifier()
model1.fit(x_train, y_train)
y_pred1 = model1.predict(x_test)
acc1.append(accuracy_score(y_test, y_pred1))
print(time.time()-start)
print("LightGBM:", end="")
acc2 = []
start = time.time()
for i in range(100):
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
model2 = LGBMClassifier()
model2.fit(x_train, y_train)
y_pred2 = model2.predict(x_test)
acc2.append(accuracy_score(y_test, y_pred2))
print(time.time()-start)
print("Scikit-Learn:", end="")
acc3 = []
start = time.time()
for i in range(100):
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.3, random_state=i)
model3 = GradientBoostingClassifier()
model3.fit(x_train, y_train)
y_pred3 = model3.predict(x_test)
acc3.append(accuracy_score(y_test, y_pred3))
print(time.time()-start)
df_acc = pd.DataFrame(acc1).describe()
df_acc = pd.concat([df_acc, pd.DataFrame(acc2).describe()], axis=1)
df_acc = pd.concat([df_acc, pd.DataFrame(acc3).describe()], axis=1)
df_acc.columns = ["XGBoost", "LightGBM", "Scikit-Learn"]
df_acc
出力結果はこうなりました
XGBoost:10.425119638442993
LightGBM:8.228994369506836
Scikit-Learn:27.98310923576355
次にこれをグラフにして精度を可視化します。
plt.boxplot([acc1, acc2, acc3], labels=df_acc.columns, positions=[0, 1, 2])
plt.scatter([0, 1, 2], [sum(acc1)/100, sum(acc2)/100, sum(acc3)/100], marker="x", color="#000000")
plt.show()
結果の解釈
これらの結果から考えられることとして(あくまでアヤメデータですが)
予測時間→LightGBM>XGBoost>>Scikit-Learn
精度の安定性→Scikit-Learn>XGBoost>LightGBM
精度の高さ→Scikit-Learn>LightGBM>XGBoost
まとめ
一長一短ですね