LoginSignup
1
1

More than 1 year has passed since last update.

データフレーム出力可能なsklearnライクな変数選択器を実装してみる

Posted at

やりたいこと

自作のfeature selectorをパイプラインに組み込み、前処理から学習を一括管理したい。
その際、削除された変数を確認したいので、データフレームで出力できるようにする。

最低限の実装

今回は練習用に、同じ値の要素が閾値以上の列を削除するfeature selectorを実装してみる。
似たような機能を持つsklearnのVarianceThreshold1のソースコード2を参考に作成した。
データフレームでの出力方法以外は、すでに解説記事3があるのでそちらを参照して頂きたい。

まずは最低限の実装を示す。
こちらの実装では、fitメソッドの引数Xはデータフレームでなければならない。

from scipy.stats import mode
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin


class MySelector(SelectorMixin, BaseEstimator):
    def __init__(self, threshold=1):
        self.threshold = threshold

    def fit(self, X, y=None):
        self.feature_names_in_ = X.columns
        _, counts = mode(X, axis=0)
        self.freq_ratios = (counts / X.shape[0]).flatten()
        return self

    def _get_support_mask(self):
        return self.freq_ratios < self.threshold

_get_support_maskメソッドを実装すれば、SelectorMixinからtransformメソッドとinverse_transformメソッドが実装される。
またBaseEstimatorを継承しているので、インスタンス変数feature_names_in_をセットすることで、次のようにset_outputメソッドにtransform="pandas"を渡すことで出力をデータフレームにできる。

selector = MySelector(threshold=1).set_output(transform="pandas")
selector.fit(X)
X_transformed = selector.transform(X)

ひとまず目的は達成されたが、fitの引数としてデータフレームしか受け取れない問題が残っている。
(他にも色々あると思われるが・・・)
これは型チェック用の関数を導入することで解決できる。

型チェックを加えた実装

fitメソッドの中で_validate_dataメソッドを呼び出せば良い。内部ではcheck_array関数によって入力データが2次元配列であることをチェックしつつ、_check_feature_namesメソッドが呼び出され、インスタンスのfeature_names_in_属性に値をセットしてくれる。
また、check_arrayに渡すキーワード引数を_validate_dataメソッドに渡しておくことで、入力データの様々な型チェックが可能4。今回は出力のデータ型を維持するようにdtype=Noneを渡しておく。

ついでに、_get_support_maskメソッドの中でcheck_is_fittedメソッドを呼び出すことで、fitメソッドの実行前にtransformメソッドを実行しようとしたときに、 "fitメソッドが実行されていないよ" というエラーメッセージを返してくれる。

以上を加えた実装はこちら。

from scipy.stats import mode
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted


class MySelector(SelectorMixin, BaseEstimator):
    def __init__(self, threshold=1):
        self.threshold = threshold

    def fit(self, X, y=None):
        X = self._validate_data(X, dtype=None)
        # self.feature_names_in_ = X.columns
        _, counts = mode(X, axis=0)
        self.freq_ratios = (counts / X.shape[0]).flatten()
        return self

    def _get_support_mask(self):
        check_is_fitted(self)
        return self.freq_ratios < self.threshold

パイプライン化

上記のように実装したクラスはsklearnのパイプラインに組み込むことができる。
またパイプラインでは出力をデータフレームにする設定をまとめて行うことができる5
以下は、自作feature selectorで全て同じ値の変数を削除した後に標準化を行う例。

pipeline = Pipeline(
    [("my_selector", MySelector(threshold=1)), ("standard_scaler", StandardScaler())]
)
pipeline.set_output(transform="pandas")
pipeline.fit(X)
X_transformed = pipeline.transform(X)
  1. https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.VarianceThreshold.html

  2. https://github.com/scikit-learn/scikit-learn/blob/364c77e04/sklearn/feature_selection/_variance_threshold.py#L13

  3. https://qiita.com/roronya/items/fdf35d4f69ea62e1dd91

  4. https://scikit-learn.org/stable/modules/generated/sklearn.utils.check_array.html

  5. https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_set_output.html

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1