LoginSignup
3
5

【論文丁寧解説】最強NNと名高いTabNet解説!!

Last updated at Posted at 2024-02-20

はじめに

この記事では、最近Kaggleなどのコンペで話題のたテーブルデータ特化型DNN、最強TabNetの論文を丁寧に解説していきます、論文の内容をそのまま数式など展開させて説明すると説明が難しくなってしまうので、各技術を数式などで説明した後にわかりやすく説明するとの形で、できるだけわかりやすく説明をしていきます!!

なぜ最強なのか

ビジネスの場において、NNもしくはDNNを使えばもっと精度出るのかな…でも説明性が皆無だしなぁ…精度…説明性…精度…説明性…...Lightgbm最強!!!ってなってる人たちに朗報です!

TabNetはDNNでありながら、なんと説明性を持っているんです!そして、もちろんDNNの強みである、エンドツーエンド学習ができ、現存の勾配ブースティングを上回る精度を出してくれます!!

それだけでなく、TabNetは標準アーキテクチャとして自己教師あり学習を導入しています。もう、これが最推しの理由です!

まず、自己教師あり学習ってのは、ラベルがないデータからでも学習できちゃうっていう、もう超絶便利な学習方法なんです!普通、ディープラーニングモデルっていうのは、大量のラベル付きデータが必要で、それがないとろくなパフォーマンスが出せないことが多いんだ。でも、TabNetはここで一味違う。ラベルがなくても、データの内在するパターンや構造を学習して、その知識を後の教師あり学習フェーズで活かすんだから、まさに天才的!!これがLLMの世界ではもう革新的な技術なわけで……
一生ついていきます!!

用語前提知識

エンドツーエンド学習 (End-to-End Learning)

従来の手法では、特徴抽出、特徴選択、分類(予測)などのステップを別々に行い、各ステップで最適化を進めていました。しかしエンドツーエンド学習とは、入力データから特徴量エンジニアリングなどの前処理、目的の出力(予測結果など)を直接得るプロセスを、一つのモデルが完結して行う学習方法です。

マスク (Masking)

マスクとは、データの一部を選択的に「無視」する技術のことで、不要な情報をフィルタリングするために使用されます。TabNetでは、選択プロセスにマスキング技術を用いており、それによってモデルの注目点を動的に調整し、不要な特徴量によるノイズの影響を減らします。

スパーシティ (Sparsity)

スパーシティとは、モデルのパラメータやデータのうち、大部分がゼロ(あるいはゼロに近い値)である状態を指します。TabNetでは、スパーシティを利用して、重要な特徴量のみを活用し、それ以外は無視する(つまりゼロに近い重みを持たせる)ことで、計算効率とモデルの解釈性を向上させています。

公式リファレンス

論文

なお本記事の構成は論文の内容に沿った形展開させていきます。

概要

TabNetは、表形式データ学習のための新しい高性能かつ解釈可能な標準的なディープラーニングアーキテクチャです。
TabNetは、各決定ステップでどの特徴から推論するかを選択するために逐次的な注意を使用し、解釈可能性を可能にし、最も顕著な特徴のために学習容量が使用されるため、より効率的な学習を実現するものです。

導入

テーブルデータへのDNNを探求する価値

DNNがもたらすテーブルデータへの利点

  • 表形式データと他のデータタイプ(例:画像)を統合して効率的にエンコードする能力。
  • ツリーベースの方法で必要とされる特徴エンジニアリングの負担を軽減。
  • ストリーミングデータからの学習能力。
  • エンドツーエンドのモデルが表現学習を可能にし、データ効率の良いドメイン適応、生成モデリング、半教師あり学習など、多くの貴重なアプリケーションシナリオを実現する。

TabNetの貢献

  1. エンドツーエンド学習:
    TabNetは生の表形式データを直接入力として扱い、事前処理なしで勾配降下ベースの最適化を用いてデータの前処理に関する手間を省き、エンドツーエンドの学習プロセスを実現します。
  2. 逐次的注意による特徴選択:
    各決定ステップで特徴を選択するために逐次的な注意メカニズムを使用。
    このアプローチは、モデルが最も重要な特徴に焦点を当てることを可能にし、行ごとに異なる特徴選択を行うことで、高い解釈可能性と効率的な学習を実現します。

tabnet_1.png

3. 優れたパフォーマンスと解釈可能性:
TabNetは、多様なデータセット上で他の表形式学習モデルと比較して優れた、あるいは同等のパフォーマンスを示します。また、特徴の重要性を視覚化するローカル解釈可能性と、訓練されたモデルにおける各特徴の寄与を定量化するグローバル解釈可能性の両方を提供します。

tabnet_2.png

