PyTorch TutorialのData Loading and Processing Tutorialをやってるときに気になったのでメモ
背景
Iterating through the dataset 中のコードでデータセットの画像に対してスケールやら,クロップやらの変換を施した結果を可視化したかった.
そのままshow_landmarks()
を呼ぶとpyplotとPyTorchでサポートしている画像配列の軸の順番が違うため表示できない.
==> show_landmarks()
の中で軸の順番を入れ替えよう.
結論
numpyのtranspose
はPyTorchではpermute
Numpyのtranspose
numpyのtransposeといえば多次元配列の軸の順番を入れ替える関数ですね.
import numpy as np
sample0 = np.arange(15).reshape(5, 3)
# array([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]])
sample0.transpose((1, 0))
# array([[ 0, 3, 6, 9, 12],
# [ 1, 4, 7, 10, 13],
# [ 2, 5, 8, 11, 14]])
sample1 = np.arange(30).reshape(2, 5, 3)
# array([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]],
# [[15, 16, 17],
# [18, 19, 20],
# [21, 22, 23],
# [24, 25, 26],
# [27, 28, 29]]])
sample1.transpose((2, 0, 1))
# array([[[ 0, 3, 6, 9, 12],
# [15, 18, 21, 24, 27]],
#
# [[ 1, 4, 7, 10, 13],
# [16, 19, 22, 25, 28]],
#
# [[ 2, 5, 8, 11, 14],
# [17, 20, 23, 26, 29]]])
PyTorchのtranspose
PyTorchでもtranspose はサポートされているのですがこれは2次元配列2軸の入れ替えにしか使えません
(ちなみにPyTorchの場合配列のサイズはtupleでは指定できません.)
import torch
sample0 = torch.arange(15).reshape(5, 3)
# tensor([[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]])
sample0.transpose(1, 0)
# tensor([[ 0, 3, 6, 9, 12],
# [ 1, 4, 7, 10, 13],
# [ 2, 5, 8, 11, 14]])
sample1 = torch.arange(30).reshape(2, 5, 3)
# tensor([[[ 0, 1, 2],
# [ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11],
# [12, 13, 14]],
#
# [[15, 16, 17],
# [18, 19, 20],
# [21, 22, 23],
# [24, 25, 26],
# [27, 28, 29]]])
sample1.transpose(2, 0, 1)
# TypeError: transpose() received an invalid combination of arguments - got (int, int, int), but expected one of:
# * (name dim0, name dim1)
# * (int dim0, int dim1)
sample1.transpose(2, 0)
# tensor([[[ 0, 15],
# [ 3, 18],
# [ 6, 21],
# [ 9, 24],
# [12, 27]],
#
# [[ 1, 16],
# [ 4, 19],
# [ 7, 22],
# [10, 25],
# [13, 28]],
#
# [[ 2, 17],
# [ 5, 20],
# [ 8, 23],
# [11, 26],
# [14, 29]]])
PyTorchでの軸の順番入れ替え
探してみたらPyTorchのフォーラムにありました.
Swap axes in pytorch?
PyTorchではpermute
を使うそうです.
sample1.permute(2, 0, 1)
# tensor([[[ 0, 3, 6, 9, 12],
# [15, 18, 21, 24, 27]],
#
# [[ 1, 4, 7, 10, 13],
# [16, 19, 22, 25, 28]],
#
# [[ 2, 5, 8, 11, 14],
# [17, 20, 23, 26, 29]]])
ちなみに...
そのままoutputに使うなら(その後でPyTorchのTensorとして処理しないなら)以下でもいいんですけどね.
sample1.numpy().transpose(2, 0, 1)
# array([[[ 0, 3, 6, 9, 12],
# [15, 18, 21, 24, 27]],
#
# [[ 1, 4, 7, 10, 13],
# [16, 19, 22, 25, 28]],
#
# [[ 2, 5, 8, 11, 14],
# [17, 20, 23, 26, 29]]])