LoginSignup
24
8

More than 3 years have passed since last update.

PyTorchのtransposeはnumpyのtransposeと若干違う(PyTorchで軸の順番を入れ替える方法について)

Last updated at Posted at 2019-08-18

PyTorch TutorialData Loading and Processing Tutorialをやってるときに気になったのでメモ

背景

Iterating through the dataset 中のコードでデータセットの画像に対してスケールやら,クロップやらの変換を施した結果を可視化したかった.

そのままshow_landmarks()を呼ぶとpyplotとPyTorchでサポートしている画像配列の軸の順番が違うため表示できない.
==> show_landmarks()の中で軸の順番を入れ替えよう.

結論

numpyのtransposeはPyTorchではpermute

Numpyのtranspose

numpyのtransposeといえば多次元配列の軸の順番を入れ替える関数ですね.

numpy
import numpy as np
sample0 = np.arange(15).reshape(5, 3)
# array([[ 0,  1,  2],
#        [ 3,  4,  5],
#        [ 6,  7,  8],
#        [ 9, 10, 11],
#        [12, 13, 14]])
sample0.transpose((1, 0))
# array([[ 0,  3,  6,  9, 12],
#        [ 1,  4,  7, 10, 13],
#        [ 2,  5,  8, 11, 14]])

sample1 = np.arange(30).reshape(2, 5, 3)
# array([[[ 0,  1,  2],
#         [ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11],
#         [12, 13, 14]],

#        [[15, 16, 17],
#         [18, 19, 20],
#         [21, 22, 23],
#         [24, 25, 26],
#         [27, 28, 29]]])
sample1.transpose((2, 0, 1))
# array([[[ 0,  3,  6,  9, 12],
#         [15, 18, 21, 24, 27]],
# 
#        [[ 1,  4,  7, 10, 13],
#         [16, 19, 22, 25, 28]],
# 
#        [[ 2,  5,  8, 11, 14],
#         [17, 20, 23, 26, 29]]])

PyTorchのtranspose

PyTorchでもtranspose はサポートされているのですがこれは2次元配列2軸の入れ替えにしか使えません
(ちなみにPyTorchの場合配列のサイズはtupleでは指定できません.)

PyTorch
import torch
sample0 = torch.arange(15).reshape(5, 3)
# tensor([[ 0,  1,  2],
#         [ 3,  4,  5],
#         [ 6,  7,  8],
#         [ 9, 10, 11],
#         [12, 13, 14]])

sample0.transpose(1, 0)
# tensor([[ 0,  3,  6,  9, 12],
#         [ 1,  4,  7, 10, 13],
#         [ 2,  5,  8, 11, 14]])

sample1 = torch.arange(30).reshape(2, 5, 3)
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8],
#          [ 9, 10, 11],
#          [12, 13, 14]],
# 
#         [[15, 16, 17],
#          [18, 19, 20],
#          [21, 22, 23],
#          [24, 25, 26],
#          [27, 28, 29]]])
sample1.transpose(2, 0, 1)
# TypeError: transpose() received an invalid combination of arguments - got (int, int, int), but expected one of:
# * (name dim0, name dim1)
# * (int dim0, int dim1)
sample1.transpose(2, 0)
# tensor([[[ 0, 15],
#          [ 3, 18],
#          [ 6, 21],
#          [ 9, 24],
#          [12, 27]],
# 
#         [[ 1, 16],
#          [ 4, 19],
#          [ 7, 22],
#          [10, 25],
#          [13, 28]],
# 
#         [[ 2, 17],
#          [ 5, 20],
#          [ 8, 23],
#          [11, 26],
#          [14, 29]]])

PyTorchでの軸の順番入れ替え

探してみたらPyTorchのフォーラムにありました.
Swap axes in pytorch?
PyTorchではpermuteを使うそうです.

Pytorch
sample1.permute(2, 0, 1)
# tensor([[[ 0,  3,  6,  9, 12],
#          [15, 18, 21, 24, 27]],
# 
#         [[ 1,  4,  7, 10, 13],
#          [16, 19, 22, 25, 28]],
# 
#         [[ 2,  5,  8, 11, 14],
#          [17, 20, 23, 26, 29]]])

ちなみに...

そのままoutputに使うなら(その後でPyTorchのTensorとして処理しないなら)以下でもいいんですけどね.

sample1.numpy().transpose(2, 0, 1)
# array([[[ 0,  3,  6,  9, 12],
#         [15, 18, 21, 24, 27]],
# 
#        [[ 1,  4,  7, 10, 13],
#         [16, 19, 22, 25, 28]],
# 
#        [[ 2,  5,  8, 11, 14],
#         [17, 20, 23, 26, 29]]])

その他の参考文献

配列の軸の順番を入れ替えるNumPyのtranspose関数の使い方 - DeepAge

24
8
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
24
8