12
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Target Encodingのsmoothingの実装

Posted at

Target Encodingのsmoothingなるものを自分で使うために実装してみました。

元々はsmoothingというのは知らなかったのですが、こちらのページを参考にさせていただきました。

[Target Mean Encoding] (https://qiita.com/suaaa7/items/cfe9a9e516b5b784570f)

[TargetEncodingのスムーシング] (https://mikebird28.hatenablog.jp/entry/2018/06/14/172132)

細かい内容はそちらを見ていただきまして、早速コードに。

TargetEncoding_ws.py

import numpy as np
import pandas as pd

class TargetEncoding_ws(object):
    """
    DFと変換したいカラムリスト、targetを引数として、Target Encoding with Smoothingを行う
    引数
    dataframe : DF全体 (pd.DataFrame)
    target : 目的変数のカラム (np.ndarray or np.Series)
    list_cols : 変換したいカラムリスト (list[str])
    k : smoothingのハイパーパラメータ (int)
    impute : 未知のカテゴリに平均を入れるか (boolean)
    """
    def __init__(self, list_cols, k=100, impute=True):
        self.df = None
        self.target = None
        self.list_cols = list_cols
        self.k = k
        self.impute = impute
        self.target_map = {}
        self.target_mean = None
            
    def sigmoid(self, x, k):
        return 1 / (1 + np.exp(- x / k))
            
    def fit_univariate(self, target, col):
        """
        一つの変数に対するTarget_Encoding
        col : TargetEncodingしたい変数名
        """
        df = self.df.copy()
        k = self.k
        df["target"] = target
        n_i = df.groupby(col).count()["target"]
        
        lambda_n_i = self.sigmoid(n_i, k)
        uni_map = df.groupby(col).mean()["target"]
        
        return lambda_n_i * df.loc[:, "target"].mean() + (1 - lambda_n_i) * uni_map
        
    def fit(self, data, target):
        """
        複数カラムにも対応したTargetEncoding
        """
        self.df = data.copy()
        self.target = target
        
        if self.impute == True:
            self.target_mean = target.mean()
        
        #各カラムのmapを保存
        for col in list_cols:
            self.target_map[col] = self.fit_univariate(target, col)

    def transform(self, x):
        list_cols = self.list_cols
        x_d = x.copy()
        for col in list_cols:
            x_d.loc[:, col] = x_d.loc[:, col].map(self.target_map[col])
            
            #impute
            if self.impute == True:
                x_d.loc[:, col] = np.where(x_d.loc[:, col].isnull(), self.target_mean, x_d.loc[:, col])
                
        return x_d

予測の際に未知のカテゴリが入ってくる可能性があることを考慮し、未知のカテゴリには平均を入れるか入れないかを選べるようしてあります。

sample_data.py
#サンプルデータ
from sklearn.datasets import load_iris
data = load_iris()
df = pd.DataFrame(data.data, columns=data.feature_names)
df["cate1"] = data.target
df["cate2"] = df.cate1 + 1 #2カラム以上でも動くことの確認用カラム


data = df.drop('sepal length (cm)', axis=1)
y = df['sepal length (cm)']

X_train = data.iloc[:100, :]
X_test = data.iloc[100:, :]
y_train = y[:100]
y_test = y[100:]

みなさんご存知のirisデータセットですが、元々はtarget(上の例ではcate1)に1,2,3の数字が入っていますが、50行ずつになっていますので、X_trainのcate1には3が含まれない状態です。

execute1.py
list_cols = ["cate1", "cate2"]
te = TargetEncoding_ws(list_cols=list_cols, k=200, impute=False)
te.fit(X_train, y_train)
display(te.transform(X_train).head())
display(te.transform(X_test).head())
sepal width (cm) petal length (cm) petal width (cm) cate1 cate2
0 3.5 1.4 0.2 5.267412 5.267412
1 3.0 1.4 0.2 5.267412 5.267412
2 3.2 1.3 0.2 5.267412 5.267412
3 3.1 1.5 0.2 5.267412 5.267412
4 3.6 1.4 0.2 5.267412 5.267412

| |sepal width (cm) | petal length (cm) | petal width (cm) | cate1 | cate2 |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| 100 | 3.3 | 6.0 | 2.5 | NaN | NaN |
| 101 | 2.7 | 5.1 | 1.9 | NaN | NaN |
| 102 | 3.0 | 5.9 | 2.1 | NaN | NaN |
| 103 | 2.9 | 5.6 | 1.8 | NaN | NaN |
| 104 | 3.0 | 5.8 | 2.2 | NaN | NaN |

list_cols = ["cate1", "cate2"]
te = TargetEncoding_ws(list_cols=list_cols, k=200, impute=True)
te.fit(X_train, y_train)
display(te.transform(X_train).head())
display(te.transform(X_test).head())
sepal width (cm) petal length (cm) petal width (cm) cate1 cate2
0 3.5 1.4 0.2 5.267412 5.267412
1 3.0 1.4 0.2 5.267412 5.267412
2 3.2 1.3 0.2 5.267412 5.267412
3 3.1 1.5 0.2 5.267412 5.267412
4 3.6 1.4 0.2 5.267412 5.267412
sepal width (cm) petal length (cm) petal width (cm) cate1 cate2
100 3.3 6.0 2.5 5.471 5.471
101 2.7 5.1 1.9 5.471 5.471
102 3.0 5.9 2.1 5.471 5.471
103 2.9 5.6 1.8 5.471 5.471
104 3.0 5.8 2.2 5.471 5.471

以上

12
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?