torch.repeatとtorch.expandの違いを知りたい!
導入
例えば、
1階テンソル[0,0,0]を複製して
[[0,0,0], [0,0,0], [0,0,0]]みたいに2階テンソルにしたい時があると思います。
そんな時使うメソッドとして候補に上がるのがtorch.repeatとtorch.expandですが、これらはどう違うのか
気になったので実験しました。
実験
>>> import torch
>>> a = torch.zeros(3, 1)
>>> b = a.repeat(1, 4)
>>> c = a.expand(3, 4) # a.expand(-1, 4)でも可能
>>> b.shape
torch.Size([3, 4])
>>> c.shape
torch.Size([3, 4])
>>> b
tensor([[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]])
>>> c
tensor([[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.],
[ 0., 0., 0., 0.]])
>>> b[:, 2] += 1
>>> c[:, 2] += 1
>>> b
tensor([[ 0., 0., 1., 0.],
[ 0., 0., 1., 0.],
[ 0., 0., 1., 0.]])
>>> c
tensor([[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.],
[ 1., 1., 1., 1.]])
要はこういうことでした。
元となるテンソルaをdataごと複製して新しいメモリに割り当てるのかどうかという違いでした。
てかdocumentationにもそう書いてあったわ
結論
複製したテンソルを個別にいじりたいならdataごと複製できるtorch.repeat
そうでないならメモリを節約できるtorch.expandを使おう
ということでした。
また注意としてtorch.expandは次元数が1の軸でないとexpandできないです。
torch.repeatは次元数が1より大きくてもrepeatできる
メモリの制約のせいかな