71
78

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

scikit-learn で最低限の自作推定器(Estimator)を実装する

Last updated at Posted at 2016-02-25

やりたいこと

scikit-learn はPythonのほぼデファクトの機械学習ライブラリです.scikit-learnの利点としては多くのアルゴリズムが実装されていることもそうですが,一貫した形で設計されており様々なアルゴリズムを共通したかたちで扱えることです.scikit-learnにないアルゴリズムを新たに実装したり,他のライブラリを使用するときにsciki-learnの他の推定器と同様に扱えるよう実装すれば,もともと実装されている推定器同様にクロスバリデーションで性能を評価したりグリッドサーチでパラメータを最適化したりできます.ここでは最低限の推定器の実装を示します.ここでは識別器または回帰器をターゲットとして考えます(クラスタリングとか教師なし学習とかは考えない).

べたな実装

from sklearn.base import BaseEstimator

class MyEstimator(BaseEstimator):
    def __init__(self, param1, param2):
        self.param1 = param1
        self.param2 = param2
    
    def fit(self, x, y):
        return self 
    
    def predict(self, x):
        return [1.0]*len(x) 
    
    def score(self, x, y):
        return 1
    
    def get_params(self, deep=True):
        return {'param1': self.param1, 'param2': self.param2}

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self,parameter, value)
        return self

sklearn.base.BaseEstimatorを継承して推定器クラスを定義します.メソッドの中身は適宜書き換えて下さい.
##実行例
Cross validation:


x = [[2,3],[4,5],[6,1],[2,0]] 
y = [0.0,9.4,2.1,0.9]

estimator = MyEstimator()
cross_validation.cross_val_score(estimator,x,y,cv=3)

Result:

array([ 1.,  1.,  1.])

Grid search:

gs = grid_search.GridSearchCV(estimator, {'param1': [0,10], 'param2': (1, 1e-1, 1e-2)})
gs.fit(x,y)
gs.best_estimator_, gs.best_params_, gs.best_score_

Result:

(MyEstimator(), {'param1': 0, 'param2': 1}, 1.0)

cross_validation

cross_validationを行なう為には訓練データを学習するfitメソッドとテストデータを入力しそこから推定した値と正解の値を比較してスコアを出力するscoreメソッドが必要です.

fit(self, x, y)

入力xに対して出力がyとなるように学習する関数です.

predict(self, x)

入力xに対して出力がy_predを返す関数です.cross_validationをするだけならpredictは必要ありませんが,多くの場合score内部でpredictを呼ぶでしょう.sklearnbase.ClassifierMixinscikit-learn.base.RegressionMixinを多重継承することでpredictのみを実装することで,実装済みのscore関数が使えます.

score(self, x, y)

入力xに対して出力y_predを推定し,y_predと正解yとを比較してスコア(誤差とかラベルが一致しているかなど)を返す関数です.

grid_search

grid_searchを行う為には,上で定義したように学習してスコアを計算する他にパラメータを操作する必要があります.データに依存しないパラメータを取得するメソッドget_paramsとパラメータをセットするメソッドset_paramsを実装します.

get_params(self, deep=True)

get_paramsメソッドはパラメータのkeyがattribute名.valueが値であるような辞書を返すようにします.

set_params(self, **parameters)

パラメータのセッタです.get_params同様に辞書で渡します.

Mixinについて

識別モデルの場合sklearn.base.ClassifierMixin,回帰モデルの場合sklearn.base.RegressorMixinを多重継承することで実装済のメソッドが使える.
これらを継承すると

  • sklearn.base.ClassifierMixin
  • attribute_estimator_typeclassifierまたはregressorをセット
  • 定義済みscoreメソッドが使える.scoreメソッド内でpredictメソッドを呼ぶのでpredictメソッドは実装する必要がる.
  • スコアはaccuracy
  • sklearn.base.RegressorMixin
  • attribute_estimator_typeregressorをセット
  • 定義済みscoreメソッドが使える.scoreメソッド内でpredictメソッドを呼ぶのでpredictメソッドは実装する必要がる.
  • スコアは決定係数 $R^2$

releaseにむけて

sklearn.utils.estimator_checks.check_estimatorで自前のestimatorがsklearnに適合するかどうかがチェックできます.ちなみにこの記事で示したサンプルだと入力のvalidationができていないというエラーをはきます.自分が使うぶんには問題はないでしょうけど.

まとめ

  • 自前のestimatorクラスはsklearn.base.BaseEstimatorを継承してつくる
  • cross_validationするためにはfit,scoreメソッドが必要
  • grid_searchするためには,さらにget_params,set_paramsメソッドが必要
  • ClassifierMixinまたはRegressorMixinを定義すると自分で実装したpredictを使ってスコアを計算する,scoreメソッドが使える.

参考

ここに書いたものはほぼ
sklearn.baseモジュールのAPI Reference
公式サイトの開発者向け情報
を参考にしています.

71
78
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
71
78

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?