個包装だんごとしてみる torch.Tensor シリーズの目次はこちら
この記事での約束
0 階のテンソル x.shape == () を、箱に入っていないだんごとします (スカラー)。
1 階のテンソル x.shape == (a,) を、a 個入りだんごの箱とします。
2 階のテンソル x.shape == (b, a) を、b 袋の a 個入りだんごの箱とします。
3 階のテンソル x.shape == (c, b, a) を、c 袋の b 袋の a 個入りだんごの箱とします。
なので、2 階以上のテンソルでは、だんごは袋に包装されています。
r 階のテンソルにおいて、だんごは r - 1 重に袋に包装されています。
袋のことを外側から袋 0 、袋 1 、...、 袋 r - 2 と呼ぶことにします。
torch.cat
torch.cat([x, y], dim=k) は、「箱 x に対して、各袋 k の続きに、対応する箱 y の袋 k を追加した箱」を返す関数です (k = r - 1 のときは袋でなくだんご)。※ torch.Tensor のメソッドではありません。
箱 x の形が (a_0, ..., a_k, ..., a_{r-1})、箱 y の形が (a_0, ..., b_k, ..., a_{r-1}) だったら、結果の箱の形は (a_0, ..., a_k + b_k, ..., a_{r-1}) となります。 なお、渡す箱の形どうしは、
- k 階より小さい次元数列の完全一致が必要です (さもなくば対応がとれない)。
- k 階より大きい次元数列の完全一致が必要です (さもなくば続けて並べられない)。
- k 階の次元数だけは一致していなくても構いません。
例えば言語処理で 2 文を続けるときは torch.cat([x, y], dim=1) となります (箱 x の袋 1 たち=単語たちに箱 y の袋 1 たち=単語たちを続ける)。
なお、torch.cat は 3 箱以上渡しても処理できます (k 階以外の次元数の一致が必要です)。
import torch
# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8., 9.], [10., 11., 12.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 2, 3)
# 箱 x の袋 0 の続きに箱 y の袋 0 を続ける
z = torch.cat([x, y], dim=0)
assert z.shape == (2, 2, 3)
assert torch.allclose(z, torch.tensor([
[[1., 2., 3.], [4., 5., 6.]],
[[7., 8., 9.], [10., 11., 12.]],
]))
import torch
# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 1 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8., 9.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 1, 3)
# 箱 x の袋 1 の続きに箱 y の袋 1 を続ける
z = torch.cat([x, y], dim=1)
assert z.shape == (1, 3, 3)
assert torch.allclose(z, torch.tensor([[
[1., 2., 3.], [4., 5., 6.], [7., 8., 9.],
]]))
import torch
# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 2 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8.], [9., 10.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 2, 2)
# 箱 x のだんごの続きに箱 y のだんごを続ける
z = torch.cat([x, y], dim=2)
assert z.shape == (1, 2, 5)
assert torch.allclose(z, torch.tensor([[
[1., 2., 3., 7., 8.],
[4., 5., 6., 9., 10.],
]]))
torch.stack
torch.stack([x, y], dim=k) は、「箱 x に対して、袋 k たちを包んで、その続きに箱 y の袋 k たちを包んだものを並べる」関数です (k = r - 1 のときは袋でなくだんご)。※ torch.Tensor のメソッドではありません。箱 x 由来の袋たちと箱 y 由来の袋たちを包む点が torch.cat と異なります。つまり、 r - 1 重包装 (r 階) から r 重包装 (r + 1 階) になります。
箱 x の形と箱 y の形は完全に一致していなければなりません (k 階より小さい次元数列が不一致なら対応がとれず、k 階以上の次元数が不一致なら並べられない)。
箱 x の形と箱 y の形が (a_0, ..., a_k, ..., a_{r-1}) だったら、結果の箱の形は (a_0, ..., 2, a_k, ..., a_{r-1}) となります (箱 x と箱 y の袋 k たちが包まれて並ぶので k 階の次元数が 2 になり、それ以降の袋が 1 ずつ内側にずれます)。
なお、torch.stack も 3 箱以上渡しても処理できます (次元数の完全一致が必要です)。
x.repeat
形が (c, b, a) である箱 x に x.repeat(l, m, n) を適用すると、形が (l*c, m*b, n*a) になります。つまり、a 個並んだだんごが n 回繰り返され、b 袋並んだ袋 1 が m 回繰り返され、c 袋並んだ袋 0 が l 回繰り返されます。これは参照ではなくコピーされます (だんごが増えます!)。
import torch
# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([
[[1., 2., 3.], [4., 5., 6.]],
])
assert x.shape == (1, 2, 3)
# 袋 1 内の 3 個入りだんごを 2 回繰り返し、2 袋並んだ袋 1 を 4 回繰り返す
x = x.repeat(1, 4, 2)
assert x.shape == (1, 8, 6)
assert torch.allclose(x, torch.tensor([[
[1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
[1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
[1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
[1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
]]))
x.mean() で平均した袋 (だんご) を元の個数に複製したいときなどに使えます。
import torch
# 箱の中の 1 袋の 2 袋の 3 個入りだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
assert x.shape == (1, 2, 3)
# 袋 1 たちを平均的なただ 1 つの袋にする
x = x.mean(dim=1, keepdim=True)
assert x.shape == (1, 1, 3)
assert torch.allclose(x, torch.tensor([[[2.5, 3.5, 4.5]]]))
# 袋 1 たちを平均的なただ 1 つの袋にして、元々の袋 1 の個数コピーする
x = x.repeat(1, 2, 1)
assert x.shape == (1, 2, 3)
assert torch.allclose(x, torch.tensor([[[2.5, 3.5, 4.5], [2.5, 3.5, 4.5]]]))
x.unfold
x.unfold(k, size, step) は、
- 各だんごを、そのだんごを先頭にした「だんごが size 個入りの袋」にしてしまいます。2 個目以降のだんごは「次の袋 k」、「次の次の袋 k」から連れてきます。ただし、コピーするわけではないので、見せかけ (view) のだんごです。本当のだんごは増えません。
- そうやってだんごを見せかけのだんご袋にしていくとき、次の袋 k に進むときは step 進みます (つまり、step=1 なら普通に次の袋 k に進み、step=2 なら 1 つとばします)。かつ、袋 k の最後の size - 1 袋は無視します (ここらのだんごは既に連れてこられたので)。
結果、だんご箱のサイズは (a_0, a_1, ..., a_k, ..., a_{r-1}) から、(a_0, a_1, ..., floor((a_k - size)/step) + 1, ..., a_{r-1}, size) になります。
- 例えば自然言語処理
(B, L, D)で.unfold(1, 2, 1)すると(B, N, D, 2)となり、これは各バッチ内の文章から取れた N 個の 2-gram です。
import torch
# 箱 x: 袋 0 が 5 袋あって、袋 1 が 5 袋あって、袋 1 内に 4 個入りのだんご
x = torch.tensor([
[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [13, 14, 15, 16], [17, 18, 19, 20]],
[[21, 22, 23, 24], [25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36], [37, 38, 39, 40]],
[[41, 42, 43, 44], [45, 46, 47, 48], [49, 50, 51, 52], [53, 54, 55, 56], [57, 58, 59, 60]],
[[61, 62, 63, 64], [65, 66, 67, 68], [69, 70, 71, 72], [73, 74, 75, 76], [77, 78, 79, 80]],
[[81, 82, 83, 84], [85, 86, 87, 88], [89, 90, 91, 92], [93, 94, 95, 96], [97, 98, 99, 0]],
])
assert x.shape == (5, 5, 4)
# dimension=0, size=3, step=2
# 各だんごに「次の袋 0 の同じ位置のだんご」「次の次の袋 0 の同じ位置のだんご」が付いてきます
# そして「次の袋 0」に行くとき 2 つ先に進みます
x0 = x.unfold(0, size=3, step=2)
assert x0.shape == (2, 5, 4, 3)
assert torch.allclose(x0[0, 0, 0], torch.tensor([ 1, 21, 41])) # だんご 1 におまけが付いてきます
assert torch.allclose(x0[0, 0, 1], torch.tensor([ 2, 22, 42])) # だんご 2 におまけが付いてきます
assert torch.allclose(x0[0, 0, 2], torch.tensor([ 3, 23, 43])) # だんご 3 におまけが付いてきます
assert torch.allclose(x0[0, 0, 3], torch.tensor([ 4, 24, 44])) # だんご 4 におまけが付いてきます
assert torch.allclose(x0[0, 1, 0], torch.tensor([ 5, 25, 45])) # 次の袋 1 に進んでも付いてきます
assert torch.allclose(x0[0, 1, 1], torch.tensor([ 6, 26, 46]))
assert torch.allclose(x0[1, 0, 0], torch.tensor([41, 61, 81])) # 次の袋 0 に進むとき step 進みます
assert torch.allclose(x0[1, 0, 1], torch.tensor([42, 62, 82]))
assert torch.allclose(x0[1, 4, 3], torch.tensor([60, 80, 0]))
# dimension=1, size=3, step=2
# 各だんごに「次の袋 1 の同じ位置のだんご」「次の次の袋 1 の同じ位置のだんご」が付いてきます
# そして「次の袋 1」に行くとき 2 つ先に進みます
x1 = x.unfold(1, size=3, step=2)
assert x1.shape == (5, 2, 4, 3)
assert torch.allclose(x1[0, 0, 0], torch.tensor([ 1, 5, 9])) # だんご 1 におまけが付いてきます
assert torch.allclose(x1[0, 0, 1], torch.tensor([ 2, 6, 10])) # だんご 2 におまけが付いてきます
assert torch.allclose(x1[0, 0, 2], torch.tensor([ 3, 7, 11])) # だんご 3 におまけが付いてきます
assert torch.allclose(x1[0, 0, 3], torch.tensor([ 4, 8, 12])) # だんご 4 におまけが付いてきます
assert torch.allclose(x1[0, 1, 0], torch.tensor([ 9, 13, 17])) # 次の袋 1 に進むとき step 進みます
assert torch.allclose(x1[0, 1, 1], torch.tensor([10, 14, 18]))
assert torch.allclose(x1[1, 0, 0], torch.tensor([21, 25, 29]))
assert torch.allclose(x1[1, 0, 1], torch.tensor([22, 26, 30]))
assert torch.allclose(x1[4, 1, 3], torch.tensor([92, 96, 0]))
# dimension=2, size=3, step=2
# 各だんごに「次のだんご」「次の次のだんご」が付いてきます
# そして「次の袋だんご」に行くとき 2 つ先に進みます (が、4 個入りだんごで 2 つ先に進んだら
# 次の次のだんごまでとれないので、3 階の次元数は 1 だけでもう終わり、次の袋 1 に進みます)
x2 = x.unfold(2, size=3, step=2)
assert x2.shape == (5, 5, 1, 3)
assert torch.allclose(x2, torch.tensor([
[[[ 1, 2, 3]], [[ 5, 6, 7]], [[ 9, 10, 11]], [[13, 14, 15]], [[17, 18, 19]]],
[[[21, 22, 23]], [[25, 26, 27]], [[29, 30, 31]], [[33, 34, 35]], [[37, 38, 39]]],
[[[41, 42, 43]], [[45, 46, 47]], [[49, 50, 51]], [[53, 54, 55]], [[57, 58, 59]]],
[[[61, 62, 63]], [[65, 66, 67]], [[69, 70, 71]], [[73, 74, 75]], [[77, 78, 79]]],
[[[81, 82, 83]], [[85, 86, 87]], [[89, 90, 91]], [[93, 94, 95]], [[97, 98, 99]]],
]))



