機械学習で使うデータは、多くの場合欠損値を含みます。scikit-learnではすべての項目に値が設定されていると仮定されているので、欠損値と相性が悪いです。
あまりにも欠損値が多い列は列ごと捨ててしまうこともありますが、一部だけ欠損している場合だともったいないですよね。なんらかの値で補完する場合の備忘録として、sklearn.impute.SimpleImputerの使い方をまとめました。
SimpleImputerの使い方
デフォルト
欠損値 np.nan を 平均値で補完します
import numpy as np
import pandas as pd
from sklearn.impute import SimpleImputer
imputer = SimpleImputer()
# すべての列を補完する場合
# この方法だと列名が消えてしまうので、元に戻してあげる必要があります
df2 = pd.DataFrame(imputer.fit_transform(df))
df2.columns = df.columns
# 特定の列を補完する場合
df[['col']] = imputer.fit_transform(df[['col']])
欠損値の種類
np.nan 以外にも様々な値を欠損値として指定することができます。
# int
imputer = SimpleImputer(missing_values=-1)
# float
imputer = SimpleImputer(missing_values=0.0)
# str
imputer = SimpleImputer(missing_values='NaN')
# np.nan(デフォルト)
imputer = SimpleImputer(missing_values=np.nan)
# pandas.NA
imputer = SimpleImputer(missing_values=pd.NA)
補完方法
平均値以外にも様々な補完方法が用意されています。
# 平均値(デフォルト)
imputer = SimpleImputer(strategy='mean')
# 中央値
imputer = SimpleImputer(strategy='median')
# 最頻値
imputer = SimpleImputer(strategy='most_frequent')
# 固定値
imputer = SimpleImputer(strategy='constant', fill_value=0.0)
メソッド
fit, transform, fit_transform がよく使われるメソッドです。
ユーザガイドを見ても使い方がよく分からなかったので、自分で試してみた上での理解を記載します。もし違っていたら指摘をお願いします。
# fitで設定した引数から統計量(例えば平均値)を取得します
imputer = SimpleImputer()
imputer.fit([[1, 2, 3], [4, np.nan, 6], [7, 8, 9]])
# transformでは、fitで取得した統計量を欠損値に補完します
# transformに設定した引数の統計値ではありません
X = [[1,1,1],[1,1,1],[1,1,np.nan]]
X = imp_mean.transform(X)
print(X)
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 6.]]
# fit_transformでは、引数Xから取得した統計量でXの欠損値を補完します。
# これが一番直感的かもしれません。
X = [[1,1,1],[1,1,1],[1,1,np.nan]]
X = imp_mean.fit_transform(X)
print(X)
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
# 取得済みの統計量を使って別のデータを補完する場合は、続けてtransformを使います。
X2 = [[10,10,10],[10,10,np.nan],[np.nan,np.nan,np.nan]]
X2 = imp_mean.transform(X2)
print(X2)
[[10. 10. 10.]
[10. 10. 1.]
[ 1. 1. 1.]]
欠損値の確認方法
欠損値の数を確認します。
もしほとんどの行が欠損値だったら、列ごと捨てたほうがいいかもしれません。
# まず行数を確認します
print(X.shape)
# 欠損値を含むカラムごとの欠損値数を確認します
missing_val_count_by_column = (X.isnull().sum())
print(missing_val_count_by_column[missing_val_count_by_column > 0])
次に欠損値を含む項目の概要を確認し、どの方法で補完するかを検討します。
# 「col」の部分には、欠損値を含むカラム名を指定してください
# 概要
print(df.col.describe())
# 先頭データ
print(df[df.col.isna()].head())