0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CNNのconvolutionは、たたみこみではない

Last updated at Posted at 2023-12-15

TL; DR

  • nn.Conv[12]d は、実はたたみこみではない
  • nn.Conv[12]d は相互相関
  • ConvTranspose[12]d の演算が本来のたたみこみに相当
  • "Transposed convolution" は「転置たたみこみ」と呼ばれているが、実は転置ではなく反転

本来のたたみこみ

離散時間信号 $h[n]$ $(n=0,\ldots,K-1)$ と $x[n]$ のたたみこみは
$$(h*x)[n]=\sum_{k=0}^{K-1}h[k]x[n-k]$$
で定義される1

例として、信号 $x[n]=[\underline{0}, 0, 1, 0, 0, 0, 2, 0, 0, 0]$ をフィルタ $h[n]=[\underline{1}, 2, 3]$ に通すことを考えよう。このときのフィルタの出力は $h[n]$ と $x[n]$ のたたみこみ $(h*x)[n]=[\underline{0}, 0, 1, 2, 3, 0, 2, 4, 6, 0, 0, 0]$ となる。

>>> import numpy as np
>>> h = np.array([1., 2., 3.])
>>> x = np.array([0., 0., 1., 0., 0., 0., 2., 0., 0., 0])
>>> print(np.convolve(h, x))
[0. 0. 1. 2. 3. 0. 2. 4. 6. 0. 0. 0.]

たたみこみは、コピー&ペーストだと解釈できる。上の例では、フィルタのパターン 1, 2, 3 を信号の存在する位置にペーストすることを繰り返している。出力の前半に 1, 2, 3 が現れ、後半に 2, 4, 6 が現れているのは、信号の大きさに合わせてそれぞれ1倍、2倍されているからである。

ニューラルネットワークにおける"たたみこみ"

たたみこみニューラルネットワーク (CNN) では、次の演算を convolution と呼んでいる。2
$$\text{out}(N_i, C_{\text{out}}) = \text{bias}(C_{\text{out}}) +
\sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}}, k)
\star \text{input}(N_i, k)$$
簡単のため、バッチサイズが1で、入力も出力も1チャンネルの場合を考えると
$$\text{out} = \text{bias} + \text{weight}
\star \text{input}$$
となる。 $\star$ は相互相関演算子であり、次式で定義される。
$$(h\star x)[n]=\sum_{k=0}^{K-1}h[k]x[n+k]$$
これは、上で示した本来のたたみこみとは異なる。具体的には、信号 $x$ に対してフィルタが作用する向きが前後逆になっている。

>>> print(np.correlate(x, h))
[3. 2. 1. 0. 6. 4. 2. 0.]