4. 教師なし事前学習によるパフォーマンス向上:
表形式データにおける教師なし事前学習を初めて導入し、一部のデータ(特徴)を意図的に隠し(マスクし)、モデルにその隠された部分を予測させることで顕著なパフォーマンス向上を実現します。

tabnet_3.png

TabNetによる表形式データ学習概要

スパース特徴選択

特徴選択は、次の式によって表されます。

$$
f \in \mathbb{R}^{B \times D}
$$

ここで、$f$は特徴セットを表し、$B$はバッチサイズ、$D$は特徴の次元数です。このスパース特徴選択メカニズムにより、TabNetはデータの重要な情報を効率的に抽出します。

わかりやすく説明すると、TabNetは大量の情報の中から「これが重要だ!」というデータの部分だけをピックアップし、それを使って学習することで、データセット全体の理解を深めます。この「賢い選択」により、モデルはより効率的に、かつ正確に学習することができるようになります。

シーケンシャルなマルチステップアーキテクチャ

TabNetの核となるのは、シーケンシャルなマルチステップ処理です。これは、一連の決定ステップ$N_{steps}$を通じてデータを精緻化していくプロセスです。各$i^{th}$ステップでは、前の$(i-1)^{th}$ステップからの情報を基に、どの特徴を選択し処理するかを決定します。このプロセスは次のように表されます。

$Step_i​:\text{Input from Step}_{i−1}​→\text{Feature Selection}→{Decision Contribution}$

わかりやすく説明すると、TabNetはデータを見ながら、「これは重要そうだから次も使おう」とか「これはあまり役に立たなさそうだから別の特徴を試そう」といった判断をステップごとに繰り返し精度を高めていきます。

非線形処理

選択された特徴は非線形変換を経て処理されます。この非線形性により、モデルはデータ内の複雑なパターンや関係を捉え、より表現力豊かな学習が可能になります。非線形処理は、モデルの学習容量を大幅に向上させます。

わかりやすく説明すると、一つのデータを多面的にみてパターンや関係性を見つけていきます。

アンサンブルの模倣

TabNetは、その設計を通じてアンサンブル学習の利点を模倣します。これは、複数のステップと高次元特徴を通じて、複数のモデルの予測を組み合わせるアンサンブル手法の効果を一つのモデル内で実現することです。このアプローチは、計算効率を保ちながら、予測の精度を向上させます。

わかりやすく説明すると、いくつかの小さなステップを経てデータを分析することで、複数のモデルを使う「アンサンブル学習」と同じような効果を1つのモデルで実現します。

アーキテクチャ概要

TabNetのアーキテクチャは、生の数値特徴と、訓練可能なエンベディングを使用してマッピングされたカテゴリカル特徴を組み合わせています。モデルは、バッチ正規化を適用し、全ての特徴を同じ次元で処理します。このシーケンシャルな処理は、一連の決定ステップ$N_{steps}$を通じて行われ、各ステップは前のステップからの情報を基に特徴を選択し、最終的な決定に寄与します。

わかりやすく説明すると、数字で表される特徴と、カテゴリーを数字に変換した特徴を一緒に一定の処理を経てすべて同じ形式に整えられます。そして、TabNetはステップごとにデータを少しずつ分析し、どの情報が重要かを選びながら、最終的な判断を下していきます。

tabnet_4.png

TabNetによる特徴量選択

マスキングとスパーシティ

特徴選択のためのマスキングは乗算的に行われ、具体的な計算式は$M[i]\times f$です。マスク $M[i]$は、以下の式に従って計算されます。

$$a[i−1]:M[i]=sparsemax(P[i−1]⋅hi​(a[i−1]))$$

ここで、$a[i−1]$ は前のステップの出力、$h_i$は訓練可能な関数(全結合層とバッチ正規化を含む)、$P[i−1]$ は特定の特徴が以前にどれだけ使用されたかを示す事前スケール項です。Sparsemax正規化は、スパーシティを促進し、解釈可能な特徴選択を可能にするために使用されます。

わかりやすく説明すると、特別な方法でフィルター(マスク)を使います。このフィルターは、前のステップの結果と特別な計算を組み合わせて作られ、どの特徴が次に重要かを選び出します。要するに、TabNetは一連のステップを通じて、最も重要な情報に集中するようにデータを絞り込んでいきます。

事前スケール項とスパーシティ正則化

事前スケール項 $P[i]$は次のように計算されます。

$$
P[i]=∏_{j=1}^i​(γ−M[j])
$$

ここで、$\gamma$は緩和パラメータであり、特徴が複数の決定ステップで使用される柔軟性を制御します。また、選択された特徴のスパーシティをさらに制御するために、スパーシティ正則化を導入しています。

