個包装だんごとしてみる torch.Tensor シリーズの目次はこちら
説明
テンソル ― この記事では「だんごを袋で包んだものを入れた箱」から、特定のインデクスの袋 or だんごだけを選んで新しい箱にしたいことはあると思います。スライスやインデクシング (ブラケットにインデクス配列を指定すること) で対応できる場合はそれでよいですが、「選びたいだんごが規則的に並んでおらずインデクシングで取れない」とか、「流れてくるテンソルの階数が未確定のためブラケット表記では一括で扱えない」とかのときは、下記のようなメソッドを使うと思います。
-
torch.take(x, index)は箱内の通し番号の列 (1階のテンソル) を index に指定してだんごを注文します。結果の箱は、注文通りにだんごが並べられた 1 階のテンソルになります。 -
torch.index_select(x, dim, index)は「袋 dim をこの順に並べ替える」を指示するインデクスの列 (1階のテンソル) を index に指定してだんごを注文します。「並べ替え」といっても、すべての袋 dim を選ぶ必要はないし、重複して選んでも構いません。結果の箱は、dim 階の次元数以外は元の箱と同じサイズになり、dim 階の次元数は index に指定した列の長さになります。- dim = r - 1 のとき (r は階数) は、袋でなくだんごの並べ替えになります。
- すべての袋 dim (or だんご) に同じ並べ替えを適用したいときにしかつかえません。
-
torch.gather(x, dim, index)は元の箱の各だんごに対して「何番目の袋 dim のそれと入れ替える」を指示した元の箱と同階数のテンソルを index に指定してだんごを注文します。結果の箱は、この index と同じサイズになります。- dim = r - 1 のとき (r は階数) は、袋でなくだんごの並べ替えになります。
- index のサイズは dim 階以外の階の次元数以外は元の箱の対応する階の次元数と同じかより小さくしなければなりません (より小さくすると、単にもうだんごを選ばないことになり、元の箱よりサイズが縮みます)。dim 階は好きな長さにできます。
-
torch.index_selectと異なり、「最初の袋 dim - 1 内の袋 dim はこう並べて、その次の袋 dim - 1 内の袋 dim はこう並べて、…」といった制御が可能です。反面、各だんごをどう選ぶかをすべて指示し切る必要があるので、torch.index_selectがつかえる状況ならそちらのほうが楽なことが多いかもしれません。
-
torch.take_along_dim(x, dim, indices)は、dim=Noneとしたときはtorch.takeと同じ挙動になり、dim=0, ..., r - 1としたときはtorch.gatherと同じ挙動になります。- ただし、
torch.gatherと異なり、index の dim 階以外の階の次元数を 1 にすると、勝手にブロードキャストされます。つまり、torch.gatherにおいては次元数を 1 で止めることは「1 袋しか選ばない」を意味したのに対して、torch.take_along_dimにおいては「1 袋分の指示しか書いていないが、以下同様である」を意味します。
- ただし、
この記事での約束
このシリーズ全体の約束にあるイラストを見ていただくのが早いですが、言葉でいうと、形が (a_0, ..., a_k, ..., a_{r-1}) の r 階テンソルを、「だんごを a_{r-1} 個ずつ袋に包み、その袋を a_{r-2} 袋ずつ袋に包み、…、その袋を a_1 袋ずつ袋に包んだ a_0 袋を入れた箱」と捉え、袋を外側から袋 0 、袋 1 、...、 袋 r - 2 と呼びます。
さらにこの記事では、個々の袋 0 を順に袋 0-0, 袋 0-1, ... と呼び、個々の袋 0 の中の袋 1 を順に袋 1-0, 袋 1-1, ... と呼ぶことにします (袋 1-0 はどの袋 0 の中にもあるので、箱の中でユニークな呼び名ではないことに注意してください)。
torch.take
スクリプト
import torch
x = torch.tensor([
[ # 袋 0-0
[10, 11, 12], # 袋 0-0 の袋 1-0
[20, 21, 22], # 袋 0-0 の袋 1-1
],
[ # 袋 0-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
])
# 通し番号で 1, 3, 5, 8, 10, 11 番のだんごを選ぶ
y = torch.take(x, index=torch.tensor([1, 3, 5, 8, 10, 11]))
assert torch.allclose(y, torch.tensor([11, 20, 22, 32, 41, 42]))
torch.take_along_dim
スクリプト
import torch
x = torch.tensor([
[ # 袋 0-0
[10, 11, 12], # 袋 0-0 の袋 1-0
[20, 21, 22], # 袋 0-0 の袋 1-1
],
[ # 袋 0-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
])
y = torch.take_along_dim(x, dim=None, indices=torch.tensor([1, 3, 5, 8, 10, 11]))
assert torch.allclose(y, torch.tensor([11, 20, 22, 32, 41, 42]))
indices = torch.argsort(x, dim=0, descending=True)
assert torch.allclose(indices, torch.tensor([
[[1, 1, 1], [1, 1, 1]],
[[0, 0, 0], [0, 0, 0]],
]))
y = torch.take_along_dim(x, indices=indices, dim=0)
assert torch.allclose(y, torch.tensor([
[[30, 31, 32], [40, 41, 42]],
[[10, 11, 12], [20, 21, 22]],
]))
indices = torch.argsort(x, dim=1, descending=True)
assert torch.allclose(indices, torch.tensor([
[[1, 1, 1], [0, 0, 0]],
[[1, 1, 1], [0, 0, 0]],
]))
y = torch.take_along_dim(x, indices=indices, dim=1)
assert torch.allclose(y, torch.tensor([
[[20, 21, 22], [10, 11, 12]],
[[40, 41, 42], [30, 31, 32]],
]))
indices = torch.argsort(x, dim=2, descending=True)
assert torch.allclose(indices, torch.tensor([
[[2, 1, 0], [2, 1, 0]],
[[2, 1, 0], [2, 1, 0]],
]))
y = torch.take_along_dim(x, indices=indices, dim=2)
assert torch.allclose(y, torch.tensor([
[[12, 11, 10], [22, 21, 20]],
[[32, 31, 30], [42, 41, 40]],
]))
indices = torch.tensor([
[[2, 1, 2], [1, 1, 0]],
[[0, 0, 1], [2, 2, 2]],
])
y = torch.take_along_dim(x, indices=indices, dim=2)
assert torch.allclose(y, torch.tensor([
[[12, 11, 12], [21, 21, 20]],
[[30, 30, 31], [42, 42, 42]],
]))
スクリプト
import torch
x = torch.tensor([
[[10, 11, 12], [20, 21, 22]],
[[30, 31, 32], [40, 41, 42]],
])
indices = torch.tensor([
[[2, 1, 2], [1, 1, 0]],
[[0, 0, 1], [2, 2, 2]],
])
y = torch.take_along_dim(x, indices=indices, dim=2)
z = torch.gather(x, index=indices, dim=2)
assert torch.allclose(y, z)
assert torch.allclose(y, torch.tensor([
[[12, 11, 12], [21, 21, 20]],
[[30, 30, 31], [42, 42, 42]],
]))
indices = torch.tensor([[[2, 1, 2]]])
y = torch.take_along_dim(x, indices=indices, dim=2)
z = torch.gather(x, index=indices, dim=2)
assert not torch.allclose(y, z)
assert torch.allclose(y, torch.tensor([
[[12, 11, 12], [22, 21, 22]],
[[32, 31, 32], [42, 41, 42]],
]))
assert torch.allclose(z, torch.tensor([
[[12, 11, 12]],
]))
torch.index_select
スクリプト
import torch
x = torch.tensor([
[ # 袋 0-0
[10, 11, 12], # 袋 0-0 の袋 1-0
[20, 21, 22], # 袋 0-0 の袋 1-1
],
[ # 袋 0-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
])
y = torch.index_select(x, dim=0, index=torch.tensor([1]))
assert torch.allclose(y, torch.tensor([
[ # 袋 0-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
]))
y = torch.index_select(x, dim=1, index=torch.tensor([1, 0, 1]))
assert torch.allclose(y, torch.tensor([
[ # 袋 0-0
[20, 21, 22], # 袋 0-0 の袋 1-1
[10, 11, 12], # 袋 0-0 の袋 1-0
[20, 21, 22], # 袋 0-0 の袋 1-1
],
[ # 袋 0-1
[40, 41, 42], # 袋 0-1 の袋 1-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
]))
y = torch.index_select(x, dim=2, index=torch.tensor([2, 0, 0]))
assert torch.allclose(y, torch.tensor([
[ # 袋 0-0
[12, 10, 10], # 袋 0-0 の袋 1-0 の 2 番目と 0 番目と 0 番目のだんご
[22, 20, 20], # 袋 0-0 の袋 1-1 の 2 番目と 0 番目と 0 番目のだんご
],
[ # 袋 0-1
[32, 30, 30], # 袋 0-1 の袋 1-0 の 2 番目と 0 番目と 0 番目のだんご
[42, 40, 40], # 袋 0-1 の袋 1-1 の 2 番目と 0 番目と 0 番目のだんご
],
]))
torch.gather
スクリプト
import torch
v = torch.tensor([
[ # 袋 0-0
[10, 11, 12], # 袋 0-0 の袋 1-0
[20, 21, 22], # 袋 0-0 の袋 1-1
],
[ # 袋 0-1
[30, 31, 32], # 袋 0-1 の袋 1-0
[40, 41, 42], # 袋 0-1 の袋 1-1
],
])
# ----- dim=2: 「近隣のだんご」から好きなだんごを選ぶ -----
indices = torch.tensor([
[ # 袋 0-0
[0, 1, 2], # 袋 0-0 の袋 1-0: そのまま
[0, 1, 2], # 袋 0-0 の袋 1-1: そのまま
],
[ # 袋 0-1
[1, 2, 0], # 袋 0-1 の袋 1-0: 常に 1 つ後ろのだんごをとる
[2, 2, 2], # 袋 0-1 の袋 1-1: 末尾のだんごを繰り返す
],
])
assert torch.allclose(torch.gather(v, dim=2, index=indices), torch.tensor([
[ # 袋 0-0
[10, 11, 12],
[20, 21, 22],
],
[ # 袋 0-1
[31, 32, 30],
[42, 42, 42],
],
]))
# ----- dim=1: 「近隣の袋 1 の同じ位置のだんご」から好きなだんごを選ぶ -----
indices = torch.tensor([
[ # 袋 0-0
[0, 0, 0], # 袋 0-0 の袋 1-0: そのまま
[1, 1, 1], # 袋 0-0 の袋 1-1: そのまま
],
[ # 袋 0-1
[0, 1, 0], # 袋 0-1 の袋 1-0: 真ん中のだんごを袋 1-1 からとる
[0, 0, 0], # 袋 0-1 の袋 1-1: すべてのだんごを袋 1-0 からとる
],
])
assert torch.allclose(torch.gather(v, dim=1, index=indices), torch.tensor([
[ # 袋 0-0
[10, 11, 12],
[20, 21, 22],
],
[ # 袋 0-1
[30, 41, 32],
[30, 31, 32],
],
]))
# ----- dim=0: 「すべての袋 0 の同じ位置のだんご」から好きなだんごを選ぶ -----
indices = torch.tensor([
[ # 袋 0-0
[0, 0, 0], # 袋 0-0 の袋 1-0: そのまま
[0, 0, 0], # 袋 0-0 の袋 1-1: そのまま
],
[ # 袋 0-1
[1, 0, 1], # 袋 0-1 の袋 1-0: 真ん中のだんごを袋 0-0 からとる
[0, 0, 0], # 袋 0-1 の袋 1-1: すべてのだんごを袋 0-0 からとる
],
])
assert torch.allclose(torch.gather(v, dim=0, index=indices), torch.tensor([
[ # 袋 0-0
[10, 11, 12],
[20, 21, 22],
],
[ # 袋 0-1
[30, 11, 32],
[20, 21, 22],
],
]))