相互相関の式において $h'[n]=h[K-1-n]$ として前後反転したフィルタを作り、$h[n]$ と置き換えると
$$(h'\star x)[n-K+1]=\sum_{k=0}^{K-1}h'[k]x[n-k]$$
となり、本来のたたみこみの結果を $K-1$ だけ左にずらしたものと一致する。

PyTorchにおけるconvolution

CNN の convolution が実際には相互相関であることを確かめよう。nn.Conv1d によって「たたみこみ層」を作り、フィルタの係数を無理矢理上の例に合わせる。バイアスもゼロにしておく。

>>> import torch
>>> h_ = torch.Tensor(h).view(1, 1, -1)
>>> x_ = torch.Tensor(x).view(1, 1, -1)
>>> conv = torch.nn.Conv1d(1, 1, kernel_size=3)
>>> conv.weight.data = h_
>>> conv.bias.data = torch.Tensor([0.])
>>> conv(x_)
tensor([[[3., 2., 1., 0., 6., 4., 2., 0.]]], grad_fn=<ConvolutionBackward0>)

NumPyで相互相関を求めた結果と一致した。

Transposed convolution との関係

ニューラルネットワークでは、「転置たたみこみ (transposed convolution)」と呼ばれる演算が登場する。Transposed convolution の要点は

  • 0埋めによって入力が stride 分だけ拡大される
  • カーネルの各軸を反転 (flip) したフィルタとの convolution (実際には相互相関) が演算される

と整理できる3

Transposed convolution は専らアップサンプリングのために用いられるので、通常strideは2以上。stride=1 の場合は単にカーネルの各軸が反転した convolution と同等である。したがって、上の議論から明らかなように、これは本来のたたみこみに相当する。

>>> tconv = torch.nn.ConvTranspose1d(1, 1, kernel_size=3)
>>> tconv.weight.data = h_
>>> tconv.bias.data = torch.Tensor([0.])
>>> tconv(x_)
tensor([[[0., 0., 1., 2., 3., 0., 2., 4., 6., 0., 0., 0.]]],
       grad_fn=<ConvolutionBackward0>)

2次元の場合

次元数が変わっても話は同じ。本来のたたみこみが

>>> from scipy.signal import convolve2d
>>> h = np.array([[1., 2., 3.], [5., 7., 11.], [13., 17., 19.]])
>>> x = torch.rand((2, 2))
>>> convolve2d(h, x)
array([[ 0.24204093,  1.03771567,  1.83339041,  1.66090143],
       [ 1.96423751,  6.85992354, 10.57879001,  8.75817871],
       [ 6.91669637, 21.03717667, 28.53072971, 20.30246735],
       [ 9.80242705, 24.38078797, 29.44646275, 16.89864314]])

のとき、上下左右を反転させたフィルタは

>>> ht = np.flip(h, [0, 1]).copy()
>>> ht
array([[19., 17., 13.],
       [11.,  7.,  5.],
       [ 3.,  2.,  1.]])

となり、このフィルタとの 2d convolution は

>>> conv2 = torch.nn.Conv2d(1, 1, kernel_size=3, padding=2)
>>> conv2.weight.data = torch.Tensor(ht).view(1, 1, 3, 3)
>>> conv2.bias.data = torch.Tensor([0.])
>>> conv2(x[None, None, ...])
tensor([[[[ 0.2420,  1.0377,  1.8334,  1.6609],
          [ 1.9642,  6.8599, 10.5788,  8.7582],
          [ 6.9167, 21.0372, 28.5307, 20.3025],
          [ 9.8024, 24.3808, 29.4465, 16.8986]]]],
       grad_fn=<ConvolutionBackward0>)

で結果は一致する。さらに、反転する前の元のフィルタとの transposed convolution は

>>> tconv2 = torch.nn.ConvTranspose2d(1, 1, kernel_size=3)
>>> tconv2.weight.data = torch.Tensor(h).view(1, 1, 3, 3)
>>> tconv2.bias.data = torch.Tensor([0.])
>>> tconv2(x[None, None, ...])
tensor([[[[ 0.2420,  1.0377,  1.8334,  1.6609],
          [ 1.9642,  6.8599, 10.5788,  8.7582],
          [ 6.9167, 21.0372, 28.5307, 20.3025],
          [ 9.8024, 24.3808, 29.4465, 16.8986]]]],
       grad_fn=<ConvolutionBackward0>)

で、やはり一致する。

ところで、元のフィルタの転置は

>>> h.T
array([[ 1.,  5., 13.],
       [ 2.,  7., 17.],
       [ 3., 11., 19.]])

であって、上下左右反転したフィルタ ht とは異なる。Transposed convolution で実行されるのは反転であって、転置ではないのである。

結論

まぎらわしい。

  1. 樋口(監),川又・阿部・八巻(著)「MATLAB対応ディジタル信号処理 第2版」森北出版 (2021) 式(5.11), 式(9.1)

  2. Conv1d – Pytorch Documentation

  3. これらは convolution の逆伝播計算を convolution 演算によって行うために必要な処理。Transposed convolution のルーツは convolution の逆伝播計算にある。参考: https://blog.cosnomi.com/posts/transposed-conv/)

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?