scikit-learn の fit() / transform() / fit_transform()

  • 2
    いいね
  • 0
    コメント

scikit-learn の変換系クラス(StandardScalerNormalizerBinarizerOneHotEncoderPolynomialFeaturesImputer など) には、fit()transform()fit_transform()という関数がありますが、何を使ったらどうなるかわかりづらいので、まとめてみました。

関数でやること

fit()

渡されたデータの最大値、最小値、平均、標準偏差、傾き...などの統計を取得して、内部メモリに保存する。

transform()

fit()で取得した統計情報を使って、渡されたデータを実際に書き換える。

fit_transform()

fit()を実施した後に、同じデータに対してtransform()を実施する。

使い分け

  • トレーニングデータの場合は、それ自体の統計を基に正規化や欠損値処理を行っても問題ないので、fit_transform()を使って構わない。

  • テストデータの場合は、比較的データ数が少なく、トレーニングデータの統計を使って正規化や欠損値処理を行うべきなので、トレーニングデータに対するfit()の結果で、transform()を行う必要がある。

参考

StackOverflow - Difference between fit and fit_transform in scikit_learn models?
scikit-learnマニュアル - 4.3.1. Standardization, or mean removal and variance scaling