1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

torch.transposeがnp.transposeのように使えないのでtorch.einsumで代用してみる

Last updated at Posted at 2019-07-29

なにをしたいか

  • 高階テンソルの軸を入れ替えたい
  • np.transposeのように、軸をリストで与えて軸を入れ替えたい
  • アインシュタインの縮約記法なら"それっぽく"できそう
  • 結論:遅いのでやめたほうがいい

追記2020.1.15

  • torch.Tensor.permuteがnp.transposeと同じように使える
transposeとpermuteの比較
# np.transpose
>>> a = np.arange(8).reshape(2,2,2)
>>> a.transpose(1,2,0)
array([[[0, 4],
        [1, 5]],

       [[2, 6],
        [3, 7]]])

# torch.Tensor.permute
>>> a = torch.arange(8).reshape(2,2,2)
>>> a.permute(1,2,0)
tensor([[[0, 4],
         [1, 5]],

        [[2, 6],
         [3, 7]]])

コード

numpy.transposeの場合
>>> a = np.arange(8).reshape(2,2,2)
>>> print(a)
[[[0 1]
  [2 3]]

 [[4 5]
  [6 7]]]
>>> np.transpose(a, [1,2,0])
array([[[0, 4],
        [1, 5]],

       [[2, 6],
        [3, 7]]])
np.einsumの場合
>>> a = np.arange(8).reshape(2,2,2)
>>> print(a)
array([[[0, 1],
        [2, 3]],

       [[4, 5],
        [6, 7]]])
>>> np.einsum("ijk->jki",a)
array([[[0, 4],
        [1, 5]],

       [[2, 6],
        [3, 7]]])
torch.transposeの場合
>>> a = torch.arange(8).reshape(2,2,2)
>>> print(a)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
>>> torch.transpose(a, 1, 2).transpose(0,2)
tensor([[[0, 4],
         [1, 5]],

        [[2, 6],
         [3, 7]]])
torch.einsumの場合
>>> a = torch.arange(8).reshape(2,2,2)
>>> print(a)
tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])
>>> torch.einsum("ijk->jki",a)
tensor([[[0, 4],
         [1, 5]],

        [[2, 6],
         [3, 7]]])

速度が違います

colaboratoryにて実行
def run_np_transpose(tensor):
  np.transpose(tensor, [1, 2, 0])

def run_np_einsum(tensor):
  np.einsum("ijk->jki", tensor)

def run_torch_transpose(tensor):
  torch.transpose(tensor, 1, 2).transpose(0,2)

def run_torch_einsum(tensor):
  torch.einsum("ijk->jki", tensor)

tensor_numpy = np.arange(8).reshape(2,2,2)
tensor_torch = torch.arange(8).reshape(2,2,2)
np.transpose
%%timeit
run_np_transpose(tensor_numpy)
The slowest run took 45.15 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 915 ns per loop
np.einsum
%%timeit
run_np_einsum(tensor_numpy)
The slowest run took 54.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.07 µs per loop
torch.transpose
%%timeit
run_torch_transpose(tensor_torch)
The slowest run took 96.76 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.17 µs per loop
torch.einsum
%%timeit
run_torch_einsum(tensor_torch)
The slowest run took 122.75 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 5.57 µs per loop

結論

  • np.transpose > np.einsum > torch.transpose > torch.einsum
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?