TL; DR
- nn.Conv[12]d は、実はたたみこみではない
- nn.Conv[12]d は相互相関
- ConvTranspose[12]d の演算が本来のたたみこみに相当
- "Transposed convolution" は「転置たたみこみ」と呼ばれているが、実は転置ではなく反転
本来のたたみこみ
離散時間信号 $h[n]$ $(n=0,\ldots,K-1)$ と $x[n]$ のたたみこみは
$$(h*x)[n]=\sum_{k=0}^{K-1}h[k]x[n-k]$$
で定義される1。
例として、信号 $x[n]=[\underline{0}, 0, 1, 0, 0, 0, 2, 0, 0, 0]$ をフィルタ $h[n]=[\underline{1}, 2, 3]$ に通すことを考えよう。このときのフィルタの出力は $h[n]$ と $x[n]$ のたたみこみ $(h*x)[n]=[\underline{0}, 0, 1, 2, 3, 0, 2, 4, 6, 0, 0, 0]$ となる。
>>> import numpy as np
>>> h = np.array([1., 2., 3.])
>>> x = np.array([0., 0., 1., 0., 0., 0., 2., 0., 0., 0])
>>> print(np.convolve(h, x))
[0. 0. 1. 2. 3. 0. 2. 4. 6. 0. 0. 0.]
たたみこみは、コピー&ペーストだと解釈できる。上の例では、フィルタのパターン 1, 2, 3 を信号の存在する位置にペーストすることを繰り返している。出力の前半に 1, 2, 3 が現れ、後半に 2, 4, 6 が現れているのは、信号の大きさに合わせてそれぞれ1倍、2倍されているからである。
ニューラルネットワークにおける"たたみこみ"
たたみこみニューラルネットワーク (CNN) では、次の演算を convolution と呼んでいる。2
$$\text{out}(N_i, C_{\text{out}}) = \text{bias}(C_{\text{out}}) +
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}}, k)
\star \text{input}(N_i, k)$$
簡単のため、バッチサイズが1で、入力も出力も1チャンネルの場合を考えると
$$\text{out} = \text{bias} + \text{weight}
\star \text{input}$$
となる。 $\star$ は相互相関演算子であり、次式で定義される。
$$(h\star x)[n]=\sum_{k=0}^{K-1}h[k]x[n+k]$$
これは、上で示した本来のたたみこみとは異なる。具体的には、信号 $x$ に対してフィルタが作用する向きが前後逆になっている。
>>> print(np.correlate(x, h))
[3. 2. 1. 0. 6. 4. 2. 0.]
相互相関の式において $h'[n]=h[K-1-n]$ として前後反転したフィルタを作り、$h[n]$ と置き換えると
$$(h'\star x)[n-K+1]=\sum_{k=0}^{K-1}h'[k]x[n-k]$$
となり、本来のたたみこみの結果を $K-1$ だけ左にずらしたものと一致する。
PyTorchにおけるconvolution
CNN の convolution が実際には相互相関であることを確かめよう。nn.Conv1d によって「たたみこみ層」を作り、フィルタの係数を無理矢理上の例に合わせる。バイアスもゼロにしておく。
>>> import torch
>>> h_ = torch.Tensor(h).view(1, 1, -1)
>>> x_ = torch.Tensor(x).view(1, 1, -1)
>>> conv = torch.nn.Conv1d(1, 1, kernel_size=3)
>>> conv.weight.data = h_
>>> conv.bias.data = torch.Tensor([0.])
>>> conv(x_)
tensor([[[3., 2., 1., 0., 6., 4., 2., 0.]]], grad_fn=<ConvolutionBackward0>)
NumPyで相互相関を求めた結果と一致した。
Transposed convolution との関係
ニューラルネットワークでは、「転置たたみこみ (transposed convolution)」と呼ばれる演算が登場する。Transposed convolution の要点は
- 0埋めによって入力が stride 分だけ拡大される
- カーネルの各軸を反転 (flip) したフィルタとの convolution (実際には相互相関) が演算される
と整理できる3。
Transposed convolution は専らアップサンプリングのために用いられるので、通常strideは2以上。stride=1 の場合は単にカーネルの各軸が反転した convolution と同等である。したがって、上の議論から明らかなように、これは本来のたたみこみに相当する。
>>> tconv = torch.nn.ConvTranspose1d(1, 1, kernel_size=3)
>>> tconv.weight.data = h_
>>> tconv.bias.data = torch.Tensor([0.])
>>> tconv(x_)
tensor([[[0., 0., 1., 2., 3., 0., 2., 4., 6., 0., 0., 0.]]],
grad_fn=<ConvolutionBackward0>)
2次元の場合
次元数が変わっても話は同じ。本来のたたみこみが
>>> from scipy.signal import convolve2d
>>> h = np.array([[1., 2., 3.], [5., 7., 11.], [13., 17., 19.]])
>>> x = torch.rand((2, 2))
>>> convolve2d(h, x)
array([[ 0.24204093, 1.03771567, 1.83339041, 1.66090143],
[ 1.96423751, 6.85992354, 10.57879001, 8.75817871],
[ 6.91669637, 21.03717667, 28.53072971, 20.30246735],
[ 9.80242705, 24.38078797, 29.44646275, 16.89864314]])
のとき、上下左右を反転させたフィルタは
>>> ht = np.flip(h, [0, 1]).copy()
>>> ht
array([[19., 17., 13.],
[11., 7., 5.],
[ 3., 2., 1.]])
となり、このフィルタとの 2d convolution は
>>> conv2 = torch.nn.Conv2d(1, 1, kernel_size=3, padding=2)
>>> conv2.weight.data = torch.Tensor(ht).view(1, 1, 3, 3)
>>> conv2.bias.data = torch.Tensor([0.])
>>> conv2(x[None, None, ...])
tensor([[[[ 0.2420, 1.0377, 1.8334, 1.6609],
[ 1.9642, 6.8599, 10.5788, 8.7582],
[ 6.9167, 21.0372, 28.5307, 20.3025],
[ 9.8024, 24.3808, 29.4465, 16.8986]]]],
grad_fn=<ConvolutionBackward0>)
で結果は一致する。さらに、反転する前の元のフィルタとの transposed convolution は
>>> tconv2 = torch.nn.ConvTranspose2d(1, 1, kernel_size=3)
>>> tconv2.weight.data = torch.Tensor(h).view(1, 1, 3, 3)
>>> tconv2.bias.data = torch.Tensor([0.])
>>> tconv2(x[None, None, ...])
tensor([[[[ 0.2420, 1.0377, 1.8334, 1.6609],
[ 1.9642, 6.8599, 10.5788, 8.7582],
[ 6.9167, 21.0372, 28.5307, 20.3025],
[ 9.8024, 24.3808, 29.4465, 16.8986]]]],
grad_fn=<ConvolutionBackward0>)
で、やはり一致する。
ところで、元のフィルタの転置は
>>> h.T
array([[ 1., 5., 13.],
[ 2., 7., 17.],
[ 3., 11., 19.]])
であって、上下左右反転したフィルタ ht とは異なる。Transposed convolution で実行されるのは反転であって、転置ではないのである。
結論
まぎらわしい。
-
樋口(監),川又・阿部・八巻(著)「MATLAB対応ディジタル信号処理 第2版」森北出版 (2021) 式(5.11), 式(9.1) ↩
-
これらは convolution の逆伝播計算を convolution 演算によって行うために必要な処理。Transposed convolution のルーツは convolution の逆伝播計算にある。参考: https://blog.cosnomi.com/posts/transposed-conv/) ↩