概要
PyTorchの torch.topk
関数において巨大なテンソルを扱った際に発生するエラーと解決方法について紹介します。
環境
この記事では、下記の環境で動作確認をしています。
- PyTorch 1.13.1
- CUDA 11.7
- NVIDIA GeForce RTX 3080Ti (VRAM 12GB)
環境構築は conda
コマンドにより行っています。
conda install pytorch==1.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
torch.topk
で巨大なテンソルを扱った際に生じるエラー
PyTorchでk近傍法などを実装する際に利用する関数として torch.topk
があります。
この torch.topk
に対して、一定のサイズを超えるテンソルを入力した際にエラーが発生します。筆者の環境で確認したところ、下記のコードでエラーとなりました。
import torch
x = torch.randn((65, 1000, 1000), device="cuda")
topk = torch.topk(x, 10, -1)
print(topk) # 正常
x = torch.randn((66, 1000, 1000), device="cuda")
topk = torch.topk(x, 10, -1)
print(topk) # エラー
(65, 1000, 1000)
の配列では動作していたものの、(66, 1000, 1000)
になるとエラーが発生します。この時のエラーメッセージとしては下記の通りです。
RuntimeError: CUDA error: an illegal memory access was encountered
巨大なテンソルを扱っているので原因としてメモリ不足が第一に考えられますが、その際は CUDA out of memory
と表示されるため、一般的なメモリ不足のエラーとは少し異なるようです。
解決方法
このエラーはPyTorchのIssueでも報告されていましたが、PyTorch 2.0.0で修正されたようです。
下記コマンドによりPyTorch 2.0.0の環境で試したところ、エラーとなっていた箇所も正常に実行できました。
conda install pytorch=2 pytorch-cuda=11.7 -c pytorch -c nvidia