#事の経緯
動画を処理する機械学習コードで
巨大なループの中で誤差を蓄積したかった
for n in range(len(datalist)):
(省略)
loss = loss_fn(GroundTruth, output)
loss.backward()
losssum += loss
print(losssum/len(datelist))
するどどんどんメモリが圧迫されていき
out of memoryエラーに。。。。
#原因
lossテンソルをそのままコピーしていたために
losssum内に勾配データが蓄積されてしまったのが原因みたいだった
正しくは
for n in range(len(datalist)):
(省略)
loss = loss_fn(GroundTruth, output)
loss.backward()
losssum += loss.detach()
print(losssum/len(datelist))
メモリ問題も解消
###教訓
学習に関係ないところにテンソルをコピーするときは
tensor.detach()
#どうやらdetach()はコピー元と同じ記憶領域を共有しているみたいなので
tensor.clone()
#を推奨します!