はじめに
PyTorchの公式(1)にあるtorch.cat
の例だけでは良く分からなかったので他の例を使って理解するための書き残しです.
sample code
import torch
input1 = torch.randn(2, 3, 4)
input2 = torch.randn(2, 3, 4)
input3 = torch.randn(5, 3, 4)
input_list = [input1, input2, input3]
output1 = torch.cat(input_list, dim=0)
print(output1.size()) # torch.Size([9, 3, 4])
output2 = torch.cat(input_list, dim=1) # error
print(output2.size())
output3 = torch.cat(input_list, dim=2) # error
print(output3.size())
torch.catの例示
torch.catの入力を見てみると
tensors (sequence of Tensors) – any python sequence of tensors of the same type.
Non-empty tensors provided must have the same shape, except in the cat dimension.
dim (int, optional) – the dimension over which the tensors are concatenated
out (Tensor, optional) – the output tensor.
と3つ引数があるようです.sample集(2)を見ていても第一と第二を使用することが多く第三引数はあまり使われてないようです.なので,第一と第二を使っていきます.
第一引数(tensors)
第一引数はtensors (sequence of Tensors)
なミソですね.
これはTensorがlist型
かtuple型
で入力される必要があるので以下のように準備します.注意が必要なことは,次元の大きさがすべて同じことです.今回は3次元としています.
input1 = torch.randn(2, 3, 4)
input2 = torch.randn(2, 3, 4)
input3 = torch.randn(5, 3, 4)
input_list = [input1, input2, input3]
print(len(input1.size())) # 3
print(len(input2.size())) # 3
print(len(input3.size())) # 3
第二引数(dim)
説明を見ても結合する次元とだけ書かれているので理解できそうでできない雰囲気ですね.
今回の例(output1
)だとdim=0
ならinput1.size(0), input2.size(0), input3.size(0)
サイズを結合(足し合わせ)します.
この時,dim=1, dim=2
のsize
が同じである必要があります.
手書きですが,図で書くと以下のようになります.
dim=1, dim=2
のsize
が3, 4
なので,結合できますね.
したがって,output1.size()
は,torch.Size([9, 3, 4])
になります.
結合する次元をdim=1, dim=2
にしたときは,dim=0
が合わないので結合できないとエラーが返ってきます.
さいごに
ちなみにerrorとなっている部分に対して,input1,input2,input3のsizeをすべて同じ形にするとなくなります.でも,それはPyTorchの公式を見れば十分かと
torch.viewとtorch.catはCNNのPoolで見かけることが多いな~と思う今日この頃
参考文献
(1) TORCH.CAT
(2) Python torch.cat() Examples