LoginSignup
30
16

More than 5 years have passed since last update.

Pytorch test時にメモリが溢れてしまう

Posted at

学習したネットワークにtrain時より大きなデータ(画像等)をtest時に投入するとメモリが溢れてしまうことがある.

TrainとTestについて

pytorchではtrain時,forward計算時に勾配計算用のパラメータを保存しておくことでbackward計算の高速化を行っているらしい.
これは,model.eval()で行っていてもパラメータが保存されているようなので,下記対策が必要になる.

with torch.no_grad()

pytorch 0.4.0に追加されたtorch.no_grad()を使用してパラメータの保存を止める

pytorch-MNISTサンプルより一部抜粋

sample.py
def test(args, model, device, test_loader):
    model.eval()
    with torch.no_grad():
        model(data) # forward実行

参考

30
16
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
30
16