Conv2D演算
2次元畳み込みの演算を考えた場合、入力$(Batch,H,W,C_{in})$、出力$(Batch,H,W,C_{out})$、カーネルサイズ$(3,3)$、畳み込みの重み$W=(3,3,C_{in},C_{out})$とすれば、
実質的にConv2D演算は
$matmul(x,W)=matmul((Batch,HW,9C_{in}), (9C_{in}, C_{out}))=(Batch,HW,C_{out})$
というmatmulの行列演算を考えることに等しいです。
ここで$c=matmul(a,b)$という行列演算において$a=(i,j,k,m),b=(m,n)$ならば$c=(i,j,k,n)$です。
一方で入力$(Batch,H,W,C_{in})$に対して
x[:,0]=input[:,0:H-2,0:W-2,:] \\
x[:,1]=input[:,0:H-2,1:W-1,:] \\
x[:,2]=input[:,0:H-2,2:W-0,:] \\
x[:,3]=input[:,1:H-1,0:W-2,:] \\
x[:,4]=input[:,1:H-1,1:W-1,:] \\
x[:,5]=input[:,1:H-1,2:W-0,:] \\
x[:,6]=input[:,2:H-0,0:W-2,:] \\
x[:,7]=input[:,2:H-0,1:W-1,:] \\
x[:,8]=input[:,2:H-0,2:W-0,:]
のような$(H,W)$から$(H-2,W-2)$の抜き出しを行い、行列演算の前に$(Batch,HW,9C_{in})$のような行列に変換する必要があります。このような行列変形は$im2col$と呼ばれます。この処理によって入力チャンネル数がカーネルサイズ総数倍になると考えることができます。また$im2col$処理自体には重みを持ちません。
def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
N, C, H, W = input_data.shape
out_h = (H + 2*pad - filter_h)//stride + 1
out_w = (W + 2*pad - filter_w)//stride + 1
img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))
for y in range(filter_h):
y_max = y + stride*out_h
for x in range(filter_w):
x_max = x + stride*out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
return col
Pytorchにおいてim2col関数はUnfold関数と呼ばれます。
従って、Conv2D=(im2col+matmul)=(Unfold+matmul)となっている筈です。本題は本当にそうなっているか確認してみました。
PyTorchにおける比較
PyTorchはチャンネルファーストで入力$(Batch,C_{in},H,W)=(25,3,32,32)$、出力$(Batch,C_{out},H,W)=(25,16,30,30)$、カーネルサイズ$(3,3)$、重み$W=(C_{out}, 3×3×C_{in})=(16,27)$とします。
(Unfold+matmul)演算
import numpy as np
import torch
input = torch.tensor(np.random.rand(25,3,32,32)).float()
weight = torch.tensor(np.random.rand(16,3,3,3)).float()
weight2 = weight.reshape((16,27))
print('input.shape= ', input.shape)
print('weight.shape= ', weight.shape)
print('weight2.shape=', weight2.shape)
x = torch.nn.Unfold(kernel_size=(3,3), stride=(1,1), padding=(0,0), dilation=(1,1))(input)
output1 = torch.matmul(weight2, x).reshape((25,16,30,30))
print('x.shape= ', x.shape)
print('output1.shape=', output1.shape)
-----------------------------------------------------------
input.shape= torch.Size([25, 3, 32, 32])
weight.shape= torch.Size([16, 3, 3, 3])
weight2.shape= torch.Size([16, 27])
x.shape= torch.Size([25, 27, 900])
output1.shape= torch.Size([25, 16, 30, 30])
ここでUnfold関数を入力にかけると$x=(25, 3×3×3, 30×30)=(25,27,900)$となり、$W=(16,27)$のとき、$matmul(W,x)=(25,16,30×30)$となります。
Conv2D演算
一方、入力$(Batch,C_{in},H,W)=(25,3,32,32)$、Conv2D関数の重みを$W=(16,3,3,3)$とした場合、出力を求めるコードは以下です。
conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, bias=False)
conv1.weight.data = weight
output2 = conv1(input)
print('conv1.weight.shape=', conv1.weight.shape)
print('output2.shape= ', output2.shape)
-----------------------------------------------------------
conv1.weight.shape= torch.Size([16, 3, 3, 3])
output2.shape= torch.Size([25, 16, 30, 30])
これらから(Unfold+matmul)で求めたoutput1とConv2Dで求めたoutput2を比較すると値は完全に一致しました。従ってConv2D=(Unfold+matmul)と演算的には等価であるのが確認できました。
output1:
tensor([[[[7.4075, 7.1269, 6.2595, ..., 6.9860, 6.5256, 7.3597],
[6.4978, 7.3303, 6.7621, ..., 7.2054, 6.9357, 7.3798],
[5.9309, 5.5016, 6.3321, ..., 5.7143, 7.0358, 6.8819],
...,
[6.0168, 6.9415, 7.5508, ..., 5.4547, 4.7888, 6.0636],
[5.0191, 7.0944, 7.0875, ..., 3.9413, 4.1925, 5.5689],
[6.2448, 6.4813, 5.5424, ..., 4.2610, 5.8013, 5.3431]],
......
output2:
tensor([[[[7.4075, 7.1269, 6.2595, ..., 6.9860, 6.5256, 7.3597],
[6.4979, 7.3303, 6.7621, ..., 7.2054, 6.9357, 7.3798],
[5.9309, 5.5016, 6.3321, ..., 5.7143, 7.0358, 6.8819],
...,
[6.0168, 6.9415, 7.5508, ..., 5.4547, 4.7888, 6.0636],
[5.0191, 7.0944, 7.0874, ..., 3.9413, 4.1925, 5.5689],
[6.2448, 6.4813, 5.5424, ..., 4.2610, 5.8013, 5.3431]],
......
Unfold関数の他の使い道
kernel_sizeとstrideが等しい時、Vision Transformerのパッチ分割に相当します。
まあ、別にパッチ分割はUnfold使わずともreshapeとtransposeで代用出来るんですが…。
input = torch.tensor(np.random.rand(25,3,224,224)).float()
x = torch.nn.Unfold(kernel_size=(14,14), stride=(14,14), padding=(0,0), dilation=(1,1))(input)
-----------------------------------------------------------
input.shape= torch.Size([25, 3, 224, 224])
x.shape= torch.Size([25, 588, 256]) #(25,3*14*14,16*16)
Vision Transformerが全くConv2Dを使っていないという話において、Attention重みとValueとの演算にmatmulが含まれるのでViTでもUnfold+matmulが結局Conv2D相当なのではと根拠のない妄想しました。
まとめ
Unfold関数はPytorchにおけるim2col関数であり、Conv2D=(Unfold+matmul)である。
またtensorflowではextract_image_patches関数である。