LoginSignup
14
11

More than 5 years have passed since last update.

PyTorchで巨大テンソルのLossを計算していたらメモリ不足になった

Last updated at Posted at 2018-11-23

事の経緯

動画を処理する機械学習コードで

巨大なループの中で誤差を蓄積したかった

for n in range(len(datalist)):

     (省略)
     loss = loss_fn(GroundTruth, output)
     loss.backward()
     losssum += loss
print(losssum/len(datelist))

するどどんどんメモリが圧迫されていき

out of memoryエラーに。。。。

原因

lossテンソルをそのままコピーしていたために

losssum内に勾配データが蓄積されてしまったのが原因みたいだった

正しくは

for n in range(len(datalist)):

     (省略)
     loss = loss_fn(GroundTruth, output)
     loss.backward()
     losssum += loss.detach()
print(losssum/len(datelist))

メモリ問題も解消

教訓

学習に関係ないところにテンソルをコピーするときは

tensor.detach()
#どうやらdetach()はコピー元と同じ記憶領域を共有しているみたいなので
tensor.clone()
#を推奨します!
14
11
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
11