1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

BERTベースモデル学習時のGPUメモリ不足エラー(RuntimeError: CUDA out of memory)の原因調査と対策

Last updated at Posted at 2023-03-26

BERTベースモデル(TabBERT)の学習時、入力データのサイズを大きくしたらGPUメモリ不足のエラーがでてしまうようになってしまいました。

AlgorithmError: ExecuteUserScriptError: ExitCode 1 ErrorMessage "RuntimeError: CUDA out of memory.Tried to allocate 2.69 GiB (GPU 0; 14.76 GiB total capacity; 12.32 GiB already allocated; 835.88 MiB free; 13.06 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

RNNのようにメモリ消費がデータサイズに依存するようなモデルではないという認識だったので、なぜこのようなエラーがでたのか直感的にわからなかったのですが、ありえそうな仮説をたてて、一つずつあたってみることにしました。

  • バッチサイズが大きい?
  • データ前処理でメモリ消費している?
  • データセットがGPUメモリに一度にロードされている?
  • ラベル付けで一部カラムのカテゴリ数が増えたことが影響している?
  • モデル構造がデータサイズに依存している?

調査

1. バッチサイズが大きい?

データサイズが小さいときにはバッチサイズ8で学習できていたのですが、念のため確かめてみました。
バッチサイズを8から1に落としても、同様のエラーがでてしまいました。

2. データ前処理でメモリ消費している?

前処理を行うクラスをみる限り、GPUメモリを消費するようなことは行っていませんでした。

3. データセットがGPUメモリに一度にロードされている?

学習にはTransformersのTrainerを使い、バッチサイズもちゃんと指定していたため、そのあたりはいい感じにやってくれていそうでした。

4. ラベル付けで一部カラムのカテゴリ数が増えたことが影響している?

ラベル付けの実装ミスで一部カラムで80万カテゴリくらいつくってしまっていたのですが、該当箇所を消しても同様のエラーが出ました(実装ミスがみつかったのはよかったですが...)。

5. モデル構造がデータサイズに依存している?

BERTベースのモデルなのでその線はなさそう...?と思ったのですが、各層のパラメータ数で極端に多くなっているところはないかどうかを調べてみることにしました。

まず、torchinfoのsummary()関数をコードに追加してみました。

summary(
    model=tab_net.model.tab_embeddings,
    input_size=(8, 2, 13)
)

そしたらRuntimeErrorがでてきました。

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

torchinfoのsummary()関数は入力テンソルのサイズが直接モデルに渡されるモデルで使用されるのですが、調査しているTabFormerHierarchicalLMの場合、TabFormerEmbeddingsが先に実行され、その出力がTabFormerBertForMaskedLMの入力として使用されるようなモデルとなっているため、モデル全体に対してsummary()関数を適用することは困難だということがわかりました。

ライブラリに頼るのはやめ、print(model)count_parametersで各層の詳細情報やパラメータ数(総パラメータ数も)を得ることにしました。

def count_parameters(model):
    total_parameters = 0
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_parameters += param_count
        print(f"{name}: {param_count}")
    print(f"Total parameters: {total_parameters}")

データサイズ小(語彙サイズ2912)と大(語彙サイズ867813)で結果を比較してみたところ、一部の層のパラメータ数に大きな差があり、結果として総パラメータ数にもかなりの差がありました。

データサイズ小

tab_embeddings.word_embeddings.weight: 186368
tb_model.bert.embeddings.word_embeddings.weight: 2422784
Total parameters: 98510560

データサイズ大

tab_embeddings.word_embeddings.weight: 55540032
tb_model.bert.embeddings.word_embeddings.weight: 722020416
Total parameters: 874326757

tab_embeddings.word_embeddings.weightは、TabFormerEmbeddingsモジュールのEmbeddingレイヤー(単語やカテゴリ変数のような離散的な入力を連続的なベクトル表現に変換する層)で使用される重み行列です。
Embeddingレイヤーは、タブラー(表形式)データのカテゴリ変数を埋め込むために使用され、要素数(186368, 55540032)は、語彙サイズと埋め込み次元数の積によって決まります。

データサイズ小の場合は以下のような計算となります。
2912(語彙サイズ=各カラムのカテゴリ数の合計)*64(埋め込み次元数) = 186368(要素数)

一意なvisitor_idが2817個あるため、語彙サイズが大きくなってしまっています。

