LoginSignup
0

More than 5 years have passed since last update.

TheanoでCNNを構築するために、Theanoの2次元の畳み込み関数theano.tensor.nnet.conv()について何点か調査しました。信号処理でおそらくよく使われているであろうN次元の畳み込み関数scipy.signal.fftconvolve()と比較しました。

2次元配列同士の畳み込み

まずは単純な2次元配列同士で畳み込みをしてみます。

import theano
import theano.tensor as T
import theano.tensor.signal as signal 
import scipy.signal as s

m = T.matrix()
w = T.matrix()

#ランク4にする必要がある。
o_full = nnet.conv.conv2d(m[None,None,:,:], w[None, None,:,:],
                          border_mode='full')
o_valid = nnet.conv.conv2d(m[None,None,:,:], w[None, None,:,:],
                          border_mode='valid')

m_arr = arange(25.).reshape((5,5)).astype(float32)
w_arr = ones((3,3)).astype(float32)
print("m_arr =")
print(m_arr)
print("w_arr =")
print(w_arr)

print("Output for Theano.")
print("full:")
print(o_full.eval({m:m_arr, w:w_arr}).round().astype(int))
print("valid:")
print(o_valid.eval({m:m_arr, w:w_arr}).round().astype(int))

print("Output for scipy.")
print("full:")
print(s.fftconvolve(m_arr, w_arr, "full").round().astype(int))
print("valid:")
print(s.fftconvolve(m_arr, w_arr, "valid").round().astype(int))

