なにをしたいか
- 高階テンソルの軸を入れ替えたい
- np.transposeのように、軸をリストで与えて軸を入れ替えたい
- アインシュタインの縮約記法なら"それっぽく"できそう
- 結論:遅いのでやめたほうがいい
追記2020.1.15
- torch.Tensor.permuteがnp.transposeと同じように使える
transposeとpermuteの比較
# np.transpose
>>> a = np.arange(8).reshape(2,2,2)
>>> a.transpose(1,2,0)
array([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
# torch.Tensor.permute
>>> a = torch.arange(8).reshape(2,2,2)
>>> a.permute(1,2,0)
tensor([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
コード
numpy.transposeの場合
>>> a = np.arange(8).reshape(2,2,2)
>>> print(a)
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
>>> np.transpose(a, [1,2,0])
array([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
np.einsumの場合
>>> a = np.arange(8).reshape(2,2,2)
>>> print(a)
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.einsum("ijk->jki",a)
array([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
torch.transposeの場合
>>> a = torch.arange(8).reshape(2,2,2)
>>> print(a)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> torch.transpose(a, 1, 2).transpose(0,2)
tensor([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
torch.einsumの場合
>>> a = torch.arange(8).reshape(2,2,2)
>>> print(a)
tensor([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> torch.einsum("ijk->jki",a)
tensor([[[0, 4],
[1, 5]],
[[2, 6],
[3, 7]]])
速度が違います
colaboratoryにて実行
def run_np_transpose(tensor):
np.transpose(tensor, [1, 2, 0])
def run_np_einsum(tensor):
np.einsum("ijk->jki", tensor)
def run_torch_transpose(tensor):
torch.transpose(tensor, 1, 2).transpose(0,2)
def run_torch_einsum(tensor):
torch.einsum("ijk->jki", tensor)
tensor_numpy = np.arange(8).reshape(2,2,2)
tensor_torch = torch.arange(8).reshape(2,2,2)
np.transpose
%%timeit
run_np_transpose(tensor_numpy)
The slowest run took 45.15 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 915 ns per loop
np.einsum
%%timeit
run_np_einsum(tensor_numpy)
The slowest run took 54.64 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 1.07 µs per loop
torch.transpose
%%timeit
run_torch_transpose(tensor_torch)
The slowest run took 96.76 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.17 µs per loop
torch.einsum
%%timeit
run_torch_einsum(tensor_torch)
The slowest run took 122.75 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 5.57 µs per loop
結論
- np.transpose > np.einsum > torch.transpose > torch.einsum