LoginSignup
8
8

More than 3 years have passed since last update.

sklearn準拠モデルについて

Last updated at Posted at 2020-04-25

これは何か

sklearn.pipelineに自作関数を組み込みたいと思い、色々調べていく中でsklearnに準拠したモデルとは何かという疑問が浮かんだ。以下にて公式ドキュメント1を整理した内容を記載していく。sklearn.pipelineやsklearn.model_selection.GridSearchCVなどを利用しようと考えている人の一助になればと思う。

1. オブジェクトの構成

sklearn準拠モデルのオブジェクトは下記のような構成の必要がある。順に見ていこう。
1. fit, set_paramsメソッドを持っている。
2. _estimator_type属性を持っている

1.1. fit, set_paramsメソッド

fitメソッドは、訓練データの学習などに使用するメソッドである。sklearnでは学習を行うメソッドの名前はfitに統一されている。これによって、pipelineやGridSearchCV側でも、sklearn準拠モデルのオブジェクトに対してfitメソッドを呼ぶことでモデルの学習が行えるようになる。
set_paramsメソッドも同様の考え方である。このメソッドは、GridSearchCVなどのパラメータチューニングを行う時に呼ばれる。

1.2. _estimator_type属性

sklearnが提供している機能の一部(例えば、GridSearchCVやcross_val_scoreなど)はモデルの種類によって挙動が変わるものがある。例えば、クラス分類器の学習を行う際はデータを層化サンプリングするなどである。以下に例を示す。
- 分類: classifier
- 回帰: regressor
- クラスタリング: clusterer

_estimator_type属性については、sklearn.baseにあるMixinクラス(例えば、ClassifierMixinクラス)を継承することで自動的に設定されるようになっている。また、sklearnでは、sklearn準拠モデルを作成する際は、sklearn.base.BaseEstimatorとそのモデルに適したMixinクラスの2つを継承することを推奨している。
- BaseEstimator: set_paramsメソッドなど、0から実装するとボイラープレートコードになってしまうようなメソッドなどが記述されている
- Mixin: 各_estimator_typeにて使われるであろうメソッドが記述されている。
* コードはgithub上で公開されているので、読んでみると更なる理解に繋がると思われる2

1.3. 実装例

1.1.及び1.2.の内容をもとにコードを書いてみる。なお、ここではset_paramsはBaseEstimatorクラスの中で用意されているため、記述しない。


from sklearn.base import BaseEstimator, ClassifierMixin

class Classifier(BaseEstimator, ClassifierMixin):

    def __init__(self):
        pass

    def fit(self, X, y):
        pass

2. __init__

インスタンス生成時に注意すべきはパラメータの受け取りである。下記に受け取り方について列挙する。
- ハイパーパラメータなどの学習に関わるパラメータは全て__init__メソッドで渡すこと(fitメソッドで渡すのはデータのみにする)。
- 受け取るパラメータは全て、デフォルト値を持つようにする。
- パラメータを持つ属性は全てパラメータと同じにすること。
- パラメータを受け取ったら、値のバリデーションなどは行わないこと(set_paramsでもパラメータの上書きを行うため、インスタンス生成時のバリデーションは避けるべきである)

2.1. 実装例


from sklearn.base import BaseEstimator, ClassifierMixin

class Classifier(BaseEstimator, ClassifierMixin):

    def __init__(self, params1=0, params2=None):
        self.params1 = params1
        self.params2 = params2

    def fit(self, X, y):
        pass

3. fit

fitにて注意すべき事項を下記に挙げる。
- データを引数として受け取り、パラメータは受け取らない
- データの学習は行っても、データ自体は保持しない
- y(正解データ)が不要であっても、第二引数に y=None という形で受け取る(pipelineなどで、教師なし学習による特徴量生成→教師あり学習を行えるようにしておくため)
- 戻り値はself
- データから推定された属性は末尾に下線を付ける(例えば、 coef_ )

3.1. 実装例


from sklearn.base import BaseEstimator, ClassifierMixin

class Classifier(BaseEstimator, ClassifierMixin):

    def __init__(self, params1=0, params2=None):
        self.params1 = params1
        self.params2 = params2

    def fit(self, X, y=None):
        print('ここにデータの学習を行う処理を記載する')
        return self

4. その他

上記以外で注意すべき事項を列挙する。
- X.shape[0]y.shape[0]は同じ(sklearn.utils.validation.check_X_yを使って確認する)。
- set_paramsは引数に辞書を受け取り、戻り値はselfである。
- get_paramsは引数を取らない。
- 分類器の場合、classes_属性にラベルのリストを持つ(sklearn.utils.multiclass.unique_labelsを使う)。

4.1. 実装例


from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y
from sklearn.utils.multiclass import unique_labels

class Classifier(BaseEstimator, ClassifierMixin):

    def __init__(self, params1=0, params2=None):
        self.params1 = params1
        self.params2 = params2

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        self.classes_ = unique_labels(y)
        print('ここにデータの学習を行う処理を記載する')
        return self

    def get_params(self, deep=True):
        return {"params1": self.params1, "params1": self.params1}

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

5. コーディング規約

基本的にPEP8に準拠しているが、それに加えてsklearn用のコーディング規約があるので記載しておく。
なお、これらはsklearnへのコントリビュートを考えていない場合は不要になる。
- クラス名以外は単語ごとにアンダースコアで区切る(例えば、 n_samples)
- 1行に複数のステートメントは記述しない(if文やfor文は改行する)
- sklearn内のモジュールのimportは相対パスで行う(テストコードでは、絶対パスで記述する)
- import * は使用しない
- docstringはnumpyスタイル3

6. sklearn準拠モデルになっているかの確認

sklearn準拠モデルになっているかを確認してくれるcheck_estimatorメソッドをsklearnは提供してくれている。_estimator_type属性によって変わるが、いくつかのテストを実施して準拠しているか確認してくれるようだ。fitメソッドを実装しなかった場合、AttributeError: 'Classifier' object has no attribute 'fit'というエラーが出る。また、github上でsklearn準拠モデルのテンプレートを用意してくれているので、そちらを参考に実装を行い、完成したらcheck_estimatorを実行して、確認するのが良いと思う。以下に実行例を示す。

from sklearn.utils.estimator_checks import check_estimator

# これまで上記で実装してきたコードでは、classifierとして必要なpredictメソッドを定義していないため、エラーが起きる。
# ClassifierMixinを継承せずに、TemplateなEstimatorとして実装するとエラーは発生しない。
class Estimator(BaseEstimator):

    def __init__(self, params1=0, params2=None):
        self.params1 = params1
        self.params2 = params2

    def fit(self, X, y=None):
        X, y = check_X_y(X, y)
        self.classes_ = unique_labels(y)
        self.is_fitted_ = True
        return self

    def get_params(self, deep=True):
        return {"params1": self.params1, "params1": self.params1}

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self


check_estimator(Estimator)


  1. Developing scikit-learn estimators 

  2. base.py 

  3. numpydoc docstring guide、またこちらの記事が分かりやすくまとめられているのでおすすめしたい。 

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8