Help us understand the problem. What is going on with this article?

sklearn準拠モデルの作り方

More than 1 year has passed since last update.

この記事は リクルートライフスタイル Advent Calendar 2017 1日目の記事です。

こんにちは!データエンジニアリンググループでエンジニアをやっている@roronyaです。CETというプロジェクトで施策で使う機械学習のモデルを作ったり基盤を作ったりしています。

今日は自前の機械学習モデルを作る際にscikit-learn準拠のモデルを書くメリットと実際のやり方をまとめました。

sklearn準拠モデルとは?

自作の機械学習モデルでも、sklearnのライブラリに実装されている各手法と同じように扱えるモデルのことです。これによりfit()predict()といった、sklearnでお馴染みの関数を利用できるようになります。

なんでsklearn準拠にするの?

自作の機械学習モデルもsklearnの各手法と同じように扱えると、便利なことがたくさんあるからです。

  • sklearn.model_selectionのGridSearchやCrossValidationなどを使えるようになる。
    • 自分で実装しなくてもOK!
    • 多くの場合これが最大のモチベ
  • 自前のモデルをsklearnのインタフェースを合わせれば、既に使っているsklearnのモデルを簡単に入れ替えられる
  • たくさん使われているライブラリのクラスと同じように使えるので学習コストが低い。
    • 使ってもらいやすい!
  • まだ本家で実装されていない手法なら、sklearnに直接コントリビュートもできるかも!!

それでは具体的にどうやって実装するのか紹介します。

1. クラス設計

やることは3つです。

  1. BaseEstimatorを継承
  2. 回帰ならRegressorMixin、分類ならClassifierMixinを継承
  3. fit()predict()を実装

例1: 自作LinearRegression

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin


# BaseEstimatorとRegressorMixinを継承する
class LinearRegression(BaseEstimator, RegressorMixin):
    # fit()を実装
    def fit(self, X, y):
        self.coef_ = np.linalg.solve(
            np.dot(X.T, X), np.dot(X.T, y)
        )
        # fit は self を返す
        return self

    # predict()を実装
    def predict(self, X):
        return np.dot(X, self.coef_)

1-1. 回帰と分類以外の手法は?

sklearn.baseのAPI Referenceを見て、対応したMixinを選んで継承します。
RegressorMixinとClassifierMixin以外にも、ClusterMixinやTransformerMixinなど他の手法のMixinも用意されています。

ref: sklearn.baseのAPI Reference

1-2. 回帰にも分類にも使える手法はどうすればいいの?

それぞれ回帰用のクラス、分類用のクラスを実装します。例2のようにRegressorMixinとClassifierMixinの両方を継承しても回帰と分類どちらにも使えるモデルにはなりません

例2: (ダメな例)RegressorMixinとClassifierMixinの両方を継承する

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin

# RegressorMixinとClassifierMixinの両方を同時に継承する
class LinearModel(BaseEstimator, RegressorMixin, ClassifierMixin):
    def fit(self, X, y):
        self.coef_ = np.linalg.solve(
            np.dot(X.T, X), np.dot(X.T, y)
        )
        return self

    def predict(self, X):
        return np.dot(X, self.coef_)

なぜどちらも使えるモデルにならないかについて説明します。
各Mixinクラスには

  • それぞれの手法に適したscore()メソッド
  • 自身がどういう手法かを示す_estimator_typeプロパティ

が実装されています。

ref: sklearnのRegressorMixinのソースコード

そのため、例2の場合、RegressorMixinClassifierMixinのそれぞれの実装が競合してしまいます。
Pythonの多重継承の振る舞い的には左側が優先なので、この場合はRegressorMixinが優先されていて、ClassifierMixinを継承した意味は無くなっています。

そこで、それぞれのMixinを継承したクラスが必要になります。

例3: それぞれのMixinを継承する

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin


# fit()だけ実装した抽象クラス
class LinearModel(BaseEstimator):
    def fit(self, X, y):
        self.coef_ = np.linalg.solve(
            np.dot(X.T, X), np.dot(X.T, y)
        )
        return self

    def predict(self, X):
        return np.dot(X, self.coef_)


# RegressorMixinを継承
class LinearRegressor(LinearModel, RegressorMixin):
    pass


# Classifierを継承
class LinearClassifier(LinearModel, ClassifierMixin):
    pass

例3のようにfit()だけを実装した抽象クラスを作ってから、RegressorMixinとClassifierMixinをそれぞれ継承したクラスを作ります。

1-3. クラス設計がsklearn準拠になっているかどうか調べるには?

sklearn.utils.estimator_checkscheck_estimator()という関数があり、これを使うとクラス設計がsklearnのルールに従っているかチェックすることが出来ます。

例4: fit()をコメントアウトした自前LinearRegressionにcheck_estimator()をすると怒られる

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.estimator_checks import check_estimator


class LinearRegression(BaseEstimator, RegressorMixin):
#    fit()をコメントアウト
#    def fit(self, X, y):
#        self.coef_ = np.linalg.solve(
#            np.dot(X.T, X), np.dot(X.T, y)
#        )
#        return self

    def predict(self, X):
        return np.dot(X, self.coef_)

