0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

個包装だんごとしてみる torch.Tensor ― #2. torch.cat と torch.stack と x.repeat と x.unfold

0
Last updated at Posted at 2026-04-15

個包装だんごとしてみる torch.Tensor シリーズの目次はこちら

この記事での約束

0 階のテンソル x.shape == () を、箱に入っていないだんごとします (スカラー)。
1 階のテンソル x.shape == (a,) を、a 個入りだんごの箱とします。
2 階のテンソル x.shape == (b, a) を、b 袋の a 個入りだんごの箱とします。
3 階のテンソル x.shape == (c, b, a) を、c 袋の b 袋の a 個入りだんごの箱とします。

なので、2 階以上のテンソルでは、だんごは袋に包装されています。
r 階のテンソルにおいて、だんごは r - 1 重に袋に包装されています。
袋のことを外側から袋 0 、袋 1 、...、 袋 r - 2 と呼ぶことにします。

自然言語処理 (B, L, D) なら、箱はバッチ、袋 0 は文章、袋 1 は単語、だんごは単語分散ベクトルの各成分値です。画像処理 (B, C, H, W) なら、箱はバッチ、袋 0 は画像、袋 1 はある色チャネル、袋 2 は画像内のある行、だんごはピクセル値です。

image.png

torch.cat

torch.cat([x, y], dim=k) は、「箱 x に対して、各袋 k の続きに、対応する箱 y の袋 k を追加した箱」を返す関数です (k = r - 1 のときは袋でなくだんご)。※ torch.Tensor のメソッドではありません。

箱 x の形が (a_0, ..., a_k, ..., a_{r-1})、箱 y の形が (a_0, ..., b_k, ..., a_{r-1}) だったら、結果の箱の形は (a_0, ..., a_k + b_k, ..., a_{r-1}) となります。 なお、渡す箱の形どうしは、

  • k 階より小さい次元数列の完全一致が必要です (さもなくば対応がとれない)
  • k 階より大きい次元数列の完全一致が必要です (さもなくば続けて並べられない)
  • k 階の次元数だけは一致していなくても構いません。

例えば言語処理で 2 文を続けるときは torch.cat([x, y], dim=1) となります (箱 x の袋 1 たち=単語たちに箱 y の袋 1 たち=単語たちを続ける)。

なお、torch.cat は 3 箱以上渡しても処理できます (k 階以外の次元数の一致が必要です)。

image.png

dim=0
import torch

# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8., 9.], [10., 11., 12.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 2, 3)

# 箱 x の袋 0 の続きに箱 y の袋 0 を続ける
z = torch.cat([x, y], dim=0)
assert z.shape == (2, 2, 3)
assert torch.allclose(z, torch.tensor([
    [[1., 2., 3.], [4., 5., 6.]],
    [[7., 8., 9.], [10., 11., 12.]],
]))
dim=1
import torch

# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 1 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8., 9.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 1, 3)

# 箱 x の袋 1 の続きに箱 y の袋 1 を続ける
z = torch.cat([x, y], dim=1)
assert z.shape == (1, 3, 3)
assert torch.allclose(z, torch.tensor([[
    [1., 2., 3.], [4., 5., 6.], [7., 8., 9.],
]]))
dim=2
import torch

# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
# 箱 y: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 2 個入りのだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
y = torch.tensor([[[7., 8.], [9., 10.]]])
assert x.shape == (1, 2, 3)
assert y.shape == (1, 2, 2)

# 箱 x のだんごの続きに箱 y のだんごを続ける
z = torch.cat([x, y], dim=2)
assert z.shape == (1, 2, 5)
assert torch.allclose(z, torch.tensor([[
    [1., 2., 3., 7., 8.],
    [4., 5., 6., 9., 10.],
]]))

torch.stack

torch.stack([x, y], dim=k) は、「箱 x に対して、袋 k たちを包んで、その続きに箱 y の袋 k たちを包んだものを並べる」関数です (k = r - 1 のときは袋でなくだんご)。※ torch.Tensor のメソッドではありません。箱 x 由来の袋たちと箱 y 由来の袋たちを包む点が torch.cat と異なります。つまり、 r - 1 重包装 (r 階) から r 重包装 (r + 1 階) になります。

箱 x の形と箱 y の形は完全に一致していなければなりません (k 階より小さい次元数列が不一致なら対応がとれず、k 階以上の次元数が不一致なら並べられない)

