- 製造業出身のデータサイエンティストがお送りする記事
- 今回はTabNetを使ってみました。
TabNetとは
TabNetは、ツリーベースモデルとディープニューラルネットワークの利点を持ち合わせた高パフォーマンスなモデルだそうです。
細かい部分は論文を参照して頂けますと幸いです。
TabNetの実装
今回もUCI Machine Learning Repositoryで公開されているボストン住宅の価格データを用いて予測モデルを構築します。
# ライブラリーのインポート
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import StratifiedKFold
from pytorch_tabnet.tab_model import TabNetRegressor
import os
import random
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
# ボストンの住宅価格データ
from sklearn.datasets import load_boston
# 前処理
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
# 評価指標
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
def seed_everything(seed_value):
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
os.environ["PYTHONHASHSEED"] = str(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(10)
# データセットの読込み
boston = load_boston()
# 説明変数の格納
df = pd.DataFrame(boston.data, columns = boston.feature_names)
# 目的変数の追加
df['MEDV'] = boston.target
# データの中身を確認
df.head()
次にデータセットを分割します(train, valid, test)。
# ランダムシード値
RANDOM_STATE = 10
# 学習データと評価データの割合
TEST_SIZE = 0.2
# 学習データと評価データを作成
x_train, x_test, y_train, y_test = train_test_split(
df.iloc[:, 0 : df.shape[1] - 1],
df.iloc[:, df.shape[1] - 1],
test_size=TEST_SIZE,
random_state=RANDOM_STATE,
)
# trainのデータセットの2割をモデル学習時のバリデーションデータとして利用する
x_train, x_valid, y_train, y_valid = train_test_split(
x_train, y_train, test_size=TEST_SIZE, random_state=RANDOM_STATE
)
次にパラメータをセットします。詳細はRepositoryのReadmeに記載されております。
# モデルのパラメータ
tabnet_params = dict(
n_d=15,
n_a=15,
n_steps=8,
gamma=0.2,
seed=10,
lambda_sparse=1e-3,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2, weight_decay=1e-5),
mask_type="entmax",
scheduler_params=dict(
max_lr=0.05,
steps_per_epoch=int(x_train.shape[0] / 256),
epochs=200,
is_batch_level=True,
),
verbose=5,
)
まだ、各パラメータがモデルにどのように影響するのか把握できておりませんので、今後更に使い込んでみようと思います。
次にモデルの学習を行います。
# model
model = TabNetRegressor(**tabnet_params)
model.fit(
X_train=x_train.values,
y_train=y_train.values.reshape(-1, 1),
eval_set=[(x_valid.values, y_valid.values.reshape(-1, 1))],
eval_metric=["mae"],
max_epochs=200,
patience=30,
batch_size=256,
virtual_batch_size=128,
num_workers=2,
drop_last=False,
loss_fn=torch.nn.functional.l1_loss,
)
TabNetでは、変数重要度も算出できます。細かい算出ロジックはまだ理解できておりません。
# Feature Importance
feat_imp = pd.DataFrame(model.feature_importances_, index=boston.feature_names)
feature_importance = feat_imp.copy()
feature_importance["imp_mean"] = feature_importance.mean(axis=1)
feature_importance = feature_importance.sort_values("imp_mean")
plt.tick_params(labelsize=18)
plt.barh(feature_importance.index.values, feature_importance["imp_mean"])
plt.title("feature_importance", fontsize=18)
またTabNetでは、マスクという横軸を使用した特徴量、縦軸にデータを表し、重要な特徴量を濃淡で表す機能もあります。
# Mask(Local interpretability)
explain_matrix, masks = model.explain(x_test.values)
fig, axs = plt.subplots(1, 3, figsize=(10, 7))
for i in range(3):
axs[i].imshow(masks[i][:25])
axs[i].set_title(f"mask {i}")
最後に予測を行います。
# TabNet推論
y_pred = model.predict(x_test.values)
# 評価
def calculate_scores(true, pred):
"""全ての評価指標を計算する
Parameters
----------
true (np.array) : 実測値
pred (np.array) : 予測値
Returns
-------
scores (pd.DataFrame) : 各評価指標を纏めた結果
"""
scores = {}
scores = pd.DataFrame(
{
"R2": r2_score(true, pred),
"MAE": mean_absolute_error(true, pred),
"MSE": mean_squared_error(true, pred),
"RMSE": np.sqrt(mean_squared_error(true, pred)),
},
index=["scores"],
)
return scores
scores = calculate_scores(y_test, y_pred)
print(scores)
出力結果は下記のようになります。
R2 MAE MSE RMSE
scores 0.90156 2.466226 10.294959 3.208576
##さいごに
最後まで読んで頂き、ありがとうございました。
今回はTabNetを使ってみました。
訂正要望がありましたら、ご連絡頂けますと幸いです。