if __name__ == '__main__':
    check_estimator(LinearRegression)

実行結果:

/Users/roronya/.pyenv/versions/3.6.3/bin/python /Users/roronya/Develop/2017adventcalendar/ex4.py
Traceback (most recent call last):
  File "/Users/roronya/Develop/2017adventcalendar/ex4.py", line 18, in <module>
    check_estimator(LinearRegression)
  File "/Users/roronya/.pyenv/versions/3.6.3/lib/python3.6/site-packages/sklearn/utils/estimator_checks.py", line 265, in check_estimator
    check(name, estimator)
  File "/Users/roronya/.pyenv/versions/3.6.3/lib/python3.6/site-packages/sklearn/utils/testing.py", line 291, in wrapper
    return fn(*args, **kwargs)
  File "/Users/roronya/.pyenv/versions/3.6.3/lib/python3.6/site-packages/sklearn/utils/estimator_checks.py", line 841, in check_estimators_dtypes
    estimator.fit(X_train, y)
AttributeError: 'LinearRegression' object has no attribute 'fit'

Process finished with exit code 1

AttributeError: 'LinearRegression' object has no attribute 'fit'というエラーメッセージが出ていて、fit()が無いと教えてくれます。

2. 命名規則とかあるの?

学習した結果など、fit() した後に値が確定するような変数には、特別なルールがあります。

つまりfit()した後に値が確定する変数はコンストラクタでは束縛せずfit()の中変数名にサフィックスとして_を付けて宣言します。

その他の命名規則は明確に決まっていませんが、慣習はあります。sklearnのドキュメントを見て慣習に従うのが良いと思います。

例えば線形モデルなら

  • 係数: coef_
  • バイアス項: intercept_

が使われることが多いです。私は実装する前に似た手法がsklearnでどのように実装するか確認しています。

3. fit()する前のpredict()の挙動はどうすればいいの?

sklearn.utils.validation.check_is_fitted()というfit()しているか否かを確かめる関数を使います。この関数はfit()されていなければsklearn.exceptions.NotFittedErrorを返します。

ref: check_is_fitted()のAPI Reference

例5: check_is_fitted()を使う例

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_is_fitted


class LinearRegression(BaseEstimator, RegressorMixin):
    def fit(self, X, y):
        self.coef_ = np.linalg.solve(
            np.dot(X.T, X), np.dot(X.T, y)
        )
        return self

    def predict(self, X):
        check_is_fitted(self, 'coef_')
        return np.dot(X, self.coef_)

if __name__ == '__main__':
    X = np.array([
        [1,1,1,1],
        [2,2,2,2]
    ])
    model = LinearRegression()
    model.predict(X)  # fit() する前に predict() してみる

check_is_fitted()にはselffit()したあとに値が確定する変数名を渡します。

実行結果:

/Users/roronya/.pyenv/versions/3.6.3/bin/python /Users/roronya/Develop/2017adventcalendar/ex5.py
Traceback (most recent call last):
  File "/Users/roronya/Develop/2017adventcalendar/ex5.py", line 23, in <module>
    model.predict(X)
  File "/Users/roronya/Develop/2017adventcalendar/ex5.py", line 14, in predict
    check_is_fitted(self, 'coef_')
  File "/Users/roronya/.pyenv/versions/3.6.3/lib/python3.6/site-packages/sklearn/utils/validation.py", line 768, in check_is_fitted
    raise NotFittedError(msg % {'name': type(estimator).__name__})
sklearn.exceptions.NotFittedError: This LinearRegression instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.

Process finished with exit code 1

sklearn.exceptions.NotFittedError: ThisLinearRegression instance is not fitted yet.と教えてくれます。

check_is_fitted()を使わない場合、fit()する前にpredict()するとcoef_は宣言されてないため、組み込み例外のAttributeErrorが吐かれますが、check_is_fitted()を使えば、発生する状況が絞られた例外に出来るので、エラーハンドリングしやすいはずです。

まとめ

  • 自前のモデルをsklearnのインタフェースを合わせると
    • sklearn.model_selectionのGridSearchやCrossValidationなどを使えるようになる
    • 既に使っているsklearnのモデルを簡単に入れ替えられる
  • BaseEstimatorと手法に適したMixinクラスを継承する
  • sklearn.utils.estimator_checks.check_estimator()でsklearn準拠になっているか確認する
  • fit()したあとに値が確定する変数は、変数名のsuffixに_を付けてfit()で宣言する。コンストラクタで初期化しない
  • 変数名は慣習に従う
  • fit()した後に呼ばれることが前提のメソッドはsklearn.utils.validation.check_is_fitted()を使う

マサカリ募集中です!コメントで指摘してください!

参考

scikit-learn 0.19.1 documentation
scikit-learn github repository

roronya
recruitlifestyle
飲食・美容・旅行領域の情報サイトや『Airレジ』などの業務支援サービスなど、日常消費領域に関わるサービスの提供するリクルートグループの中核企業
http://www.recruit-lifestyle.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした