LoginSignup
10
7

More than 3 years have passed since last update.

Unfold関数について

Posted at

Conv2D演算

2次元畳み込みの演算を考えた場合、入力$(Batch,H,W,C_{in})$、出力$(Batch,H,W,C_{out})$、カーネルサイズ$(3,3)$、畳み込みの重み$W=(3,3,C_{in},C_{out})$とすれば、

実質的にConv2D演算は
$matmul(x,W)=matmul((Batch,HW,9C_{in}), (9C_{in}, C_{out}))=(Batch,HW,C_{out})$
というmatmulの行列演算を考えることに等しいです。
ここで$c=matmul(a,b)$という行列演算において$a=(i,j,k,m),b=(m,n)$ならば$c=(i,j,k,n)$です。

一方で入力$(Batch,H,W,C_{in})$に対して

x[:,0]=input[:,0:H-2,0:W-2,:] \\
x[:,1]=input[:,0:H-2,1:W-1,:] \\
x[:,2]=input[:,0:H-2,2:W-0,:] \\ 
x[:,3]=input[:,1:H-1,0:W-2,:] \\
x[:,4]=input[:,1:H-1,1:W-1,:] \\
x[:,5]=input[:,1:H-1,2:W-0,:] \\ 
x[:,6]=input[:,2:H-0,0:W-2,:] \\
x[:,7]=input[:,2:H-0,1:W-1,:] \\
x[:,8]=input[:,2:H-0,2:W-0,:]

のような$(H,W)$から$(H-2,W-2)$の抜き出しを行い、行列演算の前に$(Batch,HW,9C_{in})$のような行列に変換する必要があります。このような行列変形は$im2col$と呼ばれます。この処理によって入力チャンネル数がカーネルサイズ総数倍になると考えることができます。また$im2col$処理自体には重みを持ちません。

def im2col(input_data, filter_h, filter_w, stride=1, pad=0):

    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

Pytorchにおいてim2col関数はUnfold関数と呼ばれます。
従って、Conv2D=(im2col+matmul)=(Unfold+matmul)となっている筈です。本題は本当にそうなっているか確認してみました。

PyTorchにおける比較

PyTorchはチャンネルファーストで入力$(Batch,C_{in},H,W)=(25,3,32,32)$、出力$(Batch,C_{out},H,W)=(25,16,30,30)$、カーネルサイズ$(3,3)$、重み$W=(C_{out}, 3×3×C_{in})=(16,27)$とします。

(Unfold+matmul)演算

import numpy as np
import torch

input = torch.tensor(np.random.rand(25,3,32,32)).float()
weight = torch.tensor(np.random.rand(16,3,3,3)).float()
weight2 = weight.reshape((16,27))

print('input.shape=  ', input.shape)
print('weight.shape= ', weight.shape)
print('weight2.shape=', weight2.shape)

x = torch.nn.Unfold(kernel_size=(3,3), stride=(1,1), padding=(0,0), dilation=(1,1))(input)
output1 = torch.matmul(weight2, x).reshape((25,16,30,30))

print('x.shape=      ', x.shape)
print('output1.shape=', output1.shape)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 32, 32])
weight.shape=  torch.Size([16, 3, 3, 3])
weight2.shape= torch.Size([16, 27])
x.shape=       torch.Size([25, 27, 900])
output1.shape= torch.Size([25, 16, 30, 30])

ここでUnfold関数を入力にかけると$x=(25, 3×3×3, 30×30)=(25,27,900)$となり、$W=(16,27)$のとき、$matmul(W,x)=(25,16,30×30)$となります。

Conv2D演算

一方、入力$(Batch,C_{in},H,W)=(25,3,32,32)$、Conv2D関数の重みを$W=(16,3,3,3)$とした場合、出力を求めるコードは以下です。

conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, bias=False)
conv1.weight.data = weight

output2 = conv1(input)

print('conv1.weight.shape=', conv1.weight.shape)
print('output2.shape= ', output2.shape)
-----------------------------------------------------------
conv1.weight.shape= torch.Size([16, 3, 3, 3])
output2.shape=  torch.Size([25, 16, 30, 30])

これらから(Unfold+matmul)で求めたoutput1とConv2Dで求めたoutput2を比較すると値は完全に一致しました。従ってConv2D=(Unfold+matmul)と演算的には等価であるのが確認できました。

output1:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4978, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0875,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......
output2:
tensor([[[[7.4075, 7.1269, 6.2595,  ..., 6.9860, 6.5256, 7.3597],
          [6.4979, 7.3303, 6.7621,  ..., 7.2054, 6.9357, 7.3798],
          [5.9309, 5.5016, 6.3321,  ..., 5.7143, 7.0358, 6.8819],
          ...,
          [6.0168, 6.9415, 7.5508,  ..., 5.4547, 4.7888, 6.0636],
          [5.0191, 7.0944, 7.0874,  ..., 3.9413, 4.1925, 5.5689],
          [6.2448, 6.4813, 5.5424,  ..., 4.2610, 5.8013, 5.3431]],
......

Unfold関数の他の使い道

kernel_sizeとstrideが等しい時、Vision Transformerのパッチ分割に相当します。
まあ、別にパッチ分割はUnfold使わずともreshapeとtransposeで代用出来るんですが…。

input = torch.tensor(np.random.rand(25,3,224,224)).float()
x = torch.nn.Unfold(kernel_size=(14,14), stride=(14,14), padding=(0,0), dilation=(1,1))(input)
-----------------------------------------------------------
input.shape=   torch.Size([25, 3, 224, 224])
x.shape=       torch.Size([25, 588, 256]) #(25,3*14*14,16*16)

Vision Transformerが全くConv2Dを使っていないという話において、Attention重みとValueとの演算にmatmulが含まれるのでViTでもUnfold+matmulが結局Conv2D相当なのではと根拠のない妄想しました。

まとめ

Unfold関数はPytorchにおけるim2col関数であり、Conv2D=(Unfold+matmul)である。
またtensorflowではextract_image_patches関数である。

10
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
7