LoginSignup
10
13

More than 5 years have passed since last update.

scikit-learn ちょっと高度なFunctionTransformerの使い方

Last updated at Posted at 2019-04-05

はじめに

以前に書いた記事sckit-learnのPiplineを使って、カスタム前処理をモデルの中に組み込むの続編。
前の記事では前処理関数は完全に固定したものでしたが、この関数にパラメータを含めて、Pipeline作成後にパラメータを変更できる作りにすることを目標にします。

実装サンプル

以下に実装コードを記載しながら解説をします。

データ準備

アイリスデータセットを題材にします。

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target

前処理関数

前処理関数 func_customの定義をします。
説明を簡単にするため、入力値Xを一定値で割り算するという関数とします。
割る値はarg1という後で変更可能なパラメータにします。

def func_custom(X, arg1):
    return(X/arg1)

カスタムフィルターの定義

FunctionTransformerクラスと先ほど定義したfunc_customを使ってカスタムフィルターを定義します。

from sklearn.preprocessing import FunctionTransformer
ft = FunctionTransformer(func_custom, validate=True, kw_args={'arg1': 2})

インスタンス生成時のパラメータ指定kw_args={'arg1': 2}がポイントです。
パラメータのキー値'arg1'は、関数宣言での引数と名前をそろえます。
こうすることで関数宣言のパラメータ値を後で変更することが可能になります。

パイプラインの定義

パイプラインの定義に関しては、普通の方法で行います。

from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
pipe = Pipeline([('ft', ft),('svm', SVC(gamma='auto'))])

学習・予測

いったん作ったPipelineで学習と予測を行ってみます。
ここも普通の方法ですので、解説は省略します。

pipe.fit(X, y)
pipe.predict(X)

Pinelineから前処理を抜き出す

named_steps属性を使って、今作ったPipelineから前処理関数を抜き出してみます。

pipe.named_steps['ft']

次のような結果がかえってくるはずです。

FunctionTransformer(accept_sparse=False, check_inverse=True,
          func=<function func_custom at 0x10d904488>, inv_kw_args=None,
          inverse_func=None, kw_args={'arg1': 2}, pass_y='deprecated',
          validate=True)

前処理の挙動の確認

FunctionTransformerは、transform関数でフィルターの挙動を確認できます。現時点では以下のように入力データを2で割った答えが返ってきます。

X1 = X[:5,:]
print(X1)
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]
print(pipe.named_steps['ft'].transform(X1))
[[2.55 1.75 0.7  0.1 ]
 [2.45 1.5  0.7  0.1 ]
 [2.35 1.6  0.65 0.1 ]
 [2.3  1.55 0.75 0.1 ]
 [2.5  1.8  0.7  0.1 ]]

フィルター関数内のパラメータ値変更

フィルター関数内のパラメータ値はset_params(kw_args={'arg1': 4})のようにset_params関数呼び出して値を変更できます。
変更のための実行コードと、変更後のフィルター出力結果は以下の通りです。
想定通り、入力値を4で割った結果がかえってきていることがわかると思います。

パラメータ値変更

pipe.named_steps['ft'].set_params(kw_args={'arg1': 4})
FunctionTransformer(accept_sparse=False, check_inverse=True,
          func=<function func_custom at 0x10d904488>, inv_kw_args=None,
          inverse_func=None, kw_args={'arg1': 4}, pass_y='deprecated',
          validate=True)

変更後のフィルターの挙動

print(pipe.named_steps['ft'].transform(X1))
[[1.275 0.875 0.35  0.05 ]
 [1.225 0.75  0.35  0.05 ]
 [1.175 0.8   0.325 0.05 ]
 [1.15  0.775 0.375 0.05 ]
 [1.25  0.9   0.35  0.05 ]]
10
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
13