2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Separable Convolution の計算方法に関するメモ

2
Last updated at Posted at 2018-06-04

ポイント

  • Separable Convoluation の計算方法を確認。

レファレンス

1. Depthwise Separable Convolutions for Neural Machine Translation

数式

image.png

      (参照論文より引用)

サンプルコード

x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

w1 = tf.constant(value = 1.0, shape = [1, 3, 2, 2], \
                    dtype = tf.float32)
w2 = tf.constant(value = 1.0, shape = [1, 3, 2, 1], \
                    dtype = tf.float32)
w3 = tf.constant(value = 1.0, shape = [1, 1, 2, 2], \
                    dtype = tf.float32)

# y1 = y3 = y4
y1 = tf.nn.conv2d(x, w1, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

y2 = tf.nn.depthwise_conv2d(x, w2, strides = [1, 1, 1, 1], \
                    padding = 'VALID')
y3 = tf.nn.conv2d(y2, w3, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

y4 = tf.nn.separable_conv2d(x, w2, w3, \
                 strides = [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y1)
print (y2)
print (y3)
print (y4)

with tf.Session() as sess:
  print ('x: ')
  print (sess.run(x))
  print ('y1: ')
  print (sess.run(y1))
  print ('y2: ')
  print (sess.run(y2))
  print ('y3: ')
  print (sess.run(y3))
  print ('y4: ')
  print (sess.run(y4))

結果

Tensor("Reshape_19:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D_21:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise_10:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_22:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d_3:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
y1:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y2:
[[[[ 6. 9.]
[12. 15.]
[18. 21.]
[24. 27.]
[30. 33.]
[36. 39.]
[42. 45.]
[48. 51.]]]]
y3:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y4:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]

サンプルコード2

x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

init = tf.truncated_normal_initializer(mean = 0.0, \
                     stddev = 1.0, dtype = tf.float32)
w_r = tf.get_variable('w_r', shape = [1, 3, 2, 2], \
                     initializer = init)
w_d = tf.get_variable('w_d', shape = [1, 3, 2, 1], \
                     initializer = init)
w_p = tf.get_variable('w_p', shape = [1, 1, 2, 2], \
                     initializer = init)

# regular
y_r = tf.nn.conv2d(x, w_r, strides = [1, 1, 1, 1], \
                     padding = 'VALID')

# depthwise + pointwise
y_d = tf.nn.depthwise_conv2d(x, w_d, strides = \
                    [1, 1, 1, 1],  padding = 'VALID')
y_d_p = tf.nn.conv2d(y_d, w_p, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

# separable
y_s = tf.nn.separable_conv2d(x, w_d, w_p, strides = \
                    [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y_r)
print (y_d)
print (y_d_p)
print (y_s)

with tf.Session() as sess:
  
  sess.run(tf.global_variables_initializer())
  
  print ('x: ')
  print (sess.run(x))
  print ('regular: ')
  print (sess.run(y_r))
  print ('depthwise: ')
  print (sess.run(y_d))
  print ('depthwise + pointwise: ')
  print (sess.run(y_d_p))
  print ('separable: ')
  print (sess.run(y_s))

結果2

Tensor("Reshape:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_1:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
regular:
[[[[ 1.9705558 0.33328748]
[ 1.8040164 -4.134241 ]
[ 1.6374781 -8.60177 ]
[ 1.4709401 -13.069299 ]
[ 1.3044 -17.536827 ]
[ 1.137861 -22.004358 ]
[ 0.97132176 -26.471891 ]
[ 0.80478483 -30.939413 ]]]]
depthwise:
[[[[-0.63552576 5.332875 ]
[ 2.240392 8.880491 ]
[ 5.11631 12.428108 ]
[ 7.9922285 15.975724 ]
[10.868145 19.523342 ]
[13.744063 23.070957 ]
[16.619982 26.618574 ]
[19.4959 30.166191 ]]]]
depthwise + pointwise:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]
separable:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?