概要
TabNetはニューラルネットワークをベースとしたモデルで、kaggleなどのテーブルデータ予測でよく用いられます。
TabNetの特徴を簡単にまとめると以下のようになります(参考サイト[1]より引用)。
- ディープラーニングをベースとしたモデル。
- 特徴量の選択や加工などの前処理が不要で、end-to-endで学習することができる。
- アテンション・メカニズムを使い、各決定ステップ(decision step)において使用する特徴量を選択する。アテンション・メカニズムにより解釈性が向上し、重要な特徴量をうまく学習することができる。
- 全サンプル共通ではなくサンプルごとに重要な特徴量を選択する。
- いくつかの特徴量をマスクし、それを予測するという事前学習を行う。
今回は、このTabNetをPythonライブラリ「pytorch-tabnet」で実装する方法を紹介します。
使用するデータ
サンプルデータとして、SIGNATE練習問題の国勢調査からの収入予測というコンペのデータを使用します。
回帰タスクではなく、二値分類タスク($50,000を超えるかどうか)です。
TabNetの実装
以下のような流れで実装方法を説明します。
- データの読み込み・前処理
- モデルの学習
- 予測
1. データの読み込み・前処理
データの読み込みと、通常の前処理を行います。
import pandas as pd
# tsvファイルの読み込み
train = pd.read_csv('./data/train.tsv', delimiter='\t')
test = pd.read_csv('./data/test.tsv', delimiter='\t')
# trainとtestを結合してdf_allとする
df_all = pd.concat([train,test])
# 前処理
df_all['Y'] = df_all['Y'].map({'<=50K':0, '>50K':1})
df_all['sex'] = df_all['sex'].map({'Male':0, 'Female':1})
category_cols = ['workclass','education','marital-status','occupation','relationship','race','native-country']
dummies = []
for col in category_cols:
dummies.append(pd.get_dummies(df_all[col], prefix=col))
df_all = pd.concat([df_all.drop(category_cols,axis=1)] + dummies, axis=1)
# df_allをtrain,testに分離
train = df_all[~df_all['Y'].isnull()].reset_index(drop=True)
test = df_all[df_all['Y'].isnull()].reset_index(drop=True)
ここで、データに不均衡がある(Y=0のデータが12288個なのに対しY=1のデータは3992個)ので、Y=0のデータからランダムに選んで4000個に減らす処理を行いました。
import numpy as np
# ダウンサンプリング(Y=0のデータから4000個ランダムに取る)
y0_idx = train[train['Y']==0].index
y1_idx = train[train['Y']==1].index
y0_choice = np.random.choice(y0_idx,size=4000)
train = train.iloc[np.concatenate([y0_choice,y1_idx])]
2. 学習
訓練用に与えられたデータを、学習用データ(X_tr
,y_tr
)と評価用データ(X_va
,y_va
)に分割します。
from sklearn.model_selection import train_test_split
X = train.drop(['id','Y'],axis=1).values
y = train['Y'].values
X_tr, X_va, y_tr, y_va = train_test_split(X, y, test_size=0.2, random_state=42)
TabNetでは教師なし事前学習を行うと精度が上がるようです。以下は論文から引用した表です。
データの数が少ないほど、事前学習による精度向上の幅が大きいことがわかりますね。
pytorch-tabnetでは、TabNetPretrainerに事前学習用のクラスが実装されているのでそれを使って事前学習ができます。
参考サイト[3]のサンプルコードを参考にしました。
optimizer_fn
:最適化の手法(AdamとかSGDとか)。torchのやつが使える。
optimizer_params
:最適化パラメータ(learning rateとかmomentumとか)。
device_name
:使用する計算リソース。'cuda'とかにするとGPUを使える。
mask_type
:特徴量選択のために使われるマスク関数('sparsemax'or'entmax')。
mask_type
でどちらを選ぶかについては、参考サイト[4]が参考になります。
entmax
の方がソフトに特徴選択してくれるおかげか、比較的sparsemax
よりは過学習しにくいことが多かったです。
from pytorch_tabnet.pretraining import TabNetPretrainer
import torch
# 事前学習
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
device_name = 'cpu',
mask_type='entmax' # "sparsemax"
)
unsupervised_model.fit(
X_tr,
eval_set=[X_va],
batch_size = 64,
pretraining_ratio=0.8,
)
本学習(ファインチューニング)では、分類タスクならTabNetClassifier
を、回帰タスクならTabNetRegressor
を使用します。
from pytorch_tabnet.tab_model import TabNetClassifier
# 本学習
model = TabNetClassifier(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
device_name = 'cpu',
verbose = 1,
seed = 42
)
model.fit(
X_tr, y_tr,
eval_set = [(X_tr, y_tr),(X_va, y_va)],
eval_metric= ['accuracy','accuracy'],
eval_name = ['train', 'valid'],
batch_size = 64,
max_epochs = 10,
patience = 10, # early_stopping_roundに相当
from_unsupervised=unsupervised_model # 事前学習のモデル
)
3. 予測
学習済みモデルの予測は.predict()
で行うことができます。
X_test = test.drop(['id','Y'],axis=1).values
y_test = test['Y'].values
y_pred = model.predict(X_test)
実行結果
損失、正解率のエポック推移は次のグラフのとおりになりました。
評価用データに対する正解率は80.425%(リーダーボード評価:79.854%)でした。
また、TabNetではLightGBMなどのような特徴量重要度(feature importance)を算出することが可能です。特徴量重要度をプロットしたのが下のグラフです。
TabNetの特徴量重要度は下図(論文[2]より引用)の中Attentive Transformerからとってきているようです。(参考サイト[5])
したがって、TabNetの特徴量重要度は0以上1以下の値をとり、総和をとると1になります。
(参考)LightGBMとの比較
LightGBMでもモデルを構築し、精度を比較しました。
評価データに対する正解率は80.425%(リーダーボード評価:81.807%)でした。
今回の例では、TabNetはLightGBMに匹敵する精度を出すことができました。
なお、参考までにLightGBMの特徴量重要度を下に載せておきます。
参考サイト
[1] 【論文解説】TabNetを理解する
https://data-analytics.fun/2021/09/04/understanding-tabnet/
[2] TabNet: Attentive Interpretable Tabular Learning
https://arxiv.org/abs/1908.07442
[3] pytorch-tabnetドキュメント
https://dreamquark-ai.github.io/tabnet/generated_docs/README.html
[4] TabNetメインに使ってみての振り返り
https://www.guruguru.science/competitions/16/discussions/12b403d8-2106-4ae7-9294-383f080b87a7/
[5] TabNetを頑張って調べて見たりする遊び(1/2)
https://tanico-kazuyo.net/archives/1649