4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

pytorchのtensorをsort, 同じ順で別のtensorを並べ替え, sortしたtensorを元に戻す.

Last updated at Posted at 2019-11-14

環境

python3
pytorch 1.3.0

準備

>>> import torch
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]])

torch.sortで並べ替える

# 指定した次元の方向にsortする
# sortする次元はdimで指定
>>> sorted, idx = torch.sort(a, dim = -1) # デフォルトはdim = -1 (-1は最後の次元. 3次元なら3と書くのと同じ)

# 戻り値は2つ
>>> sorted # 並び替えられたtensor
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]])

>>> idx # もとの行列をどういう順番に並び替えたのかの情報(後述のtorch.argsortの戻り値と同じ)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

# 例えば今回の例では, aとsortedの関係は, idxを使って
# sorted[0] = a[0,1], a[0,2], a[0,0]
# sorted[1] = a[1,0], a[1,2], a[1,1]
# sorted[2] = a[2,1], a[2,0], a[2,2]

# 降順に並び替える. descending
>>> sorted, idx = torch.sort(a, descending = True)
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
        [0.9839, 0.8991, 0.1397],
        [0.6841, 0.6298, 0.6101]]

torch.argsort: どういう風に並べ替えたのかの情報を得る

# torch.sortの2つめの戻り値と同じ
>>> idx = torch.argsort(a)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

torch.gather: argsortの結果を使って, 別のtensorを同じように並べ替える

# tensor a と同じサイズのtensor bを用意
>>> b = torch.rand_like(a) # 3*3 random

# まずtensor aを並べ替えます
>>> sorted, idx = torch.sort(a, dim = -1)

# torch.gather(input, dim, idx)
# tensor a を並べ替えたときの情報idxを用いて, aと同じ順でbを並べかえ
>>> torch.gather(b, -1, idx) # sort時と並べ替える方向は合わせてdim=-1

# つまりtorch.gatherでaをidxを用いて並べ替えた結果はsortedと同じ
>>> torch.gather(a, -1, idx) 
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]]) 

sortした配列をもとに戻す

# idxをargsortした結果(idxのidx)をgatherに突っ込む
>>> invidx = torch.argsort(idx) # idxのidx
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]]) # aと同じになる
4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?