基本的な使い方
unsqueeze
メソッドはPyTorch
のテンソルの次元を上げる際に用いられるメソッドです。
import torch
x = torch.zeros([2, 2])
print(x.shape)
print(x.unsqueeze(dim=0).shape)
print(x.unsqueeze(dim=1).shape)
print(x.unsqueeze(dim=2).shape)
・実行結果
torch.Size([2, 2])
torch.Size([1, 2, 2])
torch.Size([2, 1, 2])
torch.Size([2, 2, 1])
unsqueezeメソッドの活用
unsqueeze
メソッドは次元の大きなテンソルにテンソルを加えるなどの処理に活用することができます。
import torch
x = torch.zeros([2, 2])
y = torch.ones([2, 5, 2])
print(x)
print(y)
print(x.shape)
print(y.shape)
print(x.unsqueeze(dim=1).shape)
print((y+x.unsqueeze(dim=1)).shape)
・実行結果
tensor([[0., 0.],
[0., 0.]])
tensor([[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]])
torch.Size([2, 2])
torch.Size([2, 5, 2])
torch.Size([2, 1, 2])
torch.Size([2, 5, 2])