30
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

NTTコミュニケーションズAdvent Calendar 2020

Day 15

TabNetをKaggleに使ってみる

Last updated at Posted at 2020-12-15

この記事は NTTコミュニケーションズ Advent Calendar 2020 の15日目の記事です。

TabNetという機械学習モデルについて紹介します!

概要

  • TabNetは昨年公開されたNNのアーキテクチャで、テーブルデータを学習するためのモデルである
  • 最近開催されたKaggleのコンペで(アンサンブルではありつつも)上位ソリューションで使われた
  • この記事では、タイタニックを使ってテーブルコンペでよく利用されるモデルであるLightGBMと比較した

TabNet

TabNetは昨年公開されたNNのアーキテクチャで、テーブルデータを高性能かつ解釈可能に学習できます。

Screen Shot 2020-12-08 at 21.34.28.png

下記の画像がネットワークの構造です。Attentive Transformerという機構を用いて特徴量の選択をしながら学習します。そのため、学習後にこのマスクを参照することで、どのような特徴量が重要であったかを解釈することができます。

Screen Shot 2020-12-08 at 21.36.03.png

現実世界のテーブルデータには相互依存する特徴量(例えば学歴と職業)が含まれるため、教師なし事前学習が有効であると考え、下記のような事前学習を導入しました。(Denoising AutoEncoderのようなイメージでしょうか)

Screen Shot 2020-12-08 at 21.51.54.png

GCP AI Platformの組み込みアルゴリズムとしても採用されています。
https://cloud.google.com/ai-platform/training/docs/algorithms

論文

ソースコード

公式の実装
https://github.com/google-research/google-research/tree/master/tabnet

論文の著者の公式GitHubではありませんが、Kaggleではこちらのソースコードが使われていました。scikit-learnライクなインターフェイスで学習や推論できる実装です。
https://github.com/dreamquark-ai/tabnet

こんな感じで簡単に使えます。(後述のKaggleノートブックより抜粋)

from pytorch_tabnet.tab_model import TabNetClassifier
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score

tabnet_params = dict(verbose=40)
clf = TabNetClassifier(**tabnet_params)
kf = KFold(n_splits=5, shuffle=False)
scores = []
for train_index, val_index in kf.split(x_train, y_train):
    x_tr = x_train[train_index]
    x_val = x_train[val_index]
    y_tr = y_train[train_index]
    y_val = y_train[val_index]
    clf.fit(
      x_tr, y_tr,
      eval_set=[(x_val, y_val)],
      patience=100,
      max_epochs=300,
    )
    scores.append(accuracy_score(y_val, clf.predict(x_val)))

print(scores)
print(np.mean(scores))

解説動画

ネットを調べてみても論文以外の情報があまり見つかりません。この動画を観ると少しイメージがつかめます。
https://www.youtube.com/watch?v=ysBaZO8YmX8

MoA

非常に簡単に言うと、あるレコード(遺伝子発現データと細胞生存能データ)がどのような薬の作用によるものかを予測するテーブルコンペです。(正確には、作用機序を予測するマルチラベル分類)
2020年9月3日〜12月3日まで開催されていたコンペで、4373チームが参加しました。

ちなみに私もTabNet片手に参加しましたが、残念ながらメダル獲得には至りませんでした。

優勝者のソリューション

Screen Shot 2020-12-08 at 21.54.44.png

NNを中心としたアンサンブルを用いており、その中にTabNetも用いています。テーブルコンペなのになぜEfficientNetやResNetを用いているのかについては、話の本筋からそれるのでこの記事では扱いませんが興味がある方はこちらを読むと良いでしょう。

優勝者に限らず、上位ソリューションのほとんどがTabNetを使っていました。TabNetを含めた浅めのNNが有効なコンペだったように感じました。

LightGBMとの比較

2020年現在、テーブルコンペではGBDT(勾配ブースティング決定木)と呼ばれるアルゴリズムを採用する競技者が多いです(もちろん、アンサンブル等でDeep Learningが利用されることもよくあります)。GBDTの中でも最もよく利用されるのがMicrosoft社のLightGBMと呼ばれるライブラリです。

今回はTabNetとLightGBMを比較してみたいと思います。比較用のデータとして、おなじみのTitanicデータを用います。5fold CrossValidationを行いAccuracyの平均値を比較しました。ソースコードはKaggleノートブックとして公開しています。

実行すると下記のような結果が得られます。

5fold CV
LightGBM 80.9
TabNet 80.5

まだ私の実力不足でTabNetのハイパーパラメータの勘所がわからないため、ハイパーパラメータのチューニングはしていません。また、条件を合わせるためにLightGBMについてもチューニングはしませんでした。

そのため、どちらの精度ももう少し上がるかとは思います。あくまで参考値です。それでもチューニング無しでTitanicでのCVが80%であればかなり使えるレベルといって良いと思います。

まとめ

TabNetやその応用は、今後のテーブルデータの機械学習における主要な武器の一つになっていくことが予想されます。機械学習エンジニアな方は、動向をキャッチアップされると良いのではないでしょうか!

30
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?