PyTorch1でTensorを扱う際、transpose、view、reshapeはよく使われる関数だと思います。
それぞれTensorのサイズ数(次元)を変更する関数ですが、機能は少しずつ異なります。
そもそも、PyTorchのTensorとは何ぞや?という方はチュートリアルをご覧下さい。
簡単に言うと、numpyとほぼ同じで、GPUに載るか載らないかの違いです。
transpose
まず、最も基本的な関数はtransposeでしょう。
その名の通り、Tensorを転置するだけです。
torch.transpose(x, 0, 1)
はtorch.t(x)
と略することもできます。
>>> import torch
>>> x = torch.randn(4, 3)
>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
[ 1.8057, 0.7966, -0.6941],
[-1.3884, -2.0070, -0.2932],
[-0.6781, -0.0142, 0.8535]])
>>> torch.transpose(x, 0, 1)
tensor([[ 0.2062, 1.8057, -1.3884, -0.6781],
[-1.0431, 0.7966, -2.0070, -0.0142],
[-0.5528, -0.6941, -0.2932, 0.8535]])
>>> torch.t(x) # これでもOK
tensor([[ 0.2062, 1.8057, -1.3884, -0.6781],
[-1.0431, 0.7966, -2.0070, -0.0142],
[-0.5528, -0.6941, -0.2932, 0.8535]])
view
viewもよく使われる関数です。
1つ目の引数に-1
を入れることで、2つ目の引数で指定した値にサイズ数を自動的に調整してくれます。
Tensorの要素数が指定したサイズ数に合わない(割り切れない)場合、エラーになります。
もちろん、サイズ数を指定することもできます。
>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
[ 1.8057, 0.7966, -0.6941],
[-1.3884, -2.0070, -0.2932],
[-0.6781, -0.0142, 0.8535]])
>>> x.view(-1, 2) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431],
[-0.5528, 1.8057],
[ 0.7966, -0.6941],
[-1.3884, -2.0070],
[-0.2932, -0.6781],
[-0.0142, 0.8535]])
>>> x.view(-1, 6) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431, -0.5528, 1.8057, 0.7966, -0.6941],
[-1.3884, -2.0070, -0.2932, -0.6781, -0.0142, 0.8535]])
>>> x.view(-1, 5) # Tensorの要素数が指定したサイズ数に合わない
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: size '[-1 x 5]' is invalid for input with 12 elements at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/THStorage.cpp:80
>>> x.view(3, 4) # サイズ数を指定
tensor([[ 0.2062, -1.0431, -0.5528, 1.8057],
[ 0.7966, -0.6941, -1.3884, -2.0070],
[-0.2932, -0.6781, -0.0142, 0.8535]])
ただ、viewは一つだけ注意点があります。
それは、viewでサイズ数を変更するTensorの各要素は、メモリ上でも要素順に並んでいなければならないということです。
例えば、転置したTensorに対してviewでサイズ数を変更したい場合、そのまま実行すると下記のようにエラーになります。
>>> torch.t(x).view(-1, 2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensor.cpp:237
そこで、viewの前にcontiguous()
を付ければメモリ上で要素順に並び、上記のエラーを回避できます。
また、後述するreshapeを使うこともできます。
>>> torch.t(x).contiguous().view(-1, 2)
tensor([[ 0.2062, 1.8057],
[-1.3884, -0.6781],
[-1.0431, 0.7966],
[-2.0070, -0.0142],
[-0.5528, -0.6941],
[-0.2932, 0.8535]])
>>> torch.t(x).reshape(-1, 2) # これでもOK
tensor([[ 0.2062, 1.8057],
[-1.3884, -0.6781],
[-1.0431, 0.7966],
[-2.0070, -0.0142],
[-0.5528, -0.6941],
[-0.2932, 0.8535]])
reshape
reshapeはviewと書き方が少し異なるものの、ほぼ同じ働きをします。
サイズ数を自動的に調整してくれることも、指定することもできます。
>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
[ 1.8057, 0.7966, -0.6941],
[-1.3884, -2.0070, -0.2932],
[-0.6781, -0.0142, 0.8535]])
>>> torch.reshape(x, (-1, 2)) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431],
[-0.5528, 1.8057],
[ 0.7966, -0.6941],
[-1.3884, -2.0070],
[-0.2932, -0.6781],
[-0.0142, 0.8535]])
>>> torch.reshape(x, (-1, 6)) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431, -0.5528, 1.8057, 0.7966, -0.6941],
[-1.3884, -2.0070, -0.2932, -0.6781, -0.0142, 0.8535]])
>>> torch.reshape(x, (-1, 5)) # Tensorの要素数が指定したサイズ数に合わない
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: shape '[-1, 5]' is invalid for input of size 12
>>> torch.reshape(x, (3, 4)) # サイズ数を指定
tensor([[ 0.2062, -1.0431, -0.5528, 1.8057],
[ 0.7966, -0.6941, -1.3884, -2.0070],
[-0.2932, -0.6781, -0.0142, 0.8535]])
reshapeがviewと異なるのは、サイズ数を変更するTensorの各要素がメモリ上で要素順に並んでいない場合の挙動です。
メモリ上で要素順に並んでいる場合はviewと同一の挙動になるのに対し、メモリ上で要素順に並んでいない場合は物理的なコピーを作ります。
>>> torch.reshape(torch.t(x), (-1, 2)) # 物理的なコピーを作成
tensor([[ 0.2062, 1.8057],
[-1.3884, -0.6781],
[-1.0431, 0.7966],
[-2.0070, -0.0142],
[-0.5528, -0.6941],
[-0.2932, 0.8535]])
>>> torch.t(x).reshape(-1, 2) # これでもOK
tensor([[ 0.2062, 1.8057],
[-1.3884, -0.6781],
[-1.0431, 0.7966],
[-2.0070, -0.0142],
[-0.5528, -0.6941],
[-0.2932, 0.8535]])
reshapeはPyTorchのversion 0.4
で導入された関数で、numpy.reshape
と同じ機能を持つ関数として創られたようです。
-
Pythonで書かれたディープラーニング(深層学習)のフレームワーク ↩