Help us understand the problem. What is going on with this article?

NGBoostを試してみた

NGBoostって?

突如(?)現れた新たな勾配ブースティング系のアルゴリズム
NG → Natural Gradient

スタンフォード大学による論文実装が公開されている。
PyPIにもすでに公開されているのでpipコマンドで簡単にインストールできる。

pip3 install ngboost

なお、僕がインストールした際には

jaxlib==0.1.29

が入らなかったのでpip自体のversionを上げる必要があった。

pip install --upgrade pip

とりあえず動かしてみた

以下は公式のサンプルである。Boston Housingデータセットを利用しての動作確認。

from ngboost.ngboost import NGBoost
from ngboost.learners import default_tree_learner
from ngboost.scores import MLE
from ngboost.distns import Normal

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


X, Y = load_boston(True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)

ngb = NGBoost(Base=default_tree_learner, Dist=Normal, Score=MLE(), natural_gradient=True,verbose=False)
ngb.fit(X_train, Y_train)
Y_preds = ngb.predict(X_test)
Y_dists = ngb.pred_dist(X_test)

# test Mean Squared Error
test_MSE = mean_squared_error(Y_preds, Y_test)
print('Test MSE', test_MSE)

#test Negative Log Likelihood
test_NLL = -Y_dists.logpdf(Y_test.flatten()).mean()
print('Test NLL', test_NLL)
実行結果
Test MSE 8.807989614179135
Test NLL 4.01077

Mean Squared Error(平均二乗誤差)の値が約8.8
次に並ぶNLL(Negative Log-Likelihood)では「予測の不確実性推定値」1を表しているとのこと。

動かしてみて気づいたこと

  • ngboostでは実行するたびに予測結果が変わってしまった、そのままでは再現性がなかったため冒頭でnumpyにより乱数シードを固定する。
import numpy as np
np.random.seed(42)

XGBoost,LightGBMとの比較

上記のBoston Housingデータセットを利用してXGBoost,LightGBMと比較してみる。
ハイパーパラメータチューニング等は行っていないため、この比較はあくまで参考。

from ngboost.learners import default_tree_learner
from ngboost.scores import MLE,CRPS #Maximum Likelihood Estimationl,Continuous Ranked Probability Score
from ngboost.distns import Normal
from ngboost.ngboost import NGBoost

import xgboost as xgb
import lightgbm as lgb

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

import numpy as np
np.random.seed(42)

# Split Dataset
X, Y = load_boston(True)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2,random_state=42)

#NGBoost
ngb = NGBoost(Base=default_tree_learner, Dist=Normal, Score=MLE(), natural_gradient=True,verbose=False)
ngb.fit(X_train, Y_train)

Y_train_pred_ngb = ngb.predict(X_train)
Y_test_pred_ngb = ngb.predict(X_test)

# NLL
Y_dists = ngb.pred_dist(X_test)
test_NLL = -Y_dists.logpdf(Y_test.flatten()).mean()

# XGBoost
xgb = xgb.XGBRegressor()
xgb.fit(X_train, Y_train)

Y_train_pred_xgb = xgb.predict(X_train)
Y_test_pred_xgb = xgb.predict(X_test)

# LightGBM
lgb = lgb.LGBMRegressor()
lgb.fit(X_train, Y_train)

Y_train_pred_lgb = lgb.predict(X_train)
Y_test_pred_lgb = lgb.predict(X_test)

print("-----NGBoost-----")
print('Test NLL', test_NLL)
print('MSE train : %.3f, test : %.3f' % (mean_squared_error(Y_train, Y_train_pred_ngb), mean_squared_error(Y_test, Y_test_pred_ngb)) )

print("-----XGBoost-----")
print('MSE train : %.3f, test : %.3f' % (mean_squared_error(Y_train, Y_train_pred_xgb), mean_squared_error(Y_test, Y_test_pred_xgb)) )

print("-----lightGBM-----")
print('MSE train : %.3f, test : %.3f' % (mean_squared_error(Y_train, Y_train_pred_lgb), mean_squared_error(Y_test, Y_test_pred_lgb)) )

結果

上記の実行家結果は以下、今回はNGBoostの誤差が一番小さくなっているようだ。

結果
-----NGBoost-----
Test NLL 3.7677743
MSE train : 1.740, test : 6.719
-----XGBoost-----
MSE train : 2.256, test : 7.267
-----lightGBM-----
MSE train : 2.283, test : 8.339

実行時間は以下のとおりであった。NGBoostはだいぶ遅いようなので、大きいデータセットとの比較を検討したいところ。

実行時間
-----NGBoost-----
CPU times: user 4.29 s, sys: 643 ms, total: 4.93 s
Wall time: 3.31 s
-----XGBoost-----
CPU times: user 50.9 ms, sys: 182 ms, total: 233 ms
Wall time: 264 ms
-----lightGBM-----
CPU times: user 29.8 ms, sys: 0 ns, total: 29.8 ms
Wall time: 29.2 ms

ToDo

  • 理解できてないことだらけなので論文をしっかり読む
  • 公式のvisualize実装もあるので、そちらも動かしてみる


  1. NLLのような不確実さを推定する手法には「MC dropout」や「Deep Ensembles」があるらしい 

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away