畳み込まれる配列m_arrと畳み込む窓関数(orカーネルorフィルタ)w_arrtheano.tensor.nnet.conv.conv2d()とscipy.signal.fftconvolve()```
にそれぞれ流しています。ここで、

#ランク4にする必要がある。
o_full = nnet.conv.conv2d(m[None,None,:,:], w[None, None,:,:],
                          border_mode='full')
o_valid = nnet.conv.conv2d(m[None,None,:,:], w[None, None,:,:],
                          border_mode='valid')

のように、m[None,None,:,:], w[None, None,:,:]としているのは、入力とカーネルの配列の形式が[画像枚数、チャンネル数、高さ、幅]になっているからです。m,wはランク2のT.matrix()として定義したので、[None, None,:,:]のようにすることで、上位ランクを2つ増やしています。このブロードキャストはNumpyのそれと同じなので個人的にとても使いやすいです。

出力は以下のようになります。

m_arr =
[[  0.   1.   2.   3.   4.]
 [  5.   6.   7.   8.   9.]
 [ 10.  11.  12.  13.  14.]
 [ 15.  16.  17.  18.  19.]
 [ 20.  21.  22.  23.  24.]]
w_arr =
[[ 1.  1.  1.]
 [ 1.  1.  1.]
 [ 1.  1.  1.]]
Output for Theano.
full:
[[[[  0   1   3   6   9   7   4]
   [  5  12  21  27  33  24  13]
   [ 15  33  54  63  72  51  27]
   [ 30  63  99 108 117  81  42]
   [ 45  93 144 153 162 111  57]
   [ 35  72 111 117 123  84  43]
   [ 20  41  63  66  69  47  24]]]]
valid:
[[[[ 54  63  72]
   [ 99 108 117]
   [144 153 162]]]]
Output for scipy.
full:
[[  0   1   3   6   9   7   4]
 [  5  12  21  27  33  24  13]
 [ 15  33  54  63  72  51  27]
 [ 30  63  99 108 117  81  42]
 [ 45  93 144 153 162 111  57]
 [ 35  72 111 117 123  84  43]
 [ 20  41  63  66  69  47  24]]
valid:
[[ 54  63  72]
 [ 99 108 117]
 [144 153 162]]

畳み込みの出力は見やすいように四捨五入した後intに変換しています。
theano.tensor.nnet.conv()ではborder_modeという引数がありました。これにはfullvalidが選択可能です。畳み込みは画像に対しフィルタを動かしながら掛け算して総和を取るのですが、fullは画像に対しフィルタがはみ出していても、要素の最低一つが重なっている状態の結果を出力に含むモード、validは画像に対しフィルタがはみ出していない状態の結果のみ出力とするモードです。ある軸の画像とフィルタのサイズがそれぞれ$M,m$の時、出力される配列のある軸のサイズはfullでは$M+(m-1)$、validでは$M-(m-1)$になります。上の例では高さ(or幅)が$M=5,m=3$なので、fullの時7、validの時3になっています。

出力を確認すると、theano.tensor.nnet.conv()scipy.signal.fftconvolve()で(配列のランクを除いて)等しいことが確認できます。

ただし両者の出力は意味合いが異なっています。scipy.signal.fftconvolve()が純粋なN次元の畳み込みの結果を返すのに対して、theano.tensor.nnet.conv()は画像枚数毎、フィルタ毎の畳み込みの結果を返します。出力配列は[画像枚数、ファイル枚数、高さ、幅]です。また、後述しますがtheano.tensor.nnet.conv()はチャネル数が画像とフィルタとで等しくないといけません。

画像枚数、チャネル数の次元を加えた場合のコンボリューション

次は画像枚数とチャネル数の次元を加えたコンボリューションを行います。画像枚数2枚、チャネル数3で5x5の画像に対し、画像枚数1、チャネル数3で3x3のフィルタを畳み込みます。プログラムは以下のようになります。

m = T.tensor4()
w = T.tensor4()

#ランク4にする必要がある。
o_full = nnet.conv.conv2d(m, w,
                          border_mode='full')
o_valid = nnet.conv.conv2d(m, w,
                          border_mode='valid')

m_arr = arange(2*3*5*5).reshape((2, 3, 5, 5)).astype(float32)
w_arr = ones((1,3,3,3)).astype(float32)
print("m_arr =")
print(m_arr)
print("w_arr =")
print(w_arr)

print("Output for Theano.")
print("full:")
print(o_full.eval({m:m_arr, w:w_arr}).round().astype(int))
print("valid:")
print(o_valid.eval({m:m_arr, w:w_arr}).round().astype(int))

print("Output for scipy.")
print("full:")
print(s.fftconvolve(m_arr, w_arr, "full").round().astype(int))
print("valid:")
print(s.fftconvolve(m_arr, w_arr, "valid").round().astype(int))

ランク4のテンソルを設定するためにmwT.tensor4()を設定しています。

m_arr =
[[[[   0.    1.    2.    3.    4.]
   [   5.    6.    7.    8.    9.]
   [  10.   11.   12.   13.   14.]
   [  15.   16.   17.   18.   19.]
   [  20.   21.   22.   23.   24.]]

  [[  25.   26.   27.   28.   29.]
   [  30.   31.   32.   33.   34.]
   [  35.   36.   37.   38.   39.]
   [  40.   41.   42.   43.   44.]
   [  45.   46.   47.   48.   49.]]

  [[  50.   51.   52.   53.   54.]
   [  55.   56.   57.   58.   59.]
   [  60.   61.   62.   63.   64.]
   [  65.   66.   67.   68.   69.]
   [  70.   71.   72.   73.   74.]]]


 [[[  75.   76.   77.   78.   79.]
   [  80.   81.   82.   83.   84.]
   [  85.   86.   87.   88.   89.]
   [  90.   91.   92.   93.   94.]
   [  95.   96.   97.   98.   99.]]

  [[ 100.  101.  102.  103.  104.]
   [ 105.  106.  107.  108.  109.]
   [ 110.  111.  112.  113.  114.]
   [ 115.  116.  117.  118.  119.]
   [ 120.  121.  122.  123.  124.]]

  [[ 125.  126.  127.  128.  129.]
   [ 130.  131.  132.  133.  134.]
   [ 135.  136.  137.  138.  139.]
   [ 140.  141.  142.  143.  144.]
   [ 145.  146.  147.  148.  149.]]]]
w_arr =
[[[[ 1.  1.  1.]
   [ 1.  1.  1.]
   [ 1.  1.  1.]]

  [[ 1.  1.  1.]
   [ 1.  1.  1.]
   [ 1.  1.  1.]]

  [[ 1.  1.  1.]
   [ 1.  1.  1.]
   [ 1.  1.  1.]]]]
Output for Theano.
full:
[[[[  75  153  234  243  252  171   87]
   [ 165  336  513  531  549  372  189]
   [ 270  549  837  864  891  603  306]
   [ 315  639  972  999 1026  693  351]
   [ 360  729 1107 1134 1161  783  396]
   [ 255  516  783  801  819  552  279]
   [ 135  273  414  423  432  291  147]]]


 [[[ 300  603  909  918  927  621  312]
   [ 615 1236 1863 1881 1899 1272  639]
   [ 945 1899 2862 2889 2916 1953  981]
   [ 990 1989 2997 3024 3051 2043 1026]
   [1035 2079 3132 3159 3186 2133 1071]
   [ 705 1416 2133 2151 2169 1452  729]
   [ 360  723 1089 1098 1107  741  372]]]]
valid:
[[[[ 837  864  891]
   [ 972  999 1026]
   [1107 1134 1161]]]


 [[[2862 2889 2916]
   [2997 3024 3051]
   [3132 3159 3186]]]]
Output for scipy.
full:
[[[[   0    1    3    6    9    7    4]
   [   5   12   21   27   33   24   13]
   [  15   33   54   63   72   51   27]
   [  30   63   99  108  117   81   42]
   [  45   93  144  153  162  111   57]
   [  35   72  111  117  123   84   43]
   [  20   41   63   66   69   47   24]]

  [[  25   52   81   87   93   64   33]
   [  60  124  192  204  216  148   76]
   [ 105  216  333  351  369  252  129]
   [ 135  276  423  441  459  312  159]
   [ 165  336  513  531  549  372  189]
   [ 120  244  372  384  396  268  136]
   [  65  132  201  207  213  144   73]]

  [[  75  153  234  243  252  171   87]
   [ 165  336  513  531  549  372  189]
   [ 270  549  837  864  891  603  306]
   [ 315  639  972  999 1026  693  351]
   [ 360  729 1107 1134 1161  783  396]
   [ 255  516  783  801  819  552  279]
   [ 135  273  414  423  432  291  147]]

  [[  75  152  231  237  243  164   83]
   [ 160  324  492  504  516  348  176]
   [ 255  516  783  801  819  552  279]
   [ 285  576  873  891  909  612  309]
   [ 315  636  963  981  999  672  339]
   [ 220  444  672  684  696  468  236]
   [ 115  232  351  357  363  244  123]]

  [[  50  101  153  156  159  107   54]
   [ 105  212  321  327  333  224  113]
   [ 165  333  504  513  522  351  177]
   [ 180  363  549  558  567  381  192]
   [ 195  393  594  603  612  411  207]
   [ 135  272  411  417  423  284  143]
   [  70  141  213  216  219  147   74]]]


 [[[  75  151  228  231  234  157   79]
   [ 155  312  471  477  483  324  163]
   [ 240  483  729  738  747  501  252]
   [ 255  513  774  783  792  531  267]
   [ 270  543  819  828  837  561  282]
   [ 185  372  561  567  573  384  193]
   [  95  191  288  291  294  197   99]]

  [[ 175  352  531  537  543  364  183]
   [ 360  724 1092 1104 1116  748  376]
   [ 555 1116 1683 1701 1719 1152  579]
   [ 585 1176 1773 1791 1809 1212  609]
   [ 615 1236 1863 1881 1899 1272  639]
   [ 420  844 1272 1284 1296  868  436]
   [ 215  432  651  657  663  444  223]]

  [[ 300  603  909  918  927  621  312]
   [ 615 1236 1863 1881 1899 1272  639]
   [ 945 1899 2862 2889 2916 1953  981]
   [ 990 1989 2997 3024 3051 2043 1026]
   [1035 2079 3132 3159 3186 2133 1071]
   [ 705 1416 2133 2151 2169 1452  729]
   [ 360  723 1089 1098 1107  741  372]]

  [[ 225  452  681  687  693  464  233]
   [ 460  924 1392 1404 1416  948  476]
   [ 705 1416 2133 2151 2169 1452  729]
   [ 735 1476 2223 2241 2259 1512  759]
   [ 765 1536 2313 2331 2349 1572  789]
   [ 520 1044 1572 1584 1596 1068  536]
   [ 265  532  801  807  813  544  273]]

  [[ 125  251  378  381  384  257  129]
   [ 255  512  771  777  783  524  263]
   [ 390  783 1179 1188 1197  801  402]
   [ 405  813 1224 1233 1242  831  417]
   [ 420  843 1269 1278 1287  861  432]
   [ 285  572  861  867  873  584  293]
   [ 145  291  438  441  444  297  149]]]]
valid:
[[[[ 837  864  891]
   [ 972  999 1026]
   [1107 1134 1161]]]


 [[[2862 2889 2916]
   [2997 3024 3051]
   [3132 3159 3186]]]]

長くて比較しづらいのですが、validは同じですが、fullは両者で結果が異なっています。そこで、出力後の配列のshapeを見てみます。

print("Output for Theano.")
print("full:")
print(o_full.eval({m:m_arr, w:w_arr}).round().astype(int).shape)
print("valid:")
print(o_valid.eval({m:m_arr, w:w_arr}).round().astype(int).shape)

print("Output for scipy.")
print("full:")
print(s.fftconvolve(m_arr, w_arr, "full").round().astype(int).shape)
print("valid:")
print(s.fftconvolve(m_arr, w_arr, "valid").round().astype(int).shape)
Output for Theano.
full:
(2, 1, 7, 7)
valid:
(2, 1, 3, 3)
Output for scipy.
full:
(2, 5, 7, 7)
valid:
(2, 1, 3, 3)

これは、scipy.signal.fftconvolve()が畳み込み操作を画像枚数及びチャネル数の軸に関しても行っているのに対して、theano.tensor.nnet.conv()では画像の幅、高さの次元でしか行わず、画像枚数とチャネル数に対しては独立に処理していることに起因します。そして、theano.tensor.nnet.conv()の出力は[画像枚数、フィルタ枚数、高さ、幅]なので、shapeの2番目は1になっています。また、theano.tensor.nnet.conv()は上述の様にチャネル数を画像とフィルタとで合わせる必要があります。例えば、

m_arr = arange(2*3*5*5).reshape((2, 3, 5, 5)).astype(float32)
w_arr = ones((1,1,3,3)).astype(float32)

のように画像のチャネル数が3に対して、フィルタのチャネル数が1の場合はtheano.tensor.nnet.conv()では以下のエラーを出力します。

ValueError: GpuDnnConv images and kernel must have the same stack size

ただし、scipy.signal.fftconvolve()では配列のshapeが

Output for scipy.
full:
(2, 3, 7, 7)
valid:
(2, 3, 3, 3)

の結果を返します。fullでは$M+(m-1)$、validでは$M-(m-1)$のとおりになっています。

最後に、画像枚数2、フィルタ枚数3、チャネル数1で試してみます。また、配列の要素数を減らしました。

m = T.tensor4()
w = T.tensor4()

#ランク4にする必要がある。
o_full = nnet.conv.conv2d(m, w,
                          border_mode='full')
o_valid = nnet.conv.conv2d(m, w,
                          border_mode='valid')

m_arr = arange(2*1*3*3).reshape((2, 1, 3, 3)).astype(float32)
w_arr = ones((3,1,1,1)).astype(float32)
print("m_arr =")
print(m_arr)
print("w_arr =")
print(w_arr)

print("Output for Theano.")
print("full:")
print(o_full.eval({m:m_arr, w:w_arr}).round().astype(int))
print("valid:")
print(o_valid.eval({m:m_arr, w:w_arr}).round().astype(int))

print("Output for scipy.")
print("full:")
print(s.fftconvolve(m_arr, w_arr, "full").round().astype(int))
print("valid:")
print(s.fftconvolve(m_arr, w_arr, "valid").round().astype(int))
m_arr =
[[[[  0.   1.   2.]
   [  3.   4.   5.]
   [  6.   7.   8.]]]


 [[[  9.  10.  11.]
   [ 12.  13.  14.]
   [ 15.  16.  17.]]]]
w_arr =
[[[[ 1.]]]


 [[[ 1.]]]


 [[[ 1.]]]]
Output for Theano.
full:
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]
valid:
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]
Output for scipy.
full:
[[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]]


 [[[ 9 11 13]
   [15 17 19]
   [21 23 25]]]


 [[[ 9 11 13]
   [15 17 19]
   [21 23 25]]]


 [[[ 9 10 11]
   [12 13 14]
   [15 16 17]]]]
valid:
ValueError: For 'valid' mode, one must be at least as large as the other in every dimension

scipy.signal.fftconvolve()のvalidがエラーになってしまいました。validの場合、画像とフィルタのいずれかが片方よりもすべての次元で大きくないといけないようです。

配列のshapeは以下になります。

Output for Theano.
full:
(2, 3, 3, 3)
valid:
(2, 3, 3, 3)
Output for scipy.
full:
(4, 1, 3, 3)
valid:

 theano.tensor.nnet.conv()はshapeの1,2番目がそれぞれ画像枚数、フィルタ枚数で、残りがfullでは$M+(m-1)$、validでは$M-(m-1)$になっており、scipy.signal.fftconvolve()はすべての軸に対してfullでは$M+(m-1)$になっていることがわかります。

ストライド

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0