180
138

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【PyTorch】Tensorを操作する関数(transpose、view、reshape)

Last updated at Posted at 2019-02-25

PyTorch1でTensorを扱う際、transpose、view、reshapeはよく使われる関数だと思います。
それぞれTensorのサイズ数(次元)を変更する関数ですが、機能は少しずつ異なります。

そもそも、PyTorchのTensorとは何ぞや?という方はチュートリアルをご覧下さい。
簡単に言うと、numpyとほぼ同じで、GPUに載るか載らないかの違いです。

transpose

まず、最も基本的な関数はtransposeでしょう。
その名の通り、Tensorを転置するだけです。
torch.transpose(x, 0, 1)torch.t(x)と略することもできます。

>>> import torch
>>> x = torch.randn(4, 3)

>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
        [ 1.8057,  0.7966, -0.6941],
        [-1.3884, -2.0070, -0.2932],
        [-0.6781, -0.0142,  0.8535]])

>>> torch.transpose(x, 0, 1)
tensor([[ 0.2062,  1.8057, -1.3884, -0.6781],
        [-1.0431,  0.7966, -2.0070, -0.0142],
        [-0.5528, -0.6941, -0.2932,  0.8535]])

>>> torch.t(x) # これでもOK
tensor([[ 0.2062,  1.8057, -1.3884, -0.6781],
        [-1.0431,  0.7966, -2.0070, -0.0142],
        [-0.5528, -0.6941, -0.2932,  0.8535]])

view

viewもよく使われる関数です。
1つ目の引数に-1を入れることで、2つ目の引数で指定した値にサイズ数を自動的に調整してくれます。
Tensorの要素数が指定したサイズ数に合わない(割り切れない)場合、エラーになります。
もちろん、サイズ数を指定することもできます。

>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
        [ 1.8057,  0.7966, -0.6941],
        [-1.3884, -2.0070, -0.2932],
        [-0.6781, -0.0142,  0.8535]])

>>> x.view(-1, 2) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431],
        [-0.5528,  1.8057],
        [ 0.7966, -0.6941],
        [-1.3884, -2.0070],
        [-0.2932, -0.6781],
        [-0.0142,  0.8535]])

>>> x.view(-1, 6) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431, -0.5528,  1.8057,  0.7966, -0.6941],
        [-1.3884, -2.0070, -0.2932, -0.6781, -0.0142,  0.8535]])

>>> x.view(-1, 5) # Tensorの要素数が指定したサイズ数に合わない
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: size '[-1 x 5]' is invalid for input with 12 elements at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/THStorage.cpp:80

>>> x.view(3, 4) # サイズ数を指定
tensor([[ 0.2062, -1.0431, -0.5528,  1.8057],
        [ 0.7966, -0.6941, -1.3884, -2.0070],
        [-0.2932, -0.6781, -0.0142,  0.8535]])

ただ、viewは一つだけ注意点があります。
それは、viewでサイズ数を変更するTensorの各要素は、メモリ上でも要素順に並んでいなければならないということです。
例えば、転置したTensorに対してviewでサイズ数を変更したい場合、そのまま実行すると下記のようにエラーになります。

>>> torch.t(x).view(-1, 2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view(). at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensor.cpp:237

そこで、viewの前にcontiguous()を付ければメモリ上で要素順に並び、上記のエラーを回避できます。
また、後述するreshapeを使うこともできます。

>>> torch.t(x).contiguous().view(-1, 2)
tensor([[ 0.2062,  1.8057],
        [-1.3884, -0.6781],
        [-1.0431,  0.7966],
        [-2.0070, -0.0142],
        [-0.5528, -0.6941],
        [-0.2932,  0.8535]])

>>> torch.t(x).reshape(-1, 2) # これでもOK
tensor([[ 0.2062,  1.8057],
        [-1.3884, -0.6781],
        [-1.0431,  0.7966],
        [-2.0070, -0.0142],
        [-0.5528, -0.6941],
        [-0.2932,  0.8535]])

reshape

reshapeはviewと書き方が少し異なるものの、ほぼ同じ働きをします。
サイズ数を自動的に調整してくれることも、指定することもできます。

>>> x
tensor([[ 0.2062, -1.0431, -0.5528],
        [ 1.8057,  0.7966, -0.6941],
        [-1.3884, -2.0070, -0.2932],
        [-0.6781, -0.0142,  0.8535]])

>>> torch.reshape(x, (-1, 2)) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431],
        [-0.5528,  1.8057],
        [ 0.7966, -0.6941],
        [-1.3884, -2.0070],
        [-0.2932, -0.6781],
        [-0.0142,  0.8535]])

>>> torch.reshape(x, (-1, 6)) # サイズ数を自動的に調整してくれる
tensor([[ 0.2062, -1.0431, -0.5528,  1.8057,  0.7966, -0.6941],
        [-1.3884, -2.0070, -0.2932, -0.6781, -0.0142,  0.8535]])

>>> torch.reshape(x, (-1, 5)) # Tensorの要素数が指定したサイズ数に合わない
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: shape '[-1, 5]' is invalid for input of size 12

>>> torch.reshape(x, (3, 4)) # サイズ数を指定
tensor([[ 0.2062, -1.0431, -0.5528,  1.8057],
        [ 0.7966, -0.6941, -1.3884, -2.0070],
        [-0.2932, -0.6781, -0.0142,  0.8535]])

reshapeがviewと異なるのは、サイズ数を変更するTensorの各要素がメモリ上で要素順に並んでいない場合の挙動です。
メモリ上で要素順に並んでいる場合はviewと同一の挙動になるのに対し、メモリ上で要素順に並んでいない場合は物理的なコピーを作ります。

>>> torch.reshape(torch.t(x), (-1, 2)) # 物理的なコピーを作成
tensor([[ 0.2062,  1.8057],
        [-1.3884, -0.6781],
        [-1.0431,  0.7966],
        [-2.0070, -0.0142],
        [-0.5528, -0.6941],
        [-0.2932,  0.8535]])

>>> torch.t(x).reshape(-1, 2) # これでもOK
tensor([[ 0.2062,  1.8057],
        [-1.3884, -0.6781],
        [-1.0431,  0.7966],
        [-2.0070, -0.0142],
        [-0.5528, -0.6941],
        [-0.2932,  0.8535]])

reshapeはPyTorchのversion 0.4で導入された関数で、numpy.reshapeと同じ機能を持つ関数として創られたようです。

  1. Pythonで書かれたディープラーニング(深層学習)のフレームワーク

180
138
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
180
138

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?