8
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Target Encodingで精度向上させた例(Leave One Out)

Last updated at Posted at 2021-08-01

Target Encoding(Leave One Out)を使って分類器の精度向上させました。Kaggle Titanicで0.79665のScoreで約2500位/55000人程度なので、そこそこいいスコアかと思います。
Target Encodingを勉強したときは、どんなときに使うんだろうとぼんやりと考えていましたが、意外とすぐに使う機会ができました。Greedy Target Statisticsはリークしたので、Leave one-out Target Statisticsを使いました
※Target Encodingは記事「カテゴリ変数系特徴量の前処理(scikit-learnとcategory_encoders)」に書いています。

前提: Target Encoding

手前味噌ですが、私の記事に簡単に書いていて、以下の記事にはもっと詳しい内容があります。

チャレンジ内容: Titanic

Kaggle TitanicでTarget Encoding を使いました。タイタニック号乗客の生死予測です。

特徴量

特徴量詳細

与えられた特徴量には、以下のものがあります。

変数 内容 Value情報など
1 PassengerId 乗客IDでユニークキー
2 Survival 生死(目的変数) 0 = No(死亡), 1 = Yes(生存)。テストデータには存在しない列
3 Pclass 席等級 1 = 1st(Upper), 2 = 2nd(Middle), 3 = 3rd(Lower)
4 Name 姓名
5 Sex male/female
6 Age Age in years
7 Sibsp 同乗した姉妹兄弟と配偶者数
8 Parch 同乗した両親子どもの数
9 Ticket チケット番号
10 Fare 料金 親子で2席買ったら合計料金になるっぽい
11 Cabin キャビン番号 欠損値多い
12 Embarked 乗船港 C = Cherbourg, Q = Queenstown, S = Southampton

新たな特徴量とTarget Encoding

そもそも、以下の kernel に触発されて自分なりに特徴量エンジニアリングをしてみた結果、Target Encodingに行き着きました。

家族や友人などのグループは生死を道連れにする、という仮説を元にグループの特徴量を作り、グループに対してTarget Encodingを使います。グループ作成の基準は以下の2点。

  1. 同じ Ticket(婚約者でまだ姓が違う場合や友人など)
  2. 姓と料金が同じ(姓だけだと偶然の場合も多いので)

※もっといいグルーピングの基準があるかもしれませんが、時間をかけて精査していません。

プログラム

全般的に、結構雑な処理をしています。プログラム全体はこちらのカーネル。

1. パッケージとファイル読込

最初はパッケージとファイル読込です。

import pandas as pd
import matplotlib.pyplot as plt
import category_encoders as ce

from sklearn.preprocessing import OrdinalEncoder
from sklearn.ensemble import RandomForestClassifier

train_csv = pd.read_csv("/kaggle/input/titanic/train.csv")
test_csv = pd.read_csv("/kaggle/input/titanic/test.csv")

訓練データとテストデータの内容(info関数結果)を参考に載せておきます。

