12
6

More than 1 year has passed since last update.

TabNetのアーキテクチャを詳しく解説

Posted at

はじめに

kaggleでは2020年ぐらいから当たり前のように使われるようになってきたTabNetですが使ったこともないですし、どんなアーキテクチャなのかも知りませんでした。今回は実装というよりもTabNetについて理解が深められたらと思います。

参考:

TabNetとは一体何者なのか? :TabNetの概要についてわかりやすく解説している記事です。

TabNetを使えるようになりたい【追記①lgbmとstacking(ちょっと上がる)】 :atmacup #10のディスカッションで上がっていたものです。恐らくコンペ参加しないと(後からでも同意すればできます)見れないものなので、ここに上がっている図などは使えませんが、TabNetについて(正確にはpytorch-tabnetについて)とても詳しく書いてある記事なので、是非読んでいただきたいです。今日はpytorch-tabnetの実装の話も含みますが、基本的にはTabNet自体の解説をします。

https://github.com/dreamquark-ai/tabnet :TabNetの非公式Pytorch実装。コードもファイル構成も実装が惚れ惚れするほど綺麗で読みやすい。論文読んで詰まったら論文読み込むよりこっちを読んだ方が早い気がします。

TabNet概要

※ 元論文のIntroductionに書いてあることです

TabNetは一言でいえばテーブルデータ向けのDNNになります。画像や音声、自然言語の分野で、それぞれのドメインに適したディープラーニングのモデルが様々開発されているにも関わらず、テーブルデータではあまり発展してきませんでした。

発展してこなかった背景として以下の2つが挙げられます。

  • 決定木ベースのモデルが強かった
    • 学習速度も早いし精度も高い。また解釈性も高いため特に文句が無かった
  • CNNやMLPなど既存のDNNのアーキテクチャがテーブルデータに適さなかった
    • テーブルデータにおいてはCNNが捉える空間的情報は必要ないし、MLPはパラメータ数が多すぎて過学習してしまうし学習速度が遅くなってしまう

ただ、DNNを使うことのメリットとして、大規模なデータセットにおいてであればDNNをうまく活用すれば決定木ベースのモデルより精度が上がるかもしれないし、ディープの力によってよしなにデータから特徴を学習してくれるのであれば、特徴量エンジニアリングの必要性が軽減される、といったことが期待されるので、DNNをテーブルデータに活用したい!というモチベーションはありました。

そこで、TabNetでは、元々強力な手法だった決定木ベース(GBDT)の考え方を取り込みながら、DNNの良いところ(特徴量エンジニアリングの必要性が低くなるなど)も使って、良い感じのアーキテクチャを作る、ということを目的に考案されました。

TabNetの貢献としては以下になります。

  • 生データをそのまま突っ込んで学習できること(カテゴリ変数についてはembeddingが必要ですが)
  • 各インスタンスごとに特徴選択をして学習を行うようにしたことで、高い解釈性、学習効率性を得ることができた
  • ↑ 適切な特徴選択を行うデザインにしたことで、LightGBMなど決定木ベースのモデルなどとも同等かデータセットによってはそれ以上の精度を出すことができたこと、インスタンスごとの解釈性、モデル全体の解釈性が得られる構造になっていること(要はfeature importanceが計算できる)
  • 事前学習を使うことで精度が向上できること(事前学習をどうやるかは後述)

取り敢えずTabNetが何を目指しているかは分かったが、決定木の考え方を取り込むってどういうこと?(※ 正直これについてはちゃんと把握できているかと言われると怪しいです)テーブルデータで事前学習ってどうやって行うの?といった疑問があるかと思います。以下ではそれらに答えるため、TabNetの具体的なアーキテクチャについてみていきます。

TabNetの詳細

全体像

67426894-c47f-49f7-bb88-0febe18348b2.png
まず、TabNetの事前学習、ファインチューニング時にどう使うのか、ということを示した図が上記になります。(左図が事前学習、右図がファインチューニング時になります)TabNetはEncoder-Decoderモデルを採用しています。

