環境
python3
pytorch 1.3.0
準備
>>> import torch
>>> a = torch.rand(3,3)
>>> a
tensor([[0.8395, 0.2678, 0.5021],
[0.1397, 0.9839, 0.8991],
[0.6298, 0.6101, 0.6841]])
torch.sortで並べ替える
# 指定した次元の方向にsortする
# sortする次元はdimで指定
>>> sorted, idx = torch.sort(a, dim = -1) # デフォルトはdim = -1 (-1は最後の次元. 3次元なら3と書くのと同じ)
# 戻り値は2つ
>>> sorted # 並び替えられたtensor
tensor([[0.2678, 0.5021, 0.8395],
[0.1397, 0.8991, 0.9839],
[0.6101, 0.6298, 0.6841]])
>>> idx # もとの行列をどういう順番に並び替えたのかの情報(後述のtorch.argsortの戻り値と同じ)
tensor([[1, 2, 0],
[0, 2, 1],
[1, 0, 2]])
# 例えば今回の例では, aとsortedの関係は, idxを使って
# sorted[0] = a[0,1], a[0,2], a[0,0]
# sorted[1] = a[1,0], a[1,2], a[1,1]
# sorted[2] = a[2,1], a[2,0], a[2,2]
# 降順に並び替える. descending
>>> sorted, idx = torch.sort(a, descending = True)
>>> sorted
tensor([[0.8395, 0.5021, 0.2678],
[0.9839, 0.8991, 0.1397],
[0.6841, 0.6298, 0.6101]]
torch.argsort: どういう風に並べ替えたのかの情報を得る
# torch.sortの2つめの戻り値と同じ
>>> idx = torch.argsort(a)
tensor([[1, 2, 0],
[0, 2, 1],
[1, 0, 2]])
torch.gather: argsortの結果を使って, 別のtensorを同じように並べ替える
# tensor a と同じサイズのtensor bを用意
>>> b = torch.rand_like(a) # 3*3 random
# まずtensor aを並べ替えます
>>> sorted, idx = torch.sort(a, dim = -1)
# torch.gather(input, dim, idx)
# tensor a を並べ替えたときの情報idxを用いて, aと同じ順でbを並べかえ
>>> torch.gather(b, -1, idx) # sort時と並べ替える方向は合わせてdim=-1
# つまりtorch.gatherでaをidxを用いて並べ替えた結果はsortedと同じ
>>> torch.gather(a, -1, idx)
tensor([[0.2678, 0.5021, 0.8395],
[0.1397, 0.8991, 0.9839],
[0.6101, 0.6298, 0.6841]])
sortした配列をもとに戻す
# idxをargsortした結果(idxのidx)をgatherに突っ込む
>>> invidx = torch.argsort(idx) # idxのidx
>>> torch.gather(sorted, -1, invidx)
tensor([[0.8395, 0.2678, 0.5021],
[0.1397, 0.9839, 0.8991],
[0.6298, 0.6101, 0.6841]]) # aと同じになる