52
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

sklearnのpipelineの中身を理解する

Last updated at Posted at 2020-04-26

これは何か

sklearnのpipelineを時々使っていたのだが、pipeline.fit_transform(X, y)とした時にpipelineの中ではどのような処理をしているのか、気になったので公式ドキュメント1やソースコード2を読んで、整理してみることにした。

なお、自分の抱いていた問題意識は下記コードのコメントに記載している。
人によっては「当たり前でしょ!!」って思うかもしれないが、どうしても気になってしまったので調べてみた。

# 問題意識1: 変換器ではfit_transform, 推定器ではfitが呼ばれている??
# 問題意識2: このタイミングで変換器もしくは推定器にパラメータを渡したい時はどうすればいい??
# 問題意識3: 自作の推定器・変換器を入れたい場合に満たすべき要件は??
pipe.fit(X, y)

# 問題意識4: 変換器ではfit_transform, 推定器ではpredictが呼ばれている??
pipe.predict(X)

1. pipelineとは何か

機械学習プロジェクトにおいて分類や回帰などを行う推定器(estimator)を利用する際、変換器(transformer)も一緒に使われることが多い。データの変換から学習・推定までの処理を一つの推定器としてまとめることができる機能としてpipelineが提供されている。

1.1. pipelineの使用例

pipelineは(key, value)のタプルを要素に持つリストで構成される。keyには推定器・変換器の名前、valueには推定器・変換器のオブジェクトをstepsとしてpipelineに渡す。使用例を下記に示す。

from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn import datasets

# サンプルデータの用意
iris = datasets.load_iris()
X, y = iris.data, iris.target

# pipelineの作成
estimators = [('reduce_dim', PCA()), ('clf', SVC())]
pipe = Pipeline(steps=estimators)

# 学習
pipe.fit(X, y)

# 予測
pipe.predict(X)

2. 推定器・変換器の要件

pipelineに自作の推定器・変換器を入れたい場合がある。その際に満たしておくべき要件を記載する。
なお、要件はpipelineのstepsの最後(final_estimator)か、それ以外(not_final_estimator)で変わる。

  • ドキュメント内では、not_final_estimatorをtransformと呼んでいるが、transformメソッドと被るため、本記事ではその呼び方を変更している。
  • final_estimator: fitメソッドを持っていること
  • not_final_estimator: fit及びtransformメソッドを持っている、もしくはfit_transformメソッドを持っていること

pipelineで呼び出すメソッドによって、要件は増えるが最低限満たしておくべき要件は以上である。

3. pipeline内での処理

1.1.のコードにあるように、pipeline.fitやpipeline.predictを呼び出した時のpipeline内での処理を確認してみた3
pipelineで頻繁に使用されるであろうメソッドに絞って、下記に整理している。
左から、pipelineのメソッド、それに渡すパラメータ、not_final_estimatorで呼ばれるメソッド、final_estimatorで呼ばれるメソッドである。

pipeline パラメータ not_final_estimator final_estimator
fit X, y=None, **fit_params fit_transform fit
fit_transform X, y=None, **fit_params fit_transform fit_transform
predict X, **predict_params transform predict
fit_predict X, y=None, **fit_params fit_transform fit_predict
score X, y=None, sample_weight=None transform score

なお、以下に注意すべき点を列挙しておく。

  • fit_transformメソッドを定義していない場合、fitメソッドとtransformメソッドが順番に実行される。
  • fit_transformメソッドと違って、fit_predictメソッドは定義していない場合、エラーとなる。
  • **fit_paramsは、対象stepの名前(タプルのkey部分)__パラメータ名で渡すことができる。
    • 例: pipeline.fit(X, y, key1__param1=True)
  • **fit_paramsと違って、**predict_paramsはfinal_estimatorで呼ばれるpredictメソッドにのみパラメータを渡せる。記述方法はpredictメソッドにあるパラメータ名をそのまま指定するだけでよい。
    • 例: pipeline.predict(X, param1=True)

余談であるが、sklearn準拠のモデルでは、fitメソッド実行時にパラメータを受け付けるような設計にすべきではないとしている。
そのため可能な限り、**fit_paramsを利用してパラメータを渡すのは避けたほうが良いと思われる。
sklearn準拠のモデルについてはこちらで詳しく触れている。

  1. ユーザーガイド

  2. ソースコード

  3. pipelineのドキュメント

52
40
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
52
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?