事前学習時はテーブルの中でいくつかの値をマスクして、(マスクする値はランダム。pytorch-tabnetの実装ではベルヌーイ分布を用いていた。どれくらいマスクするかはハイパーパラメータとして決めます)encoderにつっこんであげて、decoderで再構成します。(上記の画像だとdecoderの出力がマスクされた部分だけ見たいになっているが、一応出力はマスクされた箇所含めたテーブル全部の値のはずです)

正直なぜこの事前学習が異なるドメインのテーブルデータのfine-tuningにも使えるのか疑問があるのですが、どうやらうまくいくらしいです。

ファインチューニング時は、事前学習によって獲得したEncoderの重みを初期値として使います。Encoderの出力を入力として、タスクに合わせRegressorなりClassifierを使って予測を行ってあげれば良い、ということになります。

TabNetを理解する上では、上記のEncoderとDecoderがどのように構成されているのかを理解することが重要です。次からはTabNetのアーキテクチャについてみていきます。

TabNetのアーキテクチャ

1a40f990-ce3d-4623-ad45-ef3fb44a4843.png
上記がTabNetのアーキテクチャの全体像です。(a)がEncoder (b)がDecoderに対応しており、(c), (d)はその中で使われているblockになります。

まずは(a)のEncoderから説明していきます。

Encoder

以下がEncoderの全体像です。EncoderがTabNetの肝なので、Encoderの実装が理解できればTabNetを理解したといっても過言ではありません。
d5b544bf-c184-4e2e-9933-451983079ca7.png

最初にEncoderで意識してもらいたいのは赤線の流れです。それ以外のところは一旦無視していただいて構いません。赤線の横に書いてあるのは次元です(バッチサイズは同じなので省略)

全体の流れとして、左下にあるFeaturesというのが入力で、これがstepの数だけ入力として使われて、各stepごとに出力が計算されて、各stepごとの出力を足し合わせた結果に最後FC層で線形変換したものがoutputになっている、ということを理解していただければと思います。つまり、stepがNあれば、N個の出力が得られる訳です。

さらに、実は各stepの出力は、前のstepの出力(splitというブロックから赤線が伸びているところ)を使っています。前のstepの情報を使いながら、各stepごとに出力を行い、それらを足し合わせる… なんだかGBDTの風味が漂ってくる実装になっていることが分かります。

とはいえ、「GB」っぽさは感じられるものの、「DT」の要素はここまでの情報だけでは読み取れません。これについてはもう少し後で触れます。まずは赤線の部分の流れを詳細に追うことから始めます。

まず、左下にあるFeaturesが入力であり、これがバッチサイズB、次元Dの行列とします。この行列にバッチ正規化をかけた上で、各stepの入力として使っています(stepはNとする)

このstepごとの入力はMaskという層を通ります(※ このマスクは事前学習の入力の際に使うMaskとは異なります)。このMaskという層はFeaturesと同じくB*Dの行列になっていて、各値は0 ~ 1になっています。

Maskの各値は、同じ位置にあるFeaturesの値の重要度みたいなものだと解釈ができて、1に近ければ特徴量として残りますし、もし0だったら同じ位置にあるFeaturesの値は予測に役立たないとモデルが判断した、ということになります。

Featureを上記のMaskによって重み付けするためアダマール積を計算してあげます(単に要素ごとの積をとるだけです)
69c38b71-58e0-4f12-bab8-cd2d8cae51f7.png

こうすると、「インスタンスごとに特徴選択が行われた後の」Featuresが手に入ります。(インスタンスごとにマスクの重みは異なるので、インスタンスごとに特徴量選択が行われていると言える)これを次の層であるFeature Transformerに入れると、B*(n_d+n_a)次元の出力が得られます。つまり、Feature Transfomerは D → (n_d + n_a)にmappingする関数となっています。Feature Transformerの詳細は後述します。