train_csv()
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
test_csv.info()
RangeIndex: 418 entries, 0 to 417
Data columns (total 11 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  418 non-null    int64  
 1   Pclass       418 non-null    int64  
 2   Name         418 non-null    object 
 3   Sex          418 non-null    object 
 4   Age          332 non-null    float64
 5   SibSp        418 non-null    int64  
 6   Parch        418 non-null    int64  
 7   Ticket       418 non-null    object 
 8   Fare         417 non-null    float64
 9   Cabin        91 non-null     object 
 10  Embarked     418 non-null    object 
dtypes: float64(2), int64(4), object(5)
memory usage: 36.0+ KB

2. 特徴量生成

2.1. グループキーとなる特徴量生成

訓練データとテストデータを結合して、以下の2処理を行います。本当は訓練データとテストデータの結合って嫌いなのですが、今回は無視しています。

  1. Nameから姓の抽出
  2. 姓とFareの結合(次で処理しやすくするため)
both = pd.concat([train_csv, test_csv], ignore_index=True)

# Last Name作成
both['Last_Name'] = both['Name'].apply(lambda x: str.split(x, ",")[0])

# FareがNullのレコードは文字列"nan"となる
both['Name_Fare'] = both['Last_Name'] + both['Fare'].astype('str')

2.2. グルーピング

同じTicketでグルーピング→同じ「姓」とFareでグルーピングを再帰処理で繰り返します。今回、一番コーディングが面倒だった処理です(筆者のスキル不足が原因)。新しくGroup列に数値を連番で振っていきます。

# 姓とFareでグルーピング
def process_name(i, name_fare):
    tickets = both.loc[(both['Name_Fare'] == name_fare) & (both['Group'].isnull()),'Ticket'].unique().tolist()
    both.loc[(both['Name_Fare'] == name_fare) & (both['Group'].isnull()), 'Group'] = i
    for ticket in tickets:
        process_ticket(i, ticket)

# チケットでグルーピング        
def process_ticket(i, ticket):
    name_fares = both.loc[(both['Ticket'] == ticket) & (both['Group'].isnull()),'Name_Fare'].unique().tolist()
    both.loc[(both['Ticket'] == ticket) & (both['Group'].isnull()), 'Group'] = i
    for name_fare in name_fares:
        process_name(i, name_fare)

both['Group'] = None

# チケットでグルーピング(再帰処理で姓とFare、チケットでのグルーピングを繰り返す)
[process_ticket(i, ticket) for i, ticket in enumerate(both['Ticket'].unique().tolist())]

2.3. 訓練とテストに分離

再び、訓練とテストデータセットに分離します。

train = both[:891].copy()
test = both[891:].copy()

2.4. グループのカウント

Count Encoderを使ってグループ内人数のカウントを実施。Group_Countという列に値を入れます。
テストデータセットには、自分自身のデータ数をカウントとして+1します。これは、例えば訓練データに「山田グループ」があり2人いたとして、テストデータに同じく「山田グループ」に所属する人がいたら(何人いても)テストデータを3と設定しています。訓練データとテストデータをまとめてカウントエンコーディングすればいいような気もしますが、そうすると訓練内容にも大きく影響与えそうなのでやっていません。

count_encoder = ce.CountEncoder(cols=['Group'], handle_unknown=0, return_df=True)
train['Group_Count'] = count_encoder.fit_transform(train['Group'])
test['Group_Count'] = count_encoder.transform(test['Group']).astype('int') + 1  #test自身のデータを1追加(本当はもっとあるかもしれないが雑に計算)

2.5. Target Encoding

Target Encoding としてLeave One Outを使っています。自分自身を抜いて目的変数Survivedの平均の特徴量をGroup_Targetに入れます。

te = ce.LeaveOneOutEncoder(cols=['Group'])
train['Group_Target'] = te.fit_transform(train['Group'], train['Survived'])
test['Group_Target'] = te.transform(test['Group'])

2.6. Target Encoding 微調整

訓練データに1人のみいたグループの場合、その1人の生死(Survivedの値)をそのままテストデータのGroup_Targetに入れます。
また、訓練データに1人もいなかったグループの場合は、平均生存率である0.38を入れておきます。

for index, row in test.query('Group_Count == 2').iterrows():
    test.at[index, 'Group_Target'] = train[train['Group']==row['Group']]['Survived']
test.loc[test['Group_Count'] == 0, 'Group_Target'] = 0.38

2.7. 他カテゴリ変数のEncoding

Pclass(席等級)とSex(性)をEncodingします。

oe = OrdinalEncoder()
train.loc[:,['PclassEncoded', 'SexEncoded']] = oe.fit_transform(train[['Pclass', 'Sex']])
test.loc[:,['PclassEncoded', 'SexEncoded']] = oe.transform(test[['Pclass', 'Sex']])

3. グルーピングの確認

グルーピング結果をグラフで確認します。

_, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 4))
train['Group_Count'].value_counts().plot.pie(ax=axes[0], title='Group Count Pie Chart', autopct="%1.1f%%")
train['Group_Count'].plot.hist(ax=axes[1], title='Group Count Histgram')
train.query('Group_Count > 2')['Group_Target'].plot.hist(ax=axes[2], title='Survival Rate Histgram')
plt.show()

image.png
グラフの内容は以下通りです。

  • 左(円グラフ): グループ内人数(Group_Count)の割合
  • 中央(ヒストグラム): グループ内人数(Group_Count)のヒストグラム
  • 右(ヒストグラム): 3人以上のグループの生存率ヒストグラム

訓練データの半分弱がグループ(3人以上)です。右のグラフ(3人以上のグループの生存率ヒストグラム)では、ある程度左右に偏ってくれています。
2人グループを含めなかった理由は、2人グループはLeave One Outなので必ず0(死亡)か1(生存)になるからです。
グループ内人数で個別に確認しています。

_, ax = plt.subplots(figsize=(12, 10))
train.query('Group_Count > 2').Group_Target.hist(ax=ax, by=train['Group_Count'], range=(0, 1))
plt.show()

3 or 4人グループでは少しバラけています。一方で5人以上のグループでは、道連れ死パターンが多いですね。グループ人数情報と併せて決定木系の分類器であれば、判断してくれるのでは、と期待します。
image.png

4. 訓練実施

特徴量で使いやすいものだけピックアップして、使いやすいランダムフォレストで訓練します。Scoreは0.853でした。
※Greedy Target Statisticsを使ったときはリークがひどく、0.98くらいでした(ウル覚え)・・・

features = ["Group_Count", "Group_Target", 'PclassEncoded', 'SexEncoded', "SibSp", "Parch"]
X = train[features]
y = train["Survived"].astype('int')

tmp_sr = test['PassengerId']
X_test = test[features]

model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X, y)
print(model.score(X, y))

特徴量の重要性を確認。

importances = pd.DataFrame({'Importance': model.feature_importances_}, index=X_test.columns)
importances.sort_values('Importance', ascending=False).head(10).sort_values('Importance', ascending=True).plot.barh(grid=True)

Group_Trargetはいい感じの高さになってくれています(2人以上のグループは多くないのでこんなもの)。
image.png

5. 予測

モデルを使って予測して、CSVファイルに出力します。

predictions = model.predict(X_test)

output = pd.DataFrame({'PassengerId': tmp_sr, 'Survived': predictions})
output.to_csv('my_submission.csv', index=False)
8
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?