Help us understand the problem. What is going on with this article?

PyTorchの+=は危険。。。!?

More than 1 year has passed since last update.

事の経緯

デバッグをしている最中にある変数が不可解な動きをしているのを見つけた

loss = torch.ones(1)*10
a = loss.detach()
print(a)
#return tensor([10.])
loss += 20
print(a)
#return tensor([30.])

結果を見て目を丸くした、detach()はデータの安全なコピーだと思っていたが

違うみたいだった。

PyTorchにおけるデータの安全なコピーは
tensor.clone()
だそうです。

loss = torch.ones(1)*10
a = loss
print(a)
#return tensor([10.])
loss = 20 + loss
print(a)
#return tensor([10.])

これだと正常に動いた、detach(),.dataでコピーしたデータのコピー元tensorに+=の演算をすると

コピーしたtensorまで変更されてしまうようだ。

言いたいこと

PyTorchにおいて.detach()と.dataは保存領域共有しているようです。

完全なるデータのコピーを作るには.clone()を使いましょう

また+=はinplaceの計算なので気を使いましょう

変なところで値が動く可能性があります。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away