(n_d + n_a)次元の出力はSplitという部分で、n_d とn_aの2つに分割されます。つまり、ここからはB*n_dの部分だけが使われます。(※Splitについてはなんでわざわざこんなことするのかよく分かりません。n_d+n_a次元の出力ではなくn_d次元の出力にして、これをAttentiveTransformerの入力にも使ってあげれば良いと思うのですが…)

B*n_dの入力をReLUを通して非線形な変換を行ったら各stepごとにやることは終わりです! 後は各stepごとの出力を足して線形変換を行うだけです。Encoderの大枠の流れは理解していただけたでしょうか。

さて、このアーキテクチャのどこが決定木を模しているんだ、という話ですが、このmaskして、FC(線形変換)してReLUを通して各stepの出力を足し合わせる、という構造でのアウトプットが、決定木でやっているような領域分割に近い、というお話があります。

1ece0b98-142b-4422-978e-19dd9d482717.png

左図がDNNでの例(Encoderでやっていることと近い実装になっています)で、右図が決定境界を表したものです。入力の特徴量が[x1, x2]とあり、それらをマスクして線形変換を行うことで、ReLUに通した時に0以下が0になるので、決定境界はx1=a, x2=dのところとなります。決定木においても、特徴量x1がaより大く、x2がbより小さかったら、分類としてクラス0、といったように領域分割を行います。このように、左図のblockが決定木でやっていることと近いことを実現できている、ということらしいです。正直分かったようなわからないような、という感じですが、やりたいことは理解してもらえたかと思います。

このように、各stepごとの実装が「DT」っぽくなっていて、全体の実装が「GB」っぽいので、合わせてみるとGBDTっぽさをDNNで表現しているんかなーということが分かりました。これが「決定木の考え方を取り込む」ということです。

ここまででEncoderの構造とその構造の理由について説明しました。次に、ここまで詳細を飛ばしてきたAttentiveTransformerとFeatureTransformerの詳細について説明します。

AttentiveTransformer

9e638af3-bd97-4c73-b57c-3233eea675af.png
まずはAttentiveTransformerの役割について説明します。Encoderの全体像をみていただくとわかるのですが、AttentiveTransformerの役割は特徴量選択を行うMaskの生成です。このことを念頭において処理をみていきます。

図で矢印が循環していて処理が分かりづらい感じもありますが、処理として実際にやることは赤線の順序です。以下順番に説明します

  • 前のstepのFeatureTransformerのn_a次元の部分を入力とする
  • 入力の次元がn_a次元だが、最終的に欲しいのはB*Dのマスクなので、FCでD次元にmapping
  • BNでバッチ正規化
  • Prior Scalesという、前のstepのAttentiveTransformerの出力(実質Mask)から計算したB*Dの行列とアダマール積を取る
  • Sparsemaxでスパースな行列にする
sparsemax.py
import torch
import torch.nn.functional as F
from pytorch_tabnet.sparsemax import Sparsemax, Entmax15

x = torch.Tensor([0, 0.1, 0.2, 0.5, 0.9])

print(x)
# >>> tensor([0.0000, 0.1000, 0.2000, 0.5000, 0.9000])

print(F.softmax(x))
# >>> tensor([0.1345, 0.1486, 0.1643, 0.2218, 0.3308])

print(Entmax15()(x))
# >>> tensor([0.0607, 0.0879, 0.1200, 0.2464, 0.4850])

print(Sparsemax()(x))
# >>> tensor([0.0000, 0.0000, 0.0000, 0.3000, 0.7000])

Sparsemaxについては【解説+実装】Sparsemax関数を理解する に詳しいです。Softmax関数を変形したもので、Softmax関数はすべてのラベルに対して確率を与えますが、Sparsemaxは一定のラベルに対しては確率が0となります。スパースな方が重みパラメータの数が減って嬉しいので、sparsemaxを使っている、ということになります。

これでAttentiveTransformerの説明は以上です。n_a次元の入力があった上で、前のstepの情報を使いながらmaskの値をつくって、スパースにしてあげて出力するんだな、ということが理解できれば十分だと思います。

FeatureTransformer