$$
L_{sparse}​=∑_{i=1}^{N_{steps}}​​∑_{b=1}^B​∑_{j=1}^D​−M_{b,j}​[i]log(M_{b,j}​[i]+ϵ)/(N_{steps}​⋅B)
$$

ここで、$ϵ$ は数値的安定性のための小さな定数です。この正則化項は、全体の損失関数に$λsparse​$ の係数で加えられます。

わかりやすく説明すると、TabNetは重要な特徴を適切に選びながら、データの扱いをシンプルに保つための工夫(スパーシティ(特徴の選択が偏らないようにする))をしています。これはモデルが情報を効率的に扱い、過剰に複雑にならないように助けます。

tabnet_1.png

特徴量処理のプロセス

特徴量処理の核心は、特徴トランスフォーマーによる処理です。この処理は次の式に従って行われます。

$$
[d[i],a[i]]=fi​(M[i]⋅f)
$$

ここで、$d[i]∈ \mathbb{R}^{B×N_d}​$は決定ステップの出力、$a[i]∈ \mathbb{R}^{B×N_a}​$は次のステップのための情報を表します。$M[i]⋅f$ は、選択された特徴がマスキングされた後の入力特徴です。

特徴トランスフォーマーの構成

特徴トランスフォーマーは、高い学習容量とパラメータ効率の良い、堅牢な学習を実現するために設計されています。そのために、全ての決定ステップで共有される層と、決定ステップ依存の層が含まれます。具体的には、2つの共有層と2つの決定ステップ依存層が実装され、各全結合層(FC層)の後にはバッチ正規化(BN)とゲーテッドリニアユニット(GLU)非線形性が続きます。これらは最終的に正規化された残差接続によって接続されます。

わかりやすく説明すると、データを学習するために、特定の部品を使ってスマートに処理します。これには、全ステップ共通の部品と、各ステップごとに特別な部品があり、データを整えたり、重要な点を見つけたりします。

学習の安定化と高速化

学習プロセスの安定化のために、特徴トランスフォーマーは$√5$で正規化されます。また、大きなバッチサイズを使用して高速な訓練を実現します。ゴーストBNは、仮想バッチサイズ$B_V$とモーメントム$m_B$ を使用して、特定の層で適用されます。ただし、入力特徴に対しては、ゴーストBNを避けることで低分散の平均化の利点が観察されます。

わかりやすく説明すると、学習をスムーズに行うために特別な調整をします。計算を安定させるために特別な値でデータを整えたり、学習を早く進めるためにたくさんのデータを一度に扱ったりします。また、入力されるデータにはこの方法を使わないで、より良い結果が得られるようにします。

決定エンベディングの構築

最終的な決定エンベディングは、以下のように構築されます。

$$
d_{out}=∑_{i=1}^{N_{steps}}ReLU(d[i])
$$

これにより、決定木のような集約に触発された全体的な決定エンベディングが形成されます。最終的な出力マッピングは、$d_{out}$に対して線形マッピング$W_{final}​d_{out}​$ を適用することで得られます。

わかりやすく説明すると、最後の結果を出すために、すべてのステップからの情報を合わせて一つの大きなデータ(決定エンベディング)にします。そして、その大きなデータに最後の計算を一度行うことで、最終的な答えを出します。

解釈可能性

決定ステップの重要度の計算

決定ステップの重要度を計算するために、以下の式を使用します。

$$
η_b​[i]=∑_{c=1}^{N_{d}}​​ReLU(d_{b,c}​[i])
$$

これは、$i$番目の決定ステップでの$b$番目のサンプルにおける集約された決定寄与を示します。直感的には、もし$d_{b,c}​[i]<0$ であれば、そのステップの特徴は全体的な決定に対して負の寄与をしており、その値が増加するにつれて、全体的な決定においてより大きな役割を果たすことになります。

わかりやすく説明すると、各ステップで、あるデータが決定にどれだけ貢献しているかを示す数値を計算します。この数値は、そのステップの情報がプラスの影響を与えているか、マイナスの影響を与えているかを示します。値が大きければ大きいほど、その情報は決定により大きなプラスの影響を与えると考えられます。

集約された特徴重要度マスクの計算

集約された特徴重要度マスクは、以下の式によって計算されます。

$$
M_{agg−b,j​}=\frac{∑_{i=1}^{N_{steps}}​η_b​[i]M_b​[j][i]}{∑_{j=1}^D​∑_{i=1}^{N_{steps}}​​η_b​[i]M_b​[j][i]}​
$$

この式により、各特徴が全体的な決定にどの程度寄与しているかの相対的な重要度が定量化されます。この計算によって、TabNetは高いレベルの解釈可能性を提供し、特徴が決定プロセスにどのように影響を与えているかを理解するのに役立ちます。