03/27/2023 06:02:17 - INFO - dataset.action_history -   total vocabulary size: 2912
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : year, vocab size : 1
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : month, vocab size : 1
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : day, vocab size : 1
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : hour, vocab size : 24
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : visitor_id, vocab size : 2817
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : company_id, vocab size : 21
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : site_id, vocab size : 21
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : device, vocab size : 3
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : ma_crm, vocab size : 3
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : sfa, vocab size : 1
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : stay_seconds, vocab size : 9
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : day_of_week, vocab size : 1
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : reaction, vocab size : 2
03/27/2023 06:02:17 - INFO - dataset.action_history -   column : SPECIAL, vocab size : 7

また、データサイズを大きくすると、それに比例して一意なvisitor_idも増えるため、要素数(語彙サイズ)も大きくなっています。

03/27/2023 06:13:08 - INFO - dataset.action_history -   total vocabulary size: 867813
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : year, vocab size : 2
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : month, vocab size : 9
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : day, vocab size : 31
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : hour, vocab size : 24
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : visitor_id, vocab size : 867599
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : company_id, vocab size : 57
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : site_id, vocab size : 58
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : device, vocab size : 3
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : ma_crm, vocab size : 3
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : sfa, vocab size : 3
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : stay_seconds, vocab size : 9
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : day_of_week, vocab size : 7
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : reaction, vocab size : 1
03/27/2023 06:13:08 - INFO - dataset.action_history -   column : SPECIAL, vocab size : 7

なお、TabFormerEmbeddingsは、タブラーデータの各カテゴリカル変数(列)に対応する埋め込みベクトルを生成し、最終的にTransformerEncoderでそれらの埋め込みベクトルを組み合わせて、タブラーデータの行全体を表現する埋め込みベクトルを作成します。

TabFormerBertForMaskedLMは、TabFormerEmbeddingsによって生成された行全体を表現する埋め込みベクトルを入力として受け取り、その上でBERTモデルを適用して文脈埋め込みを計算します。

※BERTモデルは、Transformerベースのアーキテクチャで、入力された埋め込みベクトルの系列に対して文脈情報をエンコードします。これにより、行同士の関係やパターンを考慮した文脈埋め込みが得られます。

tb_model.bert.embeddings.word_embeddings.weightは、TabFormerBertForMaskedLMモジュール内のBERTモデルのEmbeddingレイヤーで使用される重み行列です。
こちらのEmbeddingレイヤーもトークンを埋め込むために使用され、要素数(2422784, 722020416)は、語彙サイズと埋め込み次元数の積によって決まります。

2912(語彙サイズ=各カラムのカテゴリ数の合計)*832(埋め込み次元数) = 2422784(要素数)

データサイズが大きくなることで語彙サイズも大きくなり、それに比例してパラメータ数も増えてしまっていたようです。

結局、4のラベル付けで一部カラムのカテゴリ数が増えたことが根本原因でした。

結論

今回のGPUメモリ不足は、以下のことが原因として考えられます。

  • 語彙サイズが増えて、Embeddingレイヤーで必要とされるメモリ容量が増加した。
  • パラメータの更新に必要なメモリも増加した(勾配、オプティマイザの状態など)。
  • バッチサイズが大きい場合、学習中に一度に処理するデータ量が多くなり、GPUメモリの消費量が増加した。

対策

メモリ不足対策として以下の方法がありそうでしたが、「語彙サイズを減らす(visitor_idカラムのデータを学習しない)」という方向でいくことにしました(visitor_idが入力カラムとして有用ではないと考えたため)。

  • バッチサイズを減らす(既に試してダメだった)。
  • より大きなGPUメモリをインスタンスに変える(少し調べてみたら、東京リージョンにGPU 16 GBより大きいインスタンスがなさそうだった)。
  • 語彙サイズを減らす。
  • 埋め込みベクトルの次元を減らす。
  • モデルのアーキテクチャをシンプルにする(レイヤー数やユニット数を減らす)。
  • 勾配計算の精度を下げる(例:半精度浮動小数点数(float16)を使用する)。

結果、ちゃんと学習できるようになりました。

image.png

wandb: Waiting for W&B process to finish... (success).
wandb:
wandb: Run history:
wandb:                    train/epoch ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:              train/global_step ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:            train/learning_rate ████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
wandb:                     train/loss ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:               train/total_flos ▁
wandb:               train/train_loss ▁
wandb:            train/train_runtime ▁
wandb: train/train_samples_per_second ▁
wandb:   train/train_steps_per_second ▁
wandb:
wandb: Run summary:
wandb:                    train/epoch 3.0
wandb:              train/global_step 302874
wandb:            train/learning_rate 0.0
wandb:                     train/loss 0.0
wandb:               train/total_flos 2.990981436994685e+16
wandb:               train/train_loss 0.0
wandb:            train/train_runtime 21314.0384
wandb: train/train_samples_per_second 113.68
wandb:   train/train_steps_per_second 14.21

参考資料

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?