0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pytorch LightningでTabM学習・予測

Posted at

はじめに

本記事では、2024/10に公開されたTabMという、テーブルデータに適したDeepLearningの学習手法をPytorch Lightningで学習・予測させる流れを記載します。

TabM

TabMそのものの詳細な解説はここでは行いませんが、概要としては、GBDTで行っているようなアンサンブルをNNで実施することで精度を高める手法で、最終的に複数のモデルが生成されます。
TabMはDNNの特定の構造を指すものではなく、パラメータ探索の手法ということになります。
TabMはテーブルデータに適しているということになっていますが、元々時系列データに対する優位性は検証されていません。
ですが、TabRedという時系列要素のあるベンチマークでTabMがベストモデルになっていたり、Jane Streetなどの時系列データのKaggleコンペにTabMモデルが使われているそうだ、と筆者がXで言及しています。

Pytorch Lightning

Pytorch Lightningの詳細な解説もここでは行いませんが、PytorchでNNモデルを組むためのフレームワークで、LightningModule, DataModule, Trainerという3つのクラスを定義することでモデルの学習・予測を行うことができます。
Pytorch自体は非常に自由度が高い一方、自由度が高すぎてイチから実装すると時間がかかるため、多くのユーザは独自のWrapperを書いたりするのですが、共有や再利用コストが高くなってしまうという面があります。
Pytorch Lightningは初めからクラス構成やメソッド名が決められた中でコーディングをすることを強いるため、可読性・再現性が非常に高くなります。

TabM with Pytorch Ligntning

TabMは新しいモデルですがPytorch Lightningの表現力の中に十分収まります。
ただし、TabMのアーキテクチャの特性により、少し工夫が必要なので、これからTabMを使う方がスムーズに使い始められるよう、今回、Pytorch LightningによるTabM学習・予測の記事を執筆することにしました。

Pytorch Lightningによる実装

準備

必要な資材をcloneやinstallします。
Google colabだと下記のような形になります。

!git clone https://github.com/yandex-research/tabm
!cp /content/tabm/tabm_reference.py .
!pip install rtdl_num_embeddings

Pytorch Lightningクラス定義

DataModule

DataModuleは下記の通りになります。
注意すべき点は、CustomDatasetがcont(連続値)とcat(カテゴリ値)を分けて保持し、get_itemでもそれぞれ出力している点です。

class CustomDataset(Dataset):
    def __init__(self, df, accelerator):
        self.cont = torch.FloatTensor(df[col_feature_cont].values).to(accelerator)
        self.cat = torch.LongTensor(df[col_feature_cat].values).to(accelerator)
        self.labels = torch.FloatTensor(df[col_target].values).to(accelerator)
        self.weights = torch.FloatTensor(df[col_weight].values).to(accelerator)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        x_cont = self.cont[idx]
        x_cat = self.cat[idx]
        y = self.labels[idx]
        w = self.weights[idx]
        return x_cont, x_cat, y, w, y*w


class DataModule(LightningDataModule):
    def __init__(self, train_df, batch_size, valid_df=None, accelerator='cpu'):
        super().__init__()
        self.df = train_df
        self.batch_size = batch_size
        self.accelerator = accelerator
        self.train_df = train_df
        self.train_dataset = None
        self.valid_df = None
        if valid_df is not None:
            self.valid_df = valid_df
        self.val_dataset = None

    def setup(self):
        self.train_dataset = CustomDataset(self.train_df, self.accelerator)
        if self.valid_df is not None:
            df_valid = self.valid_df
            self.val_dataset = CustomDataset(df_valid, self.accelerator)

    def train_dataloader(self, n_workers=0):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=n_workers)

    def val_dataloader(self, n_workers=0):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=n_workers)


LightningModule

LightningModuleは下記の通りになります。
ここで、kはアンサンブルするモデルの数になります。
training, validationの各ステップで、batchから返されるxもcontとcatに分かれている点にご注意ください。

