LoginSignup
1
4

【深層学習】TabNetの使い方

Posted at

概要

TabNetはニューラルネットワークをベースとしたモデルで、kaggleなどのテーブルデータ予測でよく用いられます。
TabNetの特徴を簡単にまとめると以下のようになります(参考サイト[1]より引用)。

  • ディープラーニングをベースとしたモデル。
  • 特徴量の選択や加工などの前処理が不要で、end-to-endで学習することができる。
  • アテンション・メカニズムを使い、各決定ステップ(decision step)において使用する特徴量を選択する。アテンション・メカニズムにより解釈性が向上し、重要な特徴量をうまく学習することができる。
  • 全サンプル共通ではなくサンプルごとに重要な特徴量を選択する。
  • いくつかの特徴量をマスクし、それを予測するという事前学習を行う

今回は、このTabNetをPythonライブラリ「pytorch-tabnet」で実装する方法を紹介します。

使用するデータ

サンプルデータとして、SIGNATE練習問題の国勢調査からの収入予測というコンペのデータを使用します。
回帰タスクではなく、二値分類タスク($50,000を超えるかどうか)です。

TabNetの実装

以下のような流れで実装方法を説明します。

  1. データの読み込み・前処理
  2. モデルの学習
  3. 予測

1. データの読み込み・前処理

データの読み込みと、通常の前処理を行います。

main.py
import pandas as pd

# tsvファイルの読み込み
train = pd.read_csv('./data/train.tsv', delimiter='\t')
test = pd.read_csv('./data/test.tsv', delimiter='\t')
# trainとtestを結合してdf_allとする
df_all = pd.concat([train,test])

# 前処理
df_all['Y'] = df_all['Y'].map({'<=50K':0, '>50K':1})
df_all['sex'] = df_all['sex'].map({'Male':0, 'Female':1})

category_cols = ['workclass','education','marital-status','occupation','relationship','race','native-country']
dummies = []
for col in category_cols:
    dummies.append(pd.get_dummies(df_all[col], prefix=col))
df_all = pd.concat([df_all.drop(category_cols,axis=1)] + dummies, axis=1)

# df_allをtrain,testに分離
train = df_all[~df_all['Y'].isnull()].reset_index(drop=True)
test = df_all[df_all['Y'].isnull()].reset_index(drop=True)

ここで、データに不均衡がある(Y=0のデータが12288個なのに対しY=1のデータは3992個)ので、Y=0のデータからランダムに選んで4000個に減らす処理を行いました。

main.py
import numpy as np

# ダウンサンプリング(Y=0のデータから4000個ランダムに取る)
y0_idx = train[train['Y']==0].index
y1_idx = train[train['Y']==1].index
y0_choice = np.random.choice(y0_idx,size=4000)
train = train.iloc[np.concatenate([y0_choice,y1_idx])]

2. 学習

訓練用に与えられたデータを、学習用データ(X_tr,y_tr)と評価用データ(X_va,y_va)に分割します。

main.py
from sklearn.model_selection import train_test_split

X = train.drop(['id','Y'],axis=1).values
y = train['Y'].values

X_tr, X_va, y_tr, y_va = train_test_split(X, y, test_size=0.2, random_state=42)

TabNetでは教師なし事前学習を行うと精度が上がるようです。以下は論文から引用した表です。
image.png

データの数が少ないほど、事前学習による精度向上の幅が大きいことがわかりますね。

pytorch-tabnetでは、TabNetPretrainerに事前学習用のクラスが実装されているのでそれを使って事前学習ができます。

参考サイト[3]のサンプルコードを参考にしました。

optimizer_fn:最適化の手法(AdamとかSGDとか)。torchのやつが使える。
optimizer_params:最適化パラメータ(learning rateとかmomentumとか)。
device_name:使用する計算リソース。'cuda'とかにするとGPUを使える。
mask_type:特徴量選択のために使われるマスク関数('sparsemax'or'entmax')。

mask_typeでどちらを選ぶかについては、参考サイト[4]が参考になります。

entmaxの方がソフトに特徴選択してくれるおかげか、比較的sparsemaxよりは過学習しにくいことが多かったです。

main.py
from pytorch_tabnet.pretraining import TabNetPretrainer
import torch

# 事前学習
unsupervised_model = TabNetPretrainer(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    device_name = 'cpu',
    mask_type='entmax' # "sparsemax"
)
unsupervised_model.fit(
    X_tr,
    eval_set=[X_va],
    batch_size = 64,
    pretraining_ratio=0.8,
)

本学習(ファインチューニング)では、分類タスクならTabNetClassifierを、回帰タスクならTabNetRegressorを使用します。

main.py
from pytorch_tabnet.tab_model import TabNetClassifier

# 本学習
model = TabNetClassifier(
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    device_name = 'cpu',
    verbose = 1,
    seed = 42
)
model.fit(
    X_tr, y_tr,
    eval_set = [(X_tr, y_tr),(X_va, y_va)],
    eval_metric= ['accuracy','accuracy'],
    eval_name = ['train', 'valid'],
    batch_size = 64,
    max_epochs = 10,
    patience = 10,  # early_stopping_roundに相当
    from_unsupervised=unsupervised_model  # 事前学習のモデル
)

3. 予測

学習済みモデルの予測は.predict()で行うことができます。

main.py
X_test = test.drop(['id','Y'],axis=1).values
y_test = test['Y'].values
y_pred = model.predict(X_test)

実行結果

損失、正解率のエポック推移は次のグラフのとおりになりました。
image.png

評価用データに対する正解率は80.425%(リーダーボード評価:79.854%)でした。

また、TabNetではLightGBMなどのような特徴量重要度(feature importance)を算出することが可能です。特徴量重要度をプロットしたのが下のグラフです。
image.png

TabNetの特徴量重要度は下図(論文[2]より引用)の中Attentive Transformerからとってきているようです。(参考サイト[5])
image.png

したがって、TabNetの特徴量重要度は0以上1以下の値をとり、総和をとると1になります。

(参考)LightGBMとの比較

LightGBMでもモデルを構築し、精度を比較しました。
評価データに対する正解率は80.425%(リーダーボード評価:81.807%)でした。
今回の例では、TabNetはLightGBMに匹敵する精度を出すことができました。

なお、参考までにLightGBMの特徴量重要度を下に載せておきます。
image.png

参考サイト

[1] 【論文解説】TabNetを理解する
  https://data-analytics.fun/2021/09/04/understanding-tabnet/
[2] TabNet: Attentive Interpretable Tabular Learning
  https://arxiv.org/abs/1908.07442
[3] pytorch-tabnetドキュメント
  https://dreamquark-ai.github.io/tabnet/generated_docs/README.html
[4] TabNetメインに使ってみての振り返り
  https://www.guruguru.science/competitions/16/discussions/12b403d8-2106-4ae7-9294-383f080b87a7/
[5] TabNetを頑張って調べて見たりする遊び(1/2)
  https://tanico-kazuyo.net/archives/1649

1
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
4