はじめに
Ridge回帰、Lasso回帰、Elastic Net回帰についてまとめました(もっと追記していきます)。
正則化
これらの回帰を理解するにあたって、正則化というものがキーとなります。
正則化は、ペナルティを課すことでモデルの過学習を抑制する手法です。
学習データへのモデルの依存を減らした上で、最もフィッティングできるパラメータ(重み)の組み合わせをみつけることができます。
通常の線形回帰だと過学習によってバリアンスが大きくなるので、これを小さくするためにバイアスを増やすイメージです(バリアンスとバイアスのトレードオフ)。
機械学習の回帰問題に正則化を適用するとき、以下の3つのアプローチが代表的です。
各アプローチで、ペナルティ項がそれぞれ異なります。
- Lasso回帰(L1正則化)
- Ridge回帰(L2正則化)
- Elastic Net回(L1+L2正則化)
Ridge回帰
予測性能誤差をErrorとすると、Ridge回帰では、ペナルティ項に重みの平方和の合計を用います(L2ペナルティ)。
モデル学習時は、正則化パラメータλの値によってペナルティ項の強さを調節します。
cross-varidationで最適なλを決定します。
実践
Ridge回帰でcross-validationを行うには、RidgeCV
を使用します。
以下でコメントしていますが、複数のα(λ)の候補(0.1, 1.0, 10.0)でcross-validationを行って、一番精度のよいαをモデルとして返してくれます。
from sklearn.linear_model import RidgeCV
# デフォルトのcvはNone(Leave-One-Out cross-validation)
# 与えたalphas(λ)のうちNegative MAEが高い(高いほどいい)ものを返す
ridge_cv_model = RidgeCV(alphas=(0.1, 1.0, 10.0),scoring='neg_mean_absolute_error')
ridge_cv_model.fit(X_train,y_train)
# テストデータの予測
test_predictions = ridge_cv_model.predict(X_test)
# MAE, MSE, RMSEの導出
MAE = mean_absolute_error(y_test,test_predictions)
MSE = mean_squared_error(y_test,test_predictions)
RMSE = np.sqrt(MSE)
Lasso回帰
Lasso回帰では、ペナルティ項に重みの絶対値の合計を用います(L1ペナルティ)。
Lasso回帰の特徴として、重要でない説明変数の重みを0にする点があります。
必要な変数だけモデルに利用されるため、開発者がどの変数が重要かを認識しやすくなります。
ただ、予測性能はRidge回帰の方が高いため、こちらの方がよく使われます。
実践
sklearnのLassoCV
では、指定範囲内のαを探索するので、RidgeCV
のようにαの候補を与える必要はありません。
from sklearn.linear_model import LassoCV
# 1~100まで0.1刻みでαをかえる
lasso_cv_model = LassoCV(eps=0.1,n_alphas=100,cv=5)
lasso_cv_model.fit(X_train,y_train)
# テストデータの予測
test_predictions = lasso_cv_model.predict(X_test)
Elastic Net回帰
Elastic Net回帰はLassoのL1ペナルティとRidgeのL2ペナルティを両方加えたアプローチです。
Lasso回帰のように説明変数を取り除くとモデル性能を落とすことに繋がりかねない、という問題点をカバーするための方法となります。
パラメータとしては以下のものがあります。
- alpha: 正則化の強さ。0以上の値を取り0だと線形回帰と同じ。
- l1_ratio: L1正則化とL2正則化の比率。0~1の値をとる。
実践
from sklearn.linear_model import ElasticNetCV
# L1, L2ペナルティの比率をl1_ratioで指定する
elastic_model = ElasticNetCV(l1_ratio=[.1, .5, .7,.9, .95, .99, 1],tol=0.01)
elastic_model.fit(X_train,y_train)
# テストデータの予測
test_predictions = elastic_model.predict(X_test)
参考資料