Help us understand the problem. What is going on with this article?

sklearnのpipelineの中身を理解する

これは何か

sklearnのpipelineを時々使っていたのだが、pipeline.fit_transform(X, y)とした時にpipelineの中ではどのような処理をしているのか、気になったので公式ドキュメント1やソースコード2を読んで、整理してみることにした。

なお、自分の抱いていた問題意識は下記コードのコメントに記載している。
人によっては「当たり前でしょ!!」って思うかもしれないが、どうしても気になってしまったので調べてみた。

# 問題意識1: 変換器ではfit_transform, 推定器ではfitが呼ばれている??
# 問題意識2: このタイミングで変換器もしくは推定器にパラメータを渡したい時はどうすればいい??
# 問題意識3: 自作の推定器・変換器を入れたい場合に満たすべき要件は??
pipe.fit(X, y)

# 問題意識4: 変換器ではfit_transform, 推定器ではpredictが呼ばれている??
pipe.predict(X)

1. pipelineとは何か

機械学習プロジェクトにおいて分類や回帰などを行う推定器(estimator)を利用する際、変換器(transformer)も一緒に使われることが多い。データの変換から学習・推定までの処理を一つの推定器としてまとめることができる機能としてpipelineが提供されている。

1.1. pipelineの使用例

pipelineは(key, value)のタプルを要素に持つリストで構成される。keyには推定器・変換器の名前、valueには推定器・変換器のオブジェクトをstepsとしてpipelineに渡す。使用例を下記に示す。

from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn import datasets

# サンプルデータの用意
iris = datasets.load_iris()
X, y = iris.data, iris.target

# pipelineの作成
estimators = [('reduce_dim', PCA()), ('clf', SVC())]
pipe = Pipeline(steps=estimators)

# 学習
pipe.fit(X, y)

# 予測
pipe.predict(X)

2. 推定器・変換器の要件

pipelineに自作の推定器・変換器を入れたい場合がある。その際に満たしておくべき要件を記載する。
なお、要件はpipelineのstepsの最後(final_estimator)か、それ以外(not_final_estimator)で変わる。
* ドキュメント内では、not_final_estimatorをtransformと呼んでいるが、transformメソッドと被るため、本記事ではその呼び方を変更している。

  • final_estimator: fitメソッドを持っていること
  • not_final_estimator: fit及びtransformメソッドを持っている、もしくはfit_transformメソッドを持っていること

pipelineで呼び出すメソッドによって、要件は増えるが最低限満たしておくべき要件は以上である。

3. pipeline内での処理

1.1.のコードにあるように、pipeline.fitやpipeline.predictを呼び出した時のpipeline内での処理を確認してみた3
pipelineで頻繁に使用されるであろうメソッドに絞って、下記に整理している。
左から、pipelineのメソッド、それに渡すパラメータ、not_final_estimatorで呼ばれるメソッド、final_estimatorで呼ばれるメソッドである。

pipeline パラメータ not_final_estimator final_estimator
fit X, y=None, **fit_params fit_transform fit
fit_transform X, y=None, **fit_params fit_transform fit_transform
predict X, **predict_params transform predict
fit_predict X, y=None, **fit_params fit_transform fit_predict
score X, y=None, sample_weight=None transform score

なお、以下に注意すべき点を列挙しておく。

  • fit_transformメソッドを定義していない場合、fitメソッドとtransformメソッドが順番に実行される。
  • fit_transformメソッドと違って、fit_predictメソッドは定義していない場合、エラーとなる。
  • **fit_paramsは、対象stepの名前(タプルのkey部分)__パラメータ名で渡すことができる。
    • 例: pipeline.fit(X, y, key1__param1=True)
  • **fit_paramsと違って、**predict_paramsはfinal_estimatorで呼ばれるpredictメソッドにのみパラメータを渡せる。記述方法はpredictメソッドにあるパラメータ名をそのまま指定するだけでよい。
    • 例: pipeline.predict(X, param1=True)

余談であるが、sklearn準拠のモデルでは、fitメソッド実行時にパラメータを受け付けるような設計にすべきではないとしている。
そのため可能な限り、**fit_paramsを利用してパラメータを渡すのは避けたほうが良いと思われる。
sklearn準拠のモデルについてはこちらで詳しく触れている。

shota-imazeki
まだまだ未熟なデータサイエンティストですが日々精進しております。よろしくお願いします。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした