0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

表形式データの分類(TabNet)

Posted at

はじめに

表形式データの分類モデルを構築します。
コードは下記にあります。

概要

  • adultデータセットを用いて収入を予測するモデルを構築します。
  • モデルはTabNetを使用します。

TabNet

TabNetは表形式データのために提案されたニューラルネットワークです。
下記の提案論文の図に示されている通り、モデルの主な特徴は2点あります。

  1. Attentive Transformerという機構を利用して、入力の少数の特徴を選択して予測
  2. 上記を繰り返し、各予測を統合して全体の予測結果とする。ただし、2回目以降は、前回の予測も入力に加える。

tabnet.png

実装

1. ライブラリのインポート

TabNetはpytorch_tabularの実装を使用します。

!pip install ucimlrepo
from ucimlrepo import fetch_ucirepo

!pip install pytorch_tabular

import sys
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import collections
from sklearn.model_selection import train_test_split
from sklearn import metrics

import torch
import torch.nn as nn
import torch.nn.functional as F

2. 実行環境の確認

使用するライブラリのバージョンや、GPU環境を確認します。

print('Python:', sys.version)
print('PyTorch:', torch.__version__)
!nvidia-smi
実行結果
Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch: 2.1.0+cu121
Sat Jan 20 05:20:32 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

3. データセットの用意

adultデータセットをダウンロードして、学習に使用できる形式に整形します。

adult = fetch_ucirepo(id=2)

X = adult.data.features
y = adult.data.targets['income']

y = y.replace({'<=50K.': 0, '<=50K':0, '>50K.': 1, '>50K': 1})

# カテゴリ変数の特定
categorical = X.columns[X.dtypes == 'object'].tolist()
continuous = X.columns[X.dtypes != 'object'].tolist()


# 教師データとテストデータにランダムに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print(X_train.shape, X_test.shape)
print(collections.Counter(y_train), collections.Counter(y_test))

X_train['income'] = y_train
X_test['income'] = y_test

4. ニューラルネットワークの定義

学習などの設定は、configに引数として渡し、TabularModelを作成します。

from pytorch_tabular import TabularModel
from pytorch_tabular.models import TabNetModelConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig

data_config = DataConfig(
    target=['income'],
    continuous_cols=continuous,
    categorical_cols=categorical,
)
trainer_config = TrainerConfig(
    auto_lr_find=False,
    batch_size=128,
    max_epochs=100,
)
optimizer_config = OptimizerConfig(
    optimizer_params = {'weight_decay':1e-4}
)
model_config = TabNetModelConfig(
    task="classification",
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)

5. 学習

ニューラルネットワークの学習を行います。

tabular_model.fit(train=X_train)

6. 学習結果の表示

テストデータの損失と精度を評価します。

res = tabular_model.evaluate(X_test)
実行結果
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy       │    0.8428308367729187     │
│         test_loss         │    0.3377942144870758     │
└───────────────────────────┴───────────────────────────┘

おわりに

今回の結果

今回の設定では、テスト精度は84%程度となりました。
全結合ニューラルネットワークよりも少し低い結果となっています。
ただし、データセットやハイパーパラメータによっても性能は異なるため、もう少し検証は必要です。

次にやること

他の表形式データのために提案されたニューラルネットワークも試してみようと思います。

参考資料

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?