#概要
ToTensor()のinputは(C, W, H)じゃなくて(W, H, C)でないとダメだという話
#コード
a = np.tile(np.arange(64*64).reshape(64,64), (3, 1, 1))#shape (3,64,64)
これをpytorchのtoTensorでtorch型に変えると
from torchvision import transforms
a_tensor = trans = transforms.ToTensor()(a)
print(a_tensor.shape)
とするとtorch.Size([64, 3, 64])
となる、なぜtorch.Size([3, 64, 64])
にならないんだ。。
調べてみるとToTensor()はnumpyやPIL Image(W,H,C)をTensor(C,W,H)に変換するようです。ですので
a = np.tile(np.arange(64*64).reshape(64,64), (3, 1, 1)).T #shape (64,64,3)
from torchvision import transforms
a_tensor = trans = transforms.ToTensor()(a)
print(a_tensor.shape)
とすることでtorch.Size([3, 64, 64])
が出力されます。