はじめに
以前に書いた記事sckit-learnのPiplineを使って、カスタム前処理をモデルの中に組み込むの続編。
前の記事では前処理関数は完全に固定したものでしたが、この関数にパラメータを含めて、Pipeline作成後にパラメータを変更できる作りにすることを目標にします。
実装サンプル
以下に実装コードを記載しながら解説をします。
データ準備
アイリスデータセットを題材にします。
from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target
前処理関数
前処理関数 func_customの定義をします。
説明を簡単にするため、入力値Xを一定値で割り算するという関数とします。
割る値はarg1
という後で変更可能なパラメータにします。
def func_custom(X, arg1):
return(X/arg1)
カスタムフィルターの定義
FunctionTransformerクラスと先ほど定義したfunc_customを使ってカスタムフィルターを定義します。
from sklearn.preprocessing import FunctionTransformer
ft = FunctionTransformer(func_custom, validate=True, kw_args={'arg1': 2})
インスタンス生成時のパラメータ指定kw_args={'arg1': 2}
がポイントです。
パラメータのキー値'arg1'
は、関数宣言での引数と名前をそろえます。
こうすることで関数宣言のパラメータ値を後で変更することが可能になります。
パイプラインの定義
パイプラインの定義に関しては、普通の方法で行います。
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
pipe = Pipeline([('ft', ft),('svm', SVC(gamma='auto'))])
学習・予測
いったん作ったPipelineで学習と予測を行ってみます。
ここも普通の方法ですので、解説は省略します。
pipe.fit(X, y)
pipe.predict(X)
Pinelineから前処理を抜き出す
named_steps属性を使って、今作ったPipelineから前処理関数を抜き出してみます。
pipe.named_steps['ft']
次のような結果がかえってくるはずです。
FunctionTransformer(accept_sparse=False, check_inverse=True,
func=<function func_custom at 0x10d904488>, inv_kw_args=None,
inverse_func=None, kw_args={'arg1': 2}, pass_y='deprecated',
validate=True)
前処理の挙動の確認
FunctionTransformerは、transform関数でフィルターの挙動を確認できます。現時点では以下のように入力データを2で割った答えが返ってきます。
X1 = X[:5,:]
print(X1)
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]]
print(pipe.named_steps['ft'].transform(X1))
[[2.55 1.75 0.7 0.1 ]
[2.45 1.5 0.7 0.1 ]
[2.35 1.6 0.65 0.1 ]
[2.3 1.55 0.75 0.1 ]
[2.5 1.8 0.7 0.1 ]]
フィルター関数内のパラメータ値変更
フィルター関数内のパラメータ値はset_params(kw_args={'arg1': 4})
のようにset_params関数呼び出して値を変更できます。
変更のための実行コードと、変更後のフィルター出力結果は以下の通りです。
想定通り、入力値を4で割った結果がかえってきていることがわかると思います。
パラメータ値変更
pipe.named_steps['ft'].set_params(kw_args={'arg1': 4})
FunctionTransformer(accept_sparse=False, check_inverse=True,
func=<function func_custom at 0x10d904488>, inv_kw_args=None,
inverse_func=None, kw_args={'arg1': 4}, pass_y='deprecated',
validate=True)
変更後のフィルターの挙動
print(pipe.named_steps['ft'].transform(X1))
[[1.275 0.875 0.35 0.05 ]
[1.225 0.75 0.35 0.05 ]
[1.175 0.8 0.325 0.05 ]
[1.15 0.775 0.375 0.05 ]
[1.25 0.9 0.35 0.05 ]]