LoginSignup
21
8

More than 5 years have passed since last update.

PyTorchの+=は危険。。。!?

Last updated at Posted at 2018-11-26

事の経緯

デバッグをしている最中にある変数が不可解な動きをしているのを見つけた

loss = torch.ones(1)*10
a = loss.detach()
print(a)
#return tensor([10.])
loss += 20
print(a)
#return tensor([30.])

結果を見て目を丸くした、detach()はデータの安全なコピーだと思っていたが

違うみたいだった。

PyTorchにおけるデータの安全なコピーは
tensor.clone()
だそうです。

loss = torch.ones(1)*10
a = loss
print(a)
#return tensor([10.])
loss = 20 + loss
print(a)
#return tensor([10.])

これだと正常に動いた、detach(),.dataでコピーしたデータのコピー元tensorに+=の演算をすると

コピーしたtensorまで変更されてしまうようだ。

言いたいこと

PyTorchにおいて.detach()と.dataは保存領域共有しているようです。

完全なるデータのコピーを作るには.clone()を使いましょう

また+=はinplaceの計算なので気を使いましょう

変なところで値が動く可能性があります。

21
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
21
8