Qiita Teams that are logged in
You are not logged in to any team

Community
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@takeshikondo

Separable Convolution の計算方法に関するメモ

More than 1 year has passed since last update.

ポイント

• Separable Convoluation の計算方法を確認。

レファレンス

1. Depthwise Separable Convolutions for Neural Machine Translation

（参照論文より引用）

サンプルコード

``````x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

w1 = tf.constant(value = 1.0, shape = [1, 3, 2, 2], \
dtype = tf.float32)
w2 = tf.constant(value = 1.0, shape = [1, 3, 2, 1], \
dtype = tf.float32)
w3 = tf.constant(value = 1.0, shape = [1, 1, 2, 2], \
dtype = tf.float32)

# y1 = y3 = y4
y1 = tf.nn.conv2d(x, w1, strides = [1, 1, 1, 1], \

y2 = tf.nn.depthwise_conv2d(x, w2, strides = [1, 1, 1, 1], \
y3 = tf.nn.conv2d(y2, w3, strides = [1, 1, 1, 1], \

y4 = tf.nn.separable_conv2d(x, w2, w3, \
strides = [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y1)
print (y2)
print (y3)
print (y4)

with tf.Session() as sess:
print ('x: ')
print (sess.run(x))
print ('y1: ')
print (sess.run(y1))
print ('y2: ')
print (sess.run(y2))
print ('y3: ')
print (sess.run(y3))
print ('y4: ')
print (sess.run(y4))

``````

結果

Tensor("Reshape_19:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D_21:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise_10:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_22:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d_3:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
y1:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y2:
[[[[ 6. 9.]
[12. 15.]
[18. 21.]
[24. 27.]
[30. 33.]
[36. 39.]
[42. 45.]
[48. 51.]]]]
y3:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y4:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]

サンプルコード２

``````x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

init = tf.truncated_normal_initializer(mean = 0.0, \
stddev = 1.0, dtype = tf.float32)
w_r = tf.get_variable('w_r', shape = [1, 3, 2, 2], \
initializer = init)
w_d = tf.get_variable('w_d', shape = [1, 3, 2, 1], \
initializer = init)
w_p = tf.get_variable('w_p', shape = [1, 1, 2, 2], \
initializer = init)

# regular
y_r = tf.nn.conv2d(x, w_r, strides = [1, 1, 1, 1], \

# depthwise + pointwise
y_d = tf.nn.depthwise_conv2d(x, w_d, strides = \
[1, 1, 1, 1],  padding = 'VALID')
y_d_p = tf.nn.conv2d(y_d, w_p, strides = [1, 1, 1, 1], \

# separable
y_s = tf.nn.separable_conv2d(x, w_d, w_p, strides = \
[1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y_r)
print (y_d)
print (y_d_p)
print (y_s)

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

print ('x: ')
print (sess.run(x))
print ('regular: ')
print (sess.run(y_r))
print ('depthwise: ')
print (sess.run(y_d))
print ('depthwise + pointwise: ')
print (sess.run(y_d_p))
print ('separable: ')
print (sess.run(y_s))

``````

結果２

Tensor("Reshape:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_1:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
regular:
[[[[ 1.9705558 0.33328748]
[ 1.8040164 -4.134241 ]
[ 1.6374781 -8.60177 ]
[ 1.4709401 -13.069299 ]
[ 1.3044 -17.536827 ]
[ 1.137861 -22.004358 ]
[ 0.97132176 -26.471891 ]
[ 0.80478483 -30.939413 ]]]]
depthwise:
[[[[-0.63552576 5.332875 ]
[ 2.240392 8.880491 ]
[ 5.11631 12.428108 ]
[ 7.9922285 15.975724 ]
[10.868145 19.523342 ]
[13.744063 23.070957 ]
[16.619982 26.618574 ]
[19.4959 30.166191 ]]]]
depthwise + pointwise:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]
separable:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]

1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
1. We will deliver articles that match you
By following users and tags, you can catch up information on technical fields that you are interested in as a whole
2. you can read useful information later efficiently
By "stocking" the articles you like, you can search right away