8aefec6e-02d2-4ad8-a55b-e76b90f56db9.png

FeatureTransformerでは、FC + BN(torch-tabnetでの実装はGBNだが詳細は割愛 詳細は全体セミナー20170629 ) + GLUがセットになっているのと、skip connectionがある形を何層も組み合わせて作っています。FeatureTransformerの目的は、バッチ正規化された入力を線形変換してn_d + n_a次元にmappingしてあげることです。 ここではGLUについての説明と、shared across decision stepsとDecision step dependentの違いってなんだ、という話をします。

GLUについては論文解説 Convolutional Sequence to Sequence Learning (ConvS2S) - ディープラーニングブログに詳しいです。pytorchでの実装は以下です。

torch.mul(x[:, : self.output_dim], torch.sigmoid(x[:, self.output_dim :]))

入力について詳細は省きますが、BNが終わった段階で 2 * output_dimの次元になっているとします。この入力の片方にはsigmoidをかけ、もう片方はそのまま使った上でアダマール積を計算します。お気持ちとしてはsigmoidで重み作って、情報を取捨選択している感じです。maskした上でさらに重みかけるんか、という感じですがこういう実装になっているので受け入れましょう。

shared across decision stepsとDecision step dependentの違いについてですが、TabNetではパラメータを減らすのとロバスト性の向上のため、どのstepでも共有して持つ部分(shared across decision steps)と各step固有の部分(Decision step dependent)を使ってFeatureTransformerを構成しています。バッチによってはサンプルに偏りがある可能性がありますし、sharedを作っとく方がロバスト性もそうですし、予測精度もそっちの方がよくなりそうな気がします(感想です)

以上でFeatureTransformerの説明も終わりです。ここまででEncoderは終わりなので、Decoderの話に移ります。

decoder

d84514f1-72ea-4c80-ba63-6843700da1c0.png
decoderの話に移るといったものの、decoderでやることはEncoderに比べ単純です。

Encoderで各stepごとにn_d次元の出力があるのですが、各stepの出力に対して、n_d次元をn_d次元にmappingするFeatureTransformerと、FCでn_d次元をD次元にmappingしてあげて、各stepの出力を足せば、B * D次元の出力が得られ、これはまさにinputを再構成した出力となっています。事前学習時、再構成したテーブルは元のテーブルと損失を計算する際、平均二乗誤差を用いているようです(カテゴリ変数はembeddingされていて、テーブルには連続値しかないことを前提)

最後に、特徴量重要度の出し方です。

特徴量重要度

local interpretablityについては、基本的には各stepでのmask、各stepでのmaskを足し合わせたものを表示すれば良いです。
04e1bd4d-116a-46aa-8f8d-22d814a65802.png
上図がmaskの可視化例です。縦軸がサンプル横軸が列になっています。Syn2はとても単純で、どのインスタンスでもstepが同じなら同じ特徴量に注目しています。

Syn4では、同じstepのmaskでもインスタンスによって注目している特徴量が異なることが分かります。

global interpretablityが必要な場合は、インスタンスごとにではなく、列方向にmaskの値を合計してあげて求めてあげれば良いです。タイタニックでの例は以下です。feature importanceもモデルの構造上出せるのがTabNetの良い点ですね。
b7ebdefb-5d1e-4eab-9b3f-dca4987319dd.png

実装

実装はとても簡単にできて、sk-learnライクにfit, predictメソッドを使って実装できます。実装については公式docと既に公開されている記事に詳しいので、これらを参照していただければと思います。
公式doc example: https://github.com/dreamquark-ai/tabnet/blob/develop/pretraining_example.ipynb
zennの記事に貼られているcode: https://www.kaggle.com/code/sinchir0/selfsupervisedtabnet-titanic-comparing-lgbm-nn/notebook

おわりに

今回はTabNetについて詳しくみてみました。TabNetのアーキテクチャはだいぶ分かったのですが、理解が不十分な部分もあります。コンペ等で使いながら理解を深めていきたいと思います。

12
6
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
6