10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LightGBMはメモリ効率が良くてとても使いやすいですが、それでも学習データが巨大すぎるとメモリエラーが発生します。
ここでは巨大な学習データに対して、逐次的に処理を行うことによって、メモリエラーを発生させずに学習を完遂させる方法をご紹介します。

概要

LightGBMは学習を行うtrain()メソッドを行うときに、その事前準備としてlightgbm.Dataset型の変数の構築を行っていますが、この構築部分でメモリの使用量が増大します。
そこで、一括で学習データをlightgbm.Dataset型に変換するのではなく、学習データを部分的にlightgbm.Dataset型に変換しconstruct()で構築するということを繰り返すことで、メモリエラーを回避します。
その上でtrain()メソッドを実行すればメモリエラーは起きません。

なお、ここでのLightGBMはTrain APIを使います。

# Training APIのスタイル
import lightgbm as lgb

train_data = lgb.Dataset(X_train, label=y_train)
eval_data = lgb.Dataset(X_valid, label=y_valid, reference= train_data)

gbm = lgb.train(
                  params,
                  train_data,
                  valid_sets=eval_data,
                  num_boost_round=100,
                  categorical_feature=cat_cols
                )

おそらく1番ポピュラーなのはScikit-learn APIだと思いますが、Training APIの方が色々小回りが効く印象があります。

# Scikit-learn APIのスタイル
import lightgbm as lgb

lgb_model = lgb.LGBMClassifier(**params)

lgb_model.fit(
         X_train,
         y_train,
         eval_set=[(X_valid, y_valid)],
         callbacks=[lgb.log_evaluation(100), lgb.early_stopping(100)],
     )

まあ、仮にScikit-learn APIを使っていても書き換えはそこまで大変ではないと思います。
Scikit-learnでも同じような方法でメモリエラーを回避できるのかどうかはわかりません・・・(誰か教えてください)

具体例

具体的には、

# さっきの再掲
import lightgbm as lgb

train_data = lgb.Dataset(X_train, label=y_train)
eval_data = lgb.Dataset(X_valid, label=y_valid, reference= train_data)

gbm = lgb.train(
                  params,
                  train_data,
                  valid_sets=eval_data,
                  num_boost_round=100,
                  categorical_feature=cat_cols
                )

のような場合に、

import lightgbm as lgb

train_cols = X_train.columns

# すべてのカラムを前から順に10ずつ変換していく。
for index in range(0, len(X_train), 10):
  # 学習データを10カラム単位のバッチに分割して処理する。
  if index + 10 <= len(X_train):
    former_lgb_train_initial = lgb.Dataset(X_train[train_cols[index:index+10]], label=y_train, free_raw_data=False)
  else:
    former_lgb_train_initial = lgb.Dataset(X_train[train_cols[index:]], label=y_train, free_raw_data=False)
  former_lgb_train_initial.construct()
  if index==0:
    lgb_train_initial = former_lgb_train_initial
  else:
  # 2回目以降のイテレーションでは前回のDatasetに付け加えていく。
    lgb_train_initial.add_features_from(former_lgb_train_initial)
    former_lgb_train_initial = lgb_train_initial

for index in range(0, len(X_valid), 10):
  # 検証データを10カラム単位のバッチに分割して処理する。
  # 検証データは学習データに比べてデータ量が小さいことが多いため、わざわざ分割しなくても良いかもしれない。 
  if index + 10 <= len(X_valid):
    former_lgb_valid_initial = lgb.Dataset(X_valid[train_cols[index:index+10]], label=y_valid, free_raw_data=False)
  else:
    former_lgb_valid_initial = lgb.Dataset(X_valid[train_cols[index:]], label=y_valid, free_raw_data=False)
  former_lgb_valid_initial.construct()
  if index==0:
    lgb_valid_initial = former_lgb_valid_initial
  else:
  # 2回目以降のイテレーションでは前回のDatasetに付け加えていく。
    lgb_valid_initial.add_features_from(former_lgb_valid_initial)
    former_lgb_valid_initial = lgb_valid_initial

train_data, eval_data = lgb_train_initial, lgb_valid_initial

# 以下は同じ。
gbm = lgb.train(
                  params,
                  train_data,
                  valid_sets=eval_data,
                  num_boost_round=100,
                  categorical_feature=cat_cols
                )

となります。

10
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?