class NN(LightningModule):
    def __init__(self, n_cont_features, cat_cardinalities, n_classes, lr, weight_decay):
        super().__init__()
        self.save_hyperparameters()
        self.k = 16
        self.model = Model(
                n_num_features=n_cont_features,
                cat_cardinalities=cat_cardinalities,
                n_classes=n_classes,
                backbone={
                    'type': 'MLP',
                    'n_blocks': 3 ,
                    'd_block': 512,
                    'dropout': 0.25,
                },
                bins=None,
                num_embeddings= None,
                arch_type='tabm',
                k=self.k,
            )
        self.lr = lr
        self.weight_decay = weight_decay
        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.loss_fn = weighted_mse_loss

    def forward(self, x_cont, x_cat):
        return self.model(x_cont, x_cat).squeeze(-1)

    def training_step(self, batch):
        x_cont,x_cat, y, w , w_y= batch
        x_cont = x_cont + torch.randn_like(x_cont) * 0.02
        y_hat = self(x_cont, x_cat)

        loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
        self.training_step_outputs.append((y_hat.mean(1), y, w))
        return loss

    def validation_step(self, batch):
        x_cont,x_cat, y, w, w_y = batch
        y_hat = self(x_cont, x_cat)
        loss = self.loss_fn(y_hat.flatten(0, 1), y.repeat_interleave(self.k))
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, batch_size=x_cont.size(0))
        self.validation_step_outputs.append((y_hat.mean(1), y, w))
        return loss

    def on_validation_epoch_end(self):
        """Calculate validation WRMSE at the end of the epoch."""
        y = torch.cat([x[1] for x in self.validation_step_outputs]).cpu().numpy()
        if self.trainer.sanity_checking:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
        else:
            prob = torch.cat([x[0] for x in self.validation_step_outputs]).cpu().numpy()
            weights = torch.cat([x[2] for x in self.validation_step_outputs]).cpu().numpy()
            # r2_val
            val_r_square = r2_val(y, prob, weights)
            self.log("val_r_square", val_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(make_parameter_groups(self.model), lr=self.lr, weight_decay=self.weight_decay)
        # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5,
        #                                                        verbose=True)
        return {
            'optimizer': optimizer,
            # 'lr_scheduler': {
            #     'scheduler': scheduler,
            #     'monitor': 'val_r_square',
            # }
        }

    def on_train_epoch_end(self):
        if self.trainer.sanity_checking:
            return

        y = torch.cat([x[1] for x in self.training_step_outputs]).cpu().numpy()
        prob = torch.cat([x[0] for x in self.training_step_outputs]).detach().cpu().numpy()
        weights = torch.cat([x[2] for x in self.training_step_outputs]).cpu().numpy()
        # r2_training
        train_r_square = r2_val(y, prob, weights)
        self.log("train_r_square", train_r_square, prog_bar=True, on_step=False, on_epoch=True)
        self.training_step_outputs.clear()

        epoch = self.trainer.current_epoch
        metrics = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in self.trainer.logged_metrics.items()}
        formatted_metrics = {k: f"{v:.5f}" for k, v in metrics.items()}
        print(f"Epoch {epoch}: {formatted_metrics}")

Trainer

pytorch_lightning.TrainerそのままでOKです。

インスタンス生成

インスタンス生成時に、n_cont_features(連続値の特徴量の数)とcat_cardinalities(カテゴリ変数のカーディナリティのリスト)を与える必要があります。

model = NN(
    n_cont_features = args.n_cont_features,
    cat_cardinalities = args.cat_cardinalities,
    n_classes = args.n_classes,
    lr=args.lr,
    weight_decay=args.weight_decay
)

カテゴリ変数の処理について注意が必要で、ここで与えたカーディナリティより大きい値がデータ中に存在すると、学習時にエラーとなります。
そのため、あるカテゴリ変数が1,3,5,8,11,...のように飛び飛びの値になっている場合、0,1,2,3...と変換する必要があります。
また、予測時に未知(学習時に存在しない)カテゴリがやってくる可能性がある場合には、何かしらのカテゴリ(学習時のカテゴリ最大値+1等)に変換するなどのケアが必要です。
下記のような関数をかませるイメージです。

def encode_column(df, column, mapping):
    max_value = max(mapping.values())  
    col_dtype = df[column].dtype

    def encode_category(category):
        return mapping.get(category, max_value + 1)  
    
    return df.with_columns(
        pl.col(column).map_elements(encode_category, return_dtype=col_dtype).alias(column)
    )

Trainerは特筆することはないので、好きなパラメータをお使いください。

# Initialize Logger
logger = None
# Initialize Callbacks
early_stopping = EarlyStopping('val_loss', patience=args.patience, mode='min', verbose=False)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, verbose=False, filename=output_path + f"/tabm_{MODEL_ID}.model")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
timer = Timer()
# Initialize Trainer
trainer = Trainer(
    max_epochs=args.max_epochs,
    accelerator=accelerator,
    devices=[args.gpuid] if args.usegpu else None,
    logger=logger,
    callbacks=[lr_monitor, early_stopping, checkpoint_callback, timer],
    enable_progress_bar=True
)

学習

学習は通常のPytorch Lightningモデルと同様に行えます。

# Start Training
trainer.fit(model, data_module.train_dataloader(args.loader_workers), data_module.val_dataloader(args.loader_workers))

予測

予測時にもcontとcatそれぞれ与える必要がある点にご注意ください。
予測結果はパラメータkで指定した数だけ出力されるので、meanなどで平均をとる必要があります。

pred = model(
    torch.FloatTensor(df[col_feature_cont].values).to("cuda:0"),
    torch.LongTensor(df[col_feature_cat].values).to("cuda:0")
    ).cpu().numpy().mean(1)

おわりに

TabMをPythorch Lightningで学習・予測させる際のポイントは下記の通りです。

  • DataModuleやLightningModule定義時にデータを連続列(cont)とカテゴリ列(cat)に分ける
  • カテゴリ列は0,1,2,...となるように変換しておく。
  • 予測結果は.mean(1)などで束ねる

また、Kaggleコンペなど、インターネットが使えない環境のときの使用方法については、Kaggle上でコードを公開していますので、こちらも参考にしていただけると幸いです。

以上となります。
今後、KaggleなどでもTabMを見かけることが増えるのではないかなと思うので、この記事が誰かの役に立てば幸いです。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?