LoginSignup
26
16

More than 3 years have passed since last update.

LogisticRegressionのsolverはデフォルト値がlbfgsに変わったので注意

Last updated at Posted at 2019-12-15

はじめに

scikit-learnライブラリのロジスティック回帰(LogisticRegression)は、バージョン0.22においてsolverのデフォルト値がliblinearからlbfgsに変更されました。

この変更により、同じコードでも過去とは実行結果が異なる、あるいはエラーが出力されることが想定されるのでメモ。

たとえばこんな事象

L1正規化を行うような、下記のコードでエラーが発生します。

lr_l1 = LogisticRegression(C=C, penalty='l1').fit(X_train, y_train)

以下はエラー内容。

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/devp/linear_model.py in 
      1 for C, marker in zip([0.001, 1, 100], ['o', '^', 'v']):
----> 2     lr_l1 = LogisticRegression(C=C, penalty='l1').fit(X_train, y_train)
      3     print('Training accuracy of l1 logreg with C={:.f3}: {:.2f}'.format(C, lr_l1.score(X_train, y_train)))
      4     print('Test accuracy of l1 logreg with C={:.f3}: {:.2f}'.format(C, lr_l1.score(X_test, y_test)))
      5     plt.plot(lr_l1.coef_.T, marker, label='C={:.3f}'.format(C))

/usr/local/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py in fit(self, X, y, sample_weight)
   1484         The SAGA solver supports both float64 and float32 bit arrays.
   1485         """
-> 1486         solver = _check_solver(self.solver, self.penalty, self.dual)
   1487 
   1488         if not isinstance(self.C, numbers.Number) or self.C < 0:

/usr/local/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py in _check_solver(solver, penalty, dual)
    443     if solver not in ['liblinear', 'saga'] and penalty not in ('l2', 'none'):
    444         raise ValueError("Solver %s supports only 'l2' or 'none' penalties, "
--> 445                          "got %s penalty." % (solver, penalty))
    446     if solver != 'liblinear' and dual:
    447         raise ValueError("Solver %s supports only "

ValueError: Solver lbfgs supports only 'l2' or 'none' penalties, got l1 penalty.

solverlbfgsl2noneしかサポートしてないよ、っていうエラーです。

エラー解消のためにはsolverを記述

下記のように記述します。

lr_l1 = LogisticRegression(C=C, penalty='l1', solver='liblinear').fit(X_train, y_train)

原因についてちょっとだけ詳細

ここに書いてある通り、LogisticRegressionsolverのデフォルト値がアップデートにより変更されたせいです。

Changed in version 0.22: The default solver changed from ‘liblinear’ to ‘lbfgs’ in 0.22.

これにより本来solverのデフォルト値をliblinearだと想定して、L1正規化を行いつつsolverの設定を省略しているソースコードが軒並み影響を受けてエラーを吐くと思われます。

その他にも、solverを省略したせいで「デフォルト値がliblinearだと思って実行したら、lbfgsで実行されちゃった」ケースなんかは、エラーにはならなくても出力結果が過去とは異なる、という事象になります。

ベタな結論

こういう「ライブラリ側が原因で動作が変わる、あるいは動作しなくなる」ケースって界隈あるあるなので、よく利用するライブラリのバージョンアップは欠かさずチェックしておこうね、という話。

26
16
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
16