わかりやすく説明すると、特徴の全体的な重要度を計算する式を使って、TabNetはどの情報が最終的な判断にどれだけ影響しているかを示します。この方法により、どのデータが重要かがわかりやすくなり、TabNetの判断過程を理解しやすくなります。

tabnet_2.png

テーブルデータに対する自己教師あり学習

自己教師あり学習アプローチにより、TabNetはラベルのない大量のデータから深い表現を学習することができます。特に、特徴間の相互作用や依存関係を捉え、データの理解を深めることが可能になります。

自己教師あり学習の枠組み

TabNetの自己教師あり学習フェーズでは、特定の特徴列から欠損している特徴列を予測するタスクを通じて、モデルがデータの内在的な構造を学習します。このプロセスは以下のステップで構成されます:

  1. バイナリマスクの導入:
    バイナリマスク$S \in {\lbrace0,1 \rbrace}^{B×D}$を使用して、入力特徴のうちどれが欠損しているか(マスクされているか)を示します。ここで、$B$はバッチサイズ、$D$は特徴の次元です。
  2. エンコーダーとデコーダーの動作:
    エンコーダーはマスクされていない特徴 $(1−S)⋅\hat{f}$ を入力として受け取り、デコーダーは再構築された特徴$S⋅\hat{f}$ を出力します。ここで、$\hat{f}$ は元の特徴セットを表します。
  3. 再構築損失の計算:
    自己教師ありフェーズでは、以下の再構築損失を考慮します。
\frac{\sum_{b=1}^{B}\sum_{j=1}^{D}\sqrt{(\hat{f}_{b,j}-f_{b,j})\cdot S_{b,j}}}{\sum_{b=1}^{B}(f_{b,j}-\frac{1}{B}\sum_{b=1}^{B}f_{b,j})^2}

この式により、予測された特徴と実際の特徴の差異に基づいて損失が計算され、特徴の異なる範囲を考慮して正規化されます。

わかりやすく説明すると

  1. バイナリマスクを使う:
    どのデータ情報が欠けているか(隠されているか)を示すために、0と1のマスクを使います。
  2. データの処理:
    モデルは、隠されていない情報を使って学習し、隠されたデータを推測(再構築)します。
  3. どれだけうまく推測できたかを評価:
    推測したデータと本当のデータの違いを計算して、モデルがどれだけ上手にデータを再構築できたかを評価します。

バイナリマスクのサンプリング:

各イテレーションで、バイナリマスク$S_{b,j}$のサンプリングにベルヌーイ分布を使用する、つまり各特徴$j$について、各サンプル$b$でその特徴が「使用される(1)」か「使用されない(0)」かをランダムに決定することを意味します。ここでのパラメータ$p_s$は、特徴が使用される(つまり、マスクが1である)確率を指定します。

わかりやすく説明すると、TabNetはランダムに決めた特定のデータ情報を使うかどうかを選びます。これは、コインを投げて表が出たらその情報を使い、裏が出たら使わないのと似ています。この「コイン投げ」の確率を決めるのがパラメータ$p_s$となります。

tabnet_3.png

実世界データセットでの性能

森林被覆タイプ(Dua and Graff 2017):

このタスクは、地図変数からの森林被覆タイプの分類です
tabnet_5.png

ポーカーハンド(Dua and Graff 2017):

このタスクは、カードのスーツとランクの属性からポーカーハンドを分類することです。入力と出力の関係は決定論的であり、手作りのルールで100%の正確さを達成することができます。それにもかかわらず、従来のDNN、DT、さらにはディープニューラルDTのハイブリッドモデルも、不均衡なデータから大きな影響を受けて精度が出せない中TabNetは、深さによる高度に非線形な処理を行うことができ、インスタンスごとの特徴選択のおかげで過学習せずに、他の方法よりも優れています。

tabnet_6.png

終わりに

TabNetのアルゴリズム解説は以上になります!どうだったでしょうか!!
僕の最推しモデルTabNetの凄さが少しでも皆さんに伝わったでしょうか!

今回はTabNetの論文を徹底解説と言うことで説明してきました。論文解説という内容なだけに、わかりにくくなってしまった部分もあるかと思います、
そしてTabNetの凄さはわかったけどどう実装すればいいの?ってなる方もいると思います!

僕もそうです!!なので次回は下記のGithubのリポジトリを解説します。
最後のPartでKaggleのTitanicもしくはもう少し大規模データでの実装編も記事にします。

このTabNet解説編はあと2つの合計3つで展開していくので、よかったら最後までご覧ください!

参照

公式論文:https://arxiv.org/abs/1908.07442
「TabNetはどう使えるのか」:https://qiita.com/ps010/items/ea83eea63162f6105641

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5