1
@takeshikondo

Separable Convolution の計算方法に関するメモ

More than 1 year has passed since last update.

ポイント

• Separable Convoluation の計算方法を確認。

レファレンス

1. Depthwise Separable Convolutions for Neural Machine Translation

（参照論文より引用）

サンプルコード

``````x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

w1 = tf.constant(value = 1.0, shape = [1, 3, 2, 2], \
dtype = tf.float32)
w2 = tf.constant(value = 1.0, shape = [1, 3, 2, 1], \
dtype = tf.float32)
w3 = tf.constant(value = 1.0, shape = [1, 1, 2, 2], \
dtype = tf.float32)

# y1 = y3 = y4
y1 = tf.nn.conv2d(x, w1, strides = [1, 1, 1, 1], \

y2 = tf.nn.depthwise_conv2d(x, w2, strides = [1, 1, 1, 1], \
y3 = tf.nn.conv2d(y2, w3, strides = [1, 1, 1, 1], \

y4 = tf.nn.separable_conv2d(x, w2, w3, \
strides = [1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y1)
print (y2)
print (y3)
print (y4)

with tf.Session() as sess:
print ('x: ')
print (sess.run(x))
print ('y1: ')
print (sess.run(y1))
print ('y2: ')
print (sess.run(y2))
print ('y3: ')
print (sess.run(y3))
print ('y4: ')
print (sess.run(y4))

``````

結果

Tensor("Reshape_19:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D_21:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise_10:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_22:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d_3:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
y1:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y2:
[[[[ 6. 9.]
[12. 15.]
[18. 21.]
[24. 27.]
[30. 33.]
[36. 39.]
[42. 45.]
[48. 51.]]]]
y3:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]
y4:
[[[[15. 15.]
[27. 27.]
[39. 39.]
[51. 51.]
[63. 63.]
[75. 75.]
[87. 87.]
[99. 99.]]]]

サンプルコード２

``````x = tf.convert_to_tensor(np.arange(20).astype(np.float32))
x = tf.reshape(x, [1, 1, -1, 2])

init = tf.truncated_normal_initializer(mean = 0.0, \
stddev = 1.0, dtype = tf.float32)
w_r = tf.get_variable('w_r', shape = [1, 3, 2, 2], \
initializer = init)
w_d = tf.get_variable('w_d', shape = [1, 3, 2, 1], \
initializer = init)
w_p = tf.get_variable('w_p', shape = [1, 1, 2, 2], \
initializer = init)

# regular
y_r = tf.nn.conv2d(x, w_r, strides = [1, 1, 1, 1], \

# depthwise + pointwise
y_d = tf.nn.depthwise_conv2d(x, w_d, strides = \
[1, 1, 1, 1],  padding = 'VALID')
y_d_p = tf.nn.conv2d(y_d, w_p, strides = [1, 1, 1, 1], \

# separable
y_s = tf.nn.separable_conv2d(x, w_d, w_p, strides = \
[1, 1, 1, 1,], padding = 'VALID')

print (x)
print (y_r)
print (y_d)
print (y_d_p)
print (y_s)

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

print ('x: ')
print (sess.run(x))
print ('regular: ')
print (sess.run(y_r))
print ('depthwise: ')
print (sess.run(y_d))
print ('depthwise + pointwise: ')
print (sess.run(y_d_p))
print ('separable: ')
print (sess.run(y_s))

``````

結果２

Tensor("Reshape:0", shape=(1, 1, 10, 2), dtype=float32)
Tensor("Conv2D:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("depthwise:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("Conv2D_1:0", shape=(1, 1, 8, 2), dtype=float32)
Tensor("separable_conv2d:0", shape=(1, 1, 8, 2), dtype=float32)
x:
[[[[ 0. 1.]
[ 2. 3.]
[ 4. 5.]
[ 6. 7.]
[ 8. 9.]
[10. 11.]
[12. 13.]
[14. 15.]
[16. 17.]
[18. 19.]]]]
regular:
[[[[ 1.9705558 0.33328748]
[ 1.8040164 -4.134241 ]
[ 1.6374781 -8.60177 ]
[ 1.4709401 -13.069299 ]
[ 1.3044 -17.536827 ]
[ 1.137861 -22.004358 ]
[ 0.97132176 -26.471891 ]
[ 0.80478483 -30.939413 ]]]]
depthwise:
[[[[-0.63552576 5.332875 ]
[ 2.240392 8.880491 ]
[ 5.11631 12.428108 ]
[ 7.9922285 15.975724 ]
[10.868145 19.523342 ]
[13.744063 23.070957 ]
[16.619982 26.618574 ]
[19.4959 30.166191 ]]]]
depthwise + pointwise:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]
separable:
[[[[ 0.6532894 -5.3626175]
[ -3.0205643 -10.077426 ]
[ -6.6944184 -14.792234 ]
[-10.368273 -19.507042 ]
[-14.042125 -24.221851 ]
[-17.715979 -28.936659 ]
[-21.389833 -33.651466 ]
[-25.063686 -38.366276 ]]]]

