はじめに
本記事では、2024/10に公開されたTabMという、テーブルデータに適したDeepLearningの学習手法をPytorch Lightningで学習・予測させる流れを記載します。
TabM
TabMそのものの詳細な解説はここでは行いませんが、概要としては、GBDTで行っているようなアンサンブルをNNで実施することで精度を高める手法で、最終的に複数のモデルが生成されます。
TabMはDNNの特定の構造を指すものではなく、パラメータ探索の手法ということになります。
TabMはテーブルデータに適しているということになっていますが、元々時系列データに対する優位性は検証されていません。
ですが、TabRedという時系列要素のあるベンチマークでTabMがベストモデルになっていたり、Jane Streetなどの時系列データのKaggleコンペにTabMモデルが使われているそうだ、と筆者がXで言及しています。
In our paper, we did not evaluate TabM on time series, though two things come to mind:
— Yura Gorishniy (@YuraFiveTwo) January 9, 2025
- TabM was the best model on TabReD (a benchmark with time-based data splits)
- I know that people use TabM in the ongoing Kaggle competition by Jane Street, where the data is time series
Pytorch Lightning
Pytorch Lightningの詳細な解説もここでは行いませんが、PytorchでNNモデルを組むためのフレームワークで、LightningModule, DataModule, Trainerという3つのクラスを定義することでモデルの学習・予測を行うことができます。
Pytorch自体は非常に自由度が高い一方、自由度が高すぎてイチから実装すると時間がかかるため、多くのユーザは独自のWrapperを書いたりするのですが、共有や再利用コストが高くなってしまうという面があります。
Pytorch Lightningは初めからクラス構成やメソッド名が決められた中でコーディングをすることを強いるため、可読性・再現性が非常に高くなります。
TabM with Pytorch Ligntning
TabMは新しいモデルですがPytorch Lightningの表現力の中に十分収まります。
ただし、TabMのアーキテクチャの特性により、少し工夫が必要なので、これからTabMを使う方がスムーズに使い始められるよう、今回、Pytorch LightningによるTabM学習・予測の記事を執筆することにしました。
Pytorch Lightningによる実装
準備
必要な資材をcloneやinstallします。
Google colabだと下記のような形になります。
!git clone https://github.com/yandex-research/tabm
!cp /content/tabm/tabm_reference.py .
!pip install rtdl_num_embeddings
Pytorch Lightningクラス定義
DataModule
DataModuleは下記の通りになります。
注意すべき点は、CustomDatasetがcont(連続値)とcat(カテゴリ値)を分けて保持し、get_itemでもそれぞれ出力している点です。
class CustomDataset(Dataset):
def __init__(self, df, accelerator):
self.cont = torch.FloatTensor(df[col_feature_cont].values).to(accelerator)
self.cat = torch.LongTensor(df[col_feature_cat].values).to(accelerator)
self.labels = torch.FloatTensor(df[col_target].values).to(accelerator)
self.weights = torch.FloatTensor(df[col_weight].values).to(accelerator)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
x_cont = self.cont[idx]
x_cat = self.cat[idx]
y = self.labels[idx]
w = self.weights[idx]
return x_cont, x_cat, y, w, y*w
class DataModule(LightningDataModule):
def __init__(self, train_df, batch_size, valid_df=None, accelerator='cpu'):
super().__init__()
self.df = train_df
self.batch_size = batch_size
self.accelerator = accelerator
self.train_df = train_df
self.train_dataset = None
self.valid_df = None
if valid_df is not None:
self.valid_df = valid_df
self.val_dataset = None
def setup(self):
self.train_dataset = CustomDataset(self.train_df, self.accelerator)
if self.valid_df is not None:
df_valid = self.valid_df
self.val_dataset = CustomDataset(df_valid, self.accelerator)
def train_dataloader(self, n_workers=0):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=n_workers)
def val_dataloader(self, n_workers=0):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=n_workers)
LightningModule
LightningModuleは下記の通りになります。
ここで、kはアンサンブルするモデルの数になります。
training, validationの各ステップで、batchから返されるxもcontとcatに分かれている点にご注意ください。
class NN(LightningModule):
def __init__(self, n_cont_features, cat_cardinalities, n_classes, lr, weight_decay):
super().__init__()
self.save_hyperparameters()
self.k = 16
self.model = Model(
n_num_features=n_cont_features,
cat_cardinalities=cat_cardinalities,
n_classes=n_classes,
backbone={
'type': 'MLP',
'n_blocks': 3 ,
'd_block': 512,
'dropout': 0.25,
},
bins=None,
num_embeddings= None,
arch_type='tabm',
k=self.k,
)
self.lr = lr
self.weight_decay = weight_decay
self.training_step_outputs = []
self.validation_step_outputs = []
self.loss_fn = weighted_mse_loss
def forward(self, x_cont, x_cat):
return self.model(x_cont, x_cat).squeeze(-1)
def training_step(self, batch):
x_cont,x_cat, y, w , w_y= batch
x_cont = x_cont + torch.randn_like(x_cont) * 0.02
y_hat = self(x_cont, x_cat)
loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
self.training_step_outputs.append((y_hat.mean(1), y, w))
return loss
def validation_step(self, batch):
x_cont,x_cat, y, w, w_y = batch
y_hat = self(x_cont, x_cat)
loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
self.validation_step_outputs.append((y_hat.mean(1), y, w))
return loss
def on_validation_epoch_end(self):
"""Calculate validation WRMSE at the end of the epoch."""
y = torch.cat([x[1] for x in self.validation_step_outputs]).cpu().numpy()
if self.trainer.sanity_checking:
prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
else:
prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
weights = torch.cat([x[2] for x in self.validation_step_outputs]).cpu().numpy()
# r2_val
val_r_square = r2_val(y, prob, weights)
self.log("val_r_square", val_r_square, prog_bar=True, on_step=False, on_epoch=True)
self.validation_step_outputs.clear()
def configure_optimizers(self):
optimizer = torch.optim.AdamW(make_parameter_groups(self.model), lr=self.lr, weight_decay=self.weight_decay)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5,
# verbose=True)
return {
'optimizer': optimizer,
# 'lr_scheduler': {
# 'scheduler': scheduler,
# 'monitor': 'val_r_square',
# }
}
def on_train_epoch_end(self):
if self.trainer.sanity_checking:
return
y = torch.cat([x[1] for x in self.training_step_outputs]).cpu().numpy()
prob = torch.cat([x[0] for x in self.training_step_outputs]).detach().cpu().numpy()
weights = torch.cat([x[2] for x in self.training_step_outputs]).cpu().numpy()
# r2_training
train_r_square = r2_val(y, prob, weights)
self.log("train_r_square", train_r_square, prog_bar=True, on_step=False, on_epoch=True)
self.training_step_outputs.clear()
epoch = self.trainer.current_epoch
metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
print(f"Epoch {epoch}: {formatted_metrics}")
Trainer
pytorch_lightning.TrainerそのままでOKです。
インスタンス生成
インスタンス生成時に、n_cont_features(連続値の特徴量の数)とcat_cardinalities(カテゴリ変数のカーディナリティのリスト)を与える必要があります。
model = NN(
n_cont_features = args.n_cont_features,
cat_cardinalities = args.cat_cardinalities,
n_classes = args.n_classes,
lr=args.lr,
weight_decay=args.weight_decay
)
カテゴリ変数の処理について注意が必要で、ここで与えたカーディナリティより大きい値がデータ中に存在すると、学習時にエラーとなります。
そのため、あるカテゴリ変数が1,3,5,8,11,...のように飛び飛びの値になっている場合、0,1,2,3...と変換する必要があります。
また、予測時に未知(学習時に存在しない)カテゴリがやってくる可能性がある場合には、何かしらのカテゴリ(学習時のカテゴリ最大値+1等)に変換するなどのケアが必要です。
下記のような関数をかませるイメージです。
def encode_column(df, column, mapping):
max_value = max(mapping.values())
col_dtype = df[column].dtype
def encode_category(category):
return mapping.get(category, max_value + 1)
return df.with_columns(
pl.col(column).map_elements(encode_category, return_dtype=col_dtype).alias(column)
)
Trainerは特筆することはないので、好きなパラメータをお使いください。
# Initialize Logger
logger = None
# Initialize Callbacks
early_stopping = EarlyStopping('val_loss', patience=args.patience, mode='min', verbose=False)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=False, filename=output_path + f"/tabm_{MODEL_ID}.model")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
timer = Timer()
# Initialize Trainer
trainer = Trainer(
max_epochs=args.max_epochs,
accelerator=accelerator,
devices=[args.gpuid] if args.usegpu else None,
logger=logger,
callbacks=[lr_monitor, early_stopping, checkpoint_callback, timer],
enable_progress_bar=True
)
学習
学習は通常のPytorch Lightningモデルと同様に行えます。
# Start Training
trainer.fit(model, data_module.train_dataloader(args.loader_workers), data_module.val_dataloader(args.loader_workers))
予測
予測時にもcontとcatそれぞれ与える必要がある点にご注意ください。
予測結果はパラメータkで指定した数だけ出力されるので、meanなどで平均をとる必要があります。
pred = model(
torch.FloatTensor(df[col_feature_cont].values).to("cuda:0"),
torch.LongTensor(df[col_feature_cat].values).to("cuda:0")
).cpu().numpy().mean(1)
おわりに
TabMをPythorch Lightningで学習・予測させる際のポイントは下記の通りです。
- DataModuleやLightningModule定義時にデータを連続列(cont)とカテゴリ列(cat)に分ける
- カテゴリ列は0,1,2,...となるように変換しておく。
- 予測結果は.mean(1)などで束ねる
また、Kaggleコンペなど、インターネットが使えない環境のときの使用方法については、Kaggle上でコードを公開していますので、こちらも参考にしていただけると幸いです。
以上となります。
今後、KaggleなどでもTabMを見かけることが増えるのではないかなと思うので、この記事が誰かの役に立てば幸いです。