Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationEventAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

@dokkozo

pytorchのtensorをsort, 同じ順で別のtensorを並べ替え, sortしたtensorを元に戻す.

環境

python3
pytorch 1.3.0

準備

>>> import torch
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]])

torch.sortで並べ替える

# 指定した次元の方向にsortする
# sortする次元はdimで指定
>>> sorted, idx = torch.sort(a, dim = -1) # デフォルトはdim = -1 (-1は最後の次元. 3次元なら3と書くのと同じ)

# 戻り値は2つ
>>> sorted # 並び替えられたtensor
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]])

>>> idx # もとの行列をどういう順番に並び替えたのかの情報(後述のtorch.argsortの戻り値と同じ)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

# 例えば今回の例では, aとsortedの関係は, idxを使って
# sorted[0] = a[0,1], a[0,2], a[0,0]
# sorted[1] = a[1,0], a[1,2], a[1,1]
# sorted[2] = a[2,1], a[2,0], a[2,2]

# 降順に並び替える. descending
>>> sorted, idx = torch.sort(a, descending = True)
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
        [0.9839, 0.8991, 0.1397],
        [0.6841, 0.6298, 0.6101]]

torch.argsort: どういう風に並べ替えたのかの情報を得る

# torch.sortの2つめの戻り値と同じ
>>> idx = torch.argsort(a)
tensor([[1, 2, 0],
        [0, 2, 1],
        [1, 0, 2]])

torch.gather: argsortの結果を使って, 別のtensorを同じように並べ替える

# tensor a と同じサイズのtensor bを用意
>>> b = torch.rand_like(a) # 3*3 random

# まずtensor aを並べ替えます
>>> sorted, idx = torch.sort(a, dim = -1)

# torch.gather(input, dim, idx)
# tensor a を並べ替えたときの情報idxを用いて, aと同じ順でbを並べかえ
>>> torch.gather(b, -1, idx) # sort時と並べ替える方向は合わせてdim=-1

# つまりtorch.gatherでaをidxを用いて並べ替えた結果はsortedと同じ
>>> torch.gather(a, -1, idx) 
tensor([[0.2678, 0.5021, 0.8395],
        [0.1397, 0.8991, 0.9839],
        [0.6101, 0.6298, 0.6841]]) 

sortした配列をもとに戻す

# idxをargsortした結果(idxのidx)をgatherに突っ込む
>>> invidx = torch.argsort(idx) # idxのidx
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
        [0.1397, 0.9839, 0.8991],
        [0.6298, 0.6101, 0.6841]]) # aと同じになる
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
1
Help us understand the problem. What are the problem?