箱 x の形と箱 y の形が (a_0, ..., a_k, ..., a_{r-1}) だったら、結果の箱の形は (a_0, ..., 2, a_k, ..., a_{r-1}) となります (箱 x と箱 y の袋 k たちが包まれて並ぶので k 階の次元数が 2 になり、それ以降の袋が 1 ずつ内側にずれます)

なお、torch.stack も 3 箱以上渡しても処理できます (次元数の完全一致が必要です)。

image.png

x.repeat

形が (c, b, a) である箱 xx.repeat(l, m, n) を適用すると、形が (l*c, m*b, n*a) になります。つまり、a 個並んだだんごが n 回繰り返され、b 袋並んだ袋 1 が m 回繰り返され、c 袋並んだ袋 0 が l 回繰り返されます。これは参照ではなくコピーされます (だんごが増えます!)。

import torch

# 箱 x: 袋 0 が 1 袋あって、袋 1 が 2 袋あって、袋 1 内に 3 個入りのだんご
x = torch.tensor([
    [[1., 2., 3.], [4., 5., 6.]],
])
assert x.shape == (1, 2, 3)

# 袋 1 内の 3 個入りだんごを 2 回繰り返し、2 袋並んだ袋 1 を 4 回繰り返す
x = x.repeat(1, 4, 2)
assert x.shape == (1, 8, 6)
assert torch.allclose(x, torch.tensor([[
    [1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
    [1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
    [1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
    [1., 2., 3., 1., 2., 3.], [4., 5., 6., 4., 5., 6.],
]]))

x.mean() で平均した袋 (だんご) を元の個数に複製したいときなどに使えます。

import torch

# 箱の中の 1 袋の 2 袋の 3 個入りだんご
x = torch.tensor([[[1., 2., 3.], [4., 5., 6.]]])
assert x.shape == (1, 2, 3)

# 袋 1 たちを平均的なただ 1 つの袋にする
x = x.mean(dim=1, keepdim=True)
assert x.shape == (1, 1, 3)
assert torch.allclose(x, torch.tensor([[[2.5, 3.5, 4.5]]]))

# 袋 1 たちを平均的なただ 1 つの袋にして、元々の袋 1 の個数コピーする
x = x.repeat(1, 2, 1)
assert x.shape == (1, 2, 3)
assert torch.allclose(x, torch.tensor([[[2.5, 3.5, 4.5], [2.5, 3.5, 4.5]]]))

x.unfold

x.unfold(k, size, step) は、

  • 各だんごを、そのだんごを先頭にした「だんごが size 個入りの袋」にしてしまいます。2 個目以降のだんごは「次の袋 k」、「次の次の袋 k」から連れてきます。ただし、コピーするわけではないので、見せかけ (view) のだんごです。本当のだんごは増えません。
  • そうやってだんごを見せかけのだんご袋にしていくとき、次の袋 k に進むときは step 進みます (つまり、step=1 なら普通に次の袋 k に進み、step=2 なら 1 つとばします)。かつ、袋 k の最後の size - 1 袋は無視します (ここらのだんごは既に連れてこられたので)。

結果、だんご箱のサイズは (a_0, a_1, ..., a_k, ..., a_{r-1}) から、(a_0, a_1, ..., floor((a_k - size)/step) + 1, ..., a_{r-1}, size) になります。

  • 例えば自然言語処理 (B, L, D).unfold(1, 2, 1) すると (B, N, D, 2) となり、これは各バッチ内の文章から取れた N 個の 2-gram です。

image.png

import torch

# 箱 x: 袋 0 が 5 袋あって、袋 1 が 5 袋あって、袋 1 内に 4 個入りのだんご
x = torch.tensor([
    [[ 1,  2,  3,  4], [ 5,  6,  7,  8], [ 9, 10, 11, 12], [13, 14, 15, 16], [17, 18, 19, 20]],
    [[21, 22, 23, 24], [25, 26, 27, 28], [29, 30, 31, 32], [33, 34, 35, 36], [37, 38, 39, 40]],
    [[41, 42, 43, 44], [45, 46, 47, 48], [49, 50, 51, 52], [53, 54, 55, 56], [57, 58, 59, 60]],
    [[61, 62, 63, 64], [65, 66, 67, 68], [69, 70, 71, 72], [73, 74, 75, 76], [77, 78, 79, 80]],
    [[81, 82, 83, 84], [85, 86, 87, 88], [89, 90, 91, 92], [93, 94, 95, 96], [97, 98, 99,  0]],
])
assert x.shape == (5, 5, 4)

# dimension=0, size=3, step=2
# 各だんごに「次の袋 0 の同じ位置のだんご」「次の次の袋 0 の同じ位置のだんご」が付いてきます
# そして「次の袋 0」に行くとき 2 つ先に進みます
x0 = x.unfold(0, size=3, step=2)
assert x0.shape == (2, 5, 4, 3)
assert torch.allclose(x0[0, 0, 0], torch.tensor([ 1, 21, 41]))  # だんご 1 におまけが付いてきます
assert torch.allclose(x0[0, 0, 1], torch.tensor([ 2, 22, 42]))  # だんご 2 におまけが付いてきます
assert torch.allclose(x0[0, 0, 2], torch.tensor([ 3, 23, 43]))  # だんご 3 におまけが付いてきます
assert torch.allclose(x0[0, 0, 3], torch.tensor([ 4, 24, 44]))  # だんご 4 におまけが付いてきます
assert torch.allclose(x0[0, 1, 0], torch.tensor([ 5, 25, 45]))  # 次の袋 1 に進んでも付いてきます
assert torch.allclose(x0[0, 1, 1], torch.tensor([ 6, 26, 46]))
assert torch.allclose(x0[1, 0, 0], torch.tensor([41, 61, 81]))  # 次の袋 0 に進むとき step 進みます
assert torch.allclose(x0[1, 0, 1], torch.tensor([42, 62, 82]))
assert torch.allclose(x0[1, 4, 3], torch.tensor([60, 80,  0]))

# dimension=1, size=3, step=2
# 各だんごに「次の袋 1 の同じ位置のだんご」「次の次の袋 1 の同じ位置のだんご」が付いてきます
# そして「次の袋 1」に行くとき 2 つ先に進みます
x1 = x.unfold(1, size=3, step=2)
assert x1.shape == (5, 2, 4, 3)
assert torch.allclose(x1[0, 0, 0], torch.tensor([ 1,  5,  9]))  # だんご 1 におまけが付いてきます
assert torch.allclose(x1[0, 0, 1], torch.tensor([ 2,  6, 10]))  # だんご 2 におまけが付いてきます
assert torch.allclose(x1[0, 0, 2], torch.tensor([ 3,  7, 11]))  # だんご 3 におまけが付いてきます
assert torch.allclose(x1[0, 0, 3], torch.tensor([ 4,  8, 12]))  # だんご 4 におまけが付いてきます
assert torch.allclose(x1[0, 1, 0], torch.tensor([ 9, 13, 17]))  # 次の袋 1 に進むとき step 進みます
assert torch.allclose(x1[0, 1, 1], torch.tensor([10, 14, 18]))
assert torch.allclose(x1[1, 0, 0], torch.tensor([21, 25, 29]))
assert torch.allclose(x1[1, 0, 1], torch.tensor([22, 26, 30]))
assert torch.allclose(x1[4, 1, 3], torch.tensor([92, 96,  0]))

# dimension=2, size=3, step=2
# 各だんごに「次のだんご」「次の次のだんご」が付いてきます
# そして「次の袋だんご」に行くとき 2 つ先に進みます (が、4 個入りだんごで 2 つ先に進んだら
# 次の次のだんごまでとれないので、3 階の次元数は 1 だけでもう終わり、次の袋 1 に進みます)
x2 = x.unfold(2, size=3, step=2)
assert x2.shape == (5, 5, 1, 3)
assert torch.allclose(x2, torch.tensor([
    [[[ 1,  2,  3]], [[ 5,  6,  7]], [[ 9, 10, 11]], [[13, 14, 15]], [[17, 18, 19]]],
    [[[21, 22, 23]], [[25, 26, 27]], [[29, 30, 31]], [[33, 34, 35]], [[37, 38, 39]]],
    [[[41, 42, 43]], [[45, 46, 47]], [[49, 50, 51]], [[53, 54, 55]], [[57, 58, 59]]],
    [[[61, 62, 63]], [[65, 66, 67]], [[69, 70, 71]], [[73, 74, 75]], [[77, 78, 79]]],
    [[[81, 82, 83]], [[85, 86, 87]], [[89, 90, 91]], [[93, 94, 95]], [[97, 98, 99]]],
]))
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?