学習したネットワークにtrain時より大きなデータ(画像等)をtest時に投入するとメモリが溢れてしまうことがある.
TrainとTestについて
pytorchではtrain時,forward計算時に勾配計算用のパラメータを保存しておくことでbackward計算の高速化を行っているらしい.
これは,model.eval()
で行っていてもパラメータが保存されているようなので,下記対策が必要になる.
with torch.no_grad()
pytorch 0.4.0に追加されたtorch.no_grad()
を使用してパラメータの保存を止める
pytorch-MNISTサンプルより一部抜粋
sample.py
def test(args, model, device, test_loader):
model.eval()
with torch.no_grad():
model(data) # forward実行
参考