1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

LightGBMの学習過程をTensorboardで表示する

Posted at

1.この記事の内容

LightGBMの学習過程をTensorboardで確認する方法を紹介します.

1-1.ポイント

  • Tensorboardが読み込む学習ログを出力するために,torch.utils.tensorboardを使う
  • LightGBMのcallbacksにadd_scalarsでログを出力するカスタムコールバックを指定する

2.背景

AI Dashboardには,最初にTensorFlowでDNNを学習する機能を実装しましたが,その際に学習過程をTensorboardで表示するようにしました.
その後,テーブルデータをLightGBMで学習する機能を実装した際に,学習過程の表示はTensorboardで統一したくなりました.

3.手順詳細

プログラム全体は筆者のGitHubで公開しています.
本文には主要な箇所のみを記載します.

3-1. ログ出力準備

  • SummaryWriterオブジェクト生成
create summary writer object
from torch.utils.tensorboard import SummaryWriter

...

self.writer = SummaryWriter(log_dir=Path(output_dir, 'logs'))
  • ログ出力するカスタムコールバックを定義します
custom callback
class LogSummaryWriterCallback:
    """ Writing log callable class
    """
    
    def __init__(self, period=1, writer=None):
        self.period = period
        self.writer = writer
    
    def __call__(self, env):
        if (self.period > 0) and (env.evaluation_result_list) and (((env.iteration+1) % self.period)==0):
            if (self.writer is not None):
                scalars = {}
                for (name, metric, value, is_higher_better) in env.evaluation_result_list:
                    if (metric not in scalars.keys()):
                        scalars[metric] = {}
                    scalars[metric][name] = value
                    
                for key in scalars.keys():
                    self.writer.add_scalars(key, scalars[key], env.iteration+1)
            else:
                print(env.evaluation_result_list)
 

3-2. LightGBMの学習処理

  • lightgbm.trainに与えるmetricパラメータでログ出力するメトリクスを指定します.
    下記はL1L2RMSEのログを出力する例です.
parameters
self.params = {
    'objective': 'regression',
    'metric': 'l1,l2,rmse',
    'num_leaves': 32,
    'max_depth': 4,
    'feature_fraction': 0.5,
    'subsample_freq': 1,
    'bagging_fraction': 0.8,
    'min_data_in_leaf': 5,
    'learning_rate': learning_rate,
    'boosting': 'gbdt',
    'lambda_l1': 1,
    'lambda_l2': 5,
    'verbosity': -1,
    'random_state': 42,
    'early_stopping_rounds': 100,
}
  • valid_namesvalid_setsにログ出力対象のデータを指定し,カスタムコールバックを指定します.
training
self.model = lgb.train(
    self.params,
    train_data,
    valid_names=valid_names,
    valid_sets=valid_sets,
    num_boost_round=50000,
    callbacks=[
        lgb.log_evaluation(period=100),
        LogSummaryWriterCallback(period=100, writer=self.writer)
    ]
)

4.Tensorboardの出力例

image.png

5.さいごに

AI DashboardにはTensorで学習する処理を実装していたので,TensorFlowのAPIからできないか考えてみたのですが,tf.keras.metricsクラスを介する必要があり,データ型の整合(tf.TorchとPython標準のデータ型との変換)が複雑になりそうでしたので,torch.utils.tensorboardを使用する方法を採用しました.

このような使い方をしたくなる人がどれくらいいるか分かりませんが,ご参考まで.

6.関連リンク

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?