Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@takeshikondo

Separable Convolution の計算方法に関するメモ

More than 1 year has passed since last update.

ポイント

  • Separable Convoluation の計算方法を確認。

レファレンス

1. Depthwise Separable Convolutions for Neural Machine Translation

数式

image.png

      (参照論文より引用)

サンプルコード

x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

w1 = tf.constant(value = 1.0, shape = [1, 3, 2, 2], \
                    dtype = tf.float32)
w2 = tf.constant(value = 1.0, shape = [1, 3, 2, 1], \
                    dtype = tf.float32)
w3 = tf.constant(value = 1.0, shape = [1, 1, 2, 2], \
                    dtype = tf.float32)

# y1 = y3 = y4
y1 = tf.nn.conv2d(x, w1, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

y2 = tf.nn.depthwise_conv2d(x, w2, strides = [1, 1, 1, 1], \
                    padding = 'VALID')
y3 = tf.nn.conv2d(y2, w3, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

y4 = tf.nn.separable_conv2d(x, w2, w3, \
                 strides = [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y1)
print (y2)
print (y3)
print (y4)

with tf.Session() as sess:
  print ('x: ')
  print (sess.run(x))
  print ('y1: ')
  print (sess.run(y1))
  print ('y2: ')
  print (sess.run(y2))
  print ('y3: ')
  print (sess.run(y3))
  print ('y4: ')
  print (sess.run(y4))

結果

Tensor("Reshape_19:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D_21:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise_10:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_22:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d_3:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
y1:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y2:
[[[[ 6. 9.]
[12. 15.]
[18. 21.]
[24. 27.]
[30. 33.]
[36. 39.]
[42. 45.]
[48. 51.]]]]
y3:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y4:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]

サンプルコード2

x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

init = tf.truncated_normal_initializer(mean = 0.0, \
                     stddev = 1.0, dtype = tf.float32)
w_r = tf.get_variable('w_r', shape = [1, 3, 2, 2], \
                     initializer = init)
w_d = tf.get_variable('w_d', shape = [1, 3, 2, 1], \
                     initializer = init)
w_p = tf.get_variable('w_p', shape = [1, 1, 2, 2], \
                     initializer = init)

# regular
y_r = tf.nn.conv2d(x, w_r, strides = [1, 1, 1, 1], \
                     padding = 'VALID')

# depthwise + pointwise
y_d = tf.nn.depthwise_conv2d(x, w_d, strides = \
                    [1, 1, 1, 1],  padding = 'VALID')
y_d_p = tf.nn.conv2d(y_d, w_p, strides = [1, 1, 1, 1], \
                    padding = 'VALID')

# separable
y_s = tf.nn.separable_conv2d(x, w_d, w_p, strides = \
                    [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y_r)
print (y_d)
print (y_d_p)
print (y_s)

with tf.Session() as sess:

  sess.run(tf.global_variables_initializer())

  print ('x: ')
  print (sess.run(x))
  print ('regular: ')
  print (sess.run(y_r))
  print ('depthwise: ')
  print (sess.run(y_d))
  print ('depthwise + pointwise: ')
  print (sess.run(y_d_p))
  print ('separable: ')
  print (sess.run(y_s))

結果2

Tensor("Reshape:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_1:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
regular:
[[[[ 1.9705558 0.33328748]
[ 1.8040164 -4.134241 ]
[ 1.6374781 -8.60177 ]
[ 1.4709401 -13.069299 ]
[ 1.3044 -17.536827 ]
[ 1.137861 -22.004358 ]
[ 0.97132176 -26.471891 ]
[ 0.80478483 -30.939413 ]]]]
depthwise:
[[[[-0.63552576 5.332875 ]
[ 2.240392 8.880491 ]
[ 5.11631 12.428108 ]
[ 7.9922285 15.975724 ]
[10.868145 19.523342 ]
[13.744063 23.070957 ]
[16.619982 26.618574 ]
[19.4959 30.166191 ]]]]
depthwise + pointwise:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]
separable:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
1
Help us understand the problem. What is going on with this article?