1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

初学者が遭遇しがちな機械学習モデルの推論時エラーと解決策

Posted at

機械学習モデルをトレーニングして推論タスクに適用する際、特に異なるコンピューティング環境(トレーニングはGPU、推論はCPU)を利用する場合に推論がうまくいかないことがよくありました(2クラス予測で0と1を交互に推論するなど)。

そのときにやってしまっていたこととその解決策を簡単にまとめます。

1. モデルのロードミス

推論エンドポイントをCPU(ml.m5.large)で構築したところ、GPU(g4dn.2xlarge)でトレーニングしたモデルが期待通りの推論結果を出力しませんでした。

ログを開いたら以下のWarningがでていました。

2024-04-25T05:47:04,365 [WARN ] W-9000-model_1.0-stderr MODEL_LOG - Some weights of the model checkpoint at /opt/ml/model/code/pytorch_model.bin were not used when initializing BertModel...

Fine-TuningされたPredictionModelで出力されたpytorch_model.binファイルをBertModel.from_pretrainedを用いて直接ロードしようとしたことで発生しました。

BertModel.from_pretrainedメソッドは、基本のBERTモデルの構造を前提としているため、PredictionModelに追加されたLSTM層や線形層のパラメータが考慮されず、結果的にこれらの重要なパラメータが無視されていました。

pretrained_config = path.join("/opt/ml/model/code/", f"config.json")
pretrained_model = path.join("/opt/ml/model/code/", f"pytorch_model.bin")
config = BertConfig.from_pretrained(pretrained_config)

model = PredictionModel(config=config, pretrained_model=pretrained_model)

モデルの全パラメータを含むfine_tuning_model.ptから適切に状態をロードすることで問題が解決しました。

model = PredictionModel(config=config, pretrained_model=None)
model_path = path.join("/opt/ml/model/code/", "fine_tuning_model.pt")
model.load_state_dict(torch.load(model_path))

2. モデルのデバイス割り当て

モデルがデフォルトでCUDAデバイスを使用する設定になっており、CUDA非対応の環境でエラーが発生しました。

2024-04-28T06:59:31,905 [INFO ] W-9001-model_1.0-stdout MODEL_LOG - Exception in model fn Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False...

CPUのみの環境でモデルをロードする際は、適切なデバイスにモデルを割り当てる必要がありました。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

3. 勾配計算の無効化

推論時には、モデルの入力データがテンソルとして与えられる際に、勾配計算を無効にすることが推奨されるとのことでした。

推論結果への影響はない(と思う)のですが、無駄なメモリ使用と計算時間の増加を招くとのことで直しました。

with torch.no_grad():
    model_input = torch.tensor([preprocessed_data.getitem()], dtype=torch.long).to(device)
    model_output = model(model_input)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?