34
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tensorflow.nn.conv2dとconv2d_transoposeのフィルタサイズの設定方法

Last updated at Posted at 2017-05-27

Tensorflowのconv2dとconv2d_transposeの使い方で迷ったので調べた。

なお、紛らわしいですが下記で扱うのはtf.nn.conv2dおよびtf.nn.conv2d_transposeで、
tf.layers.conv2dおよびtf.layers.conv2d_transposeではないのでご注意ください。
普通に使う分にはkernel sizeと出力次元を指定するだけのtf.layers.conv2d(_transpose)のほうが楽かも。

tensorflow.ntt.conv2d の入力次元

まず、conv2dのIFは公式ドキュメントによると、

conv2d(
    input,
    filter,
    strides,
    padding,
    use_cudnn_on_gpu=None,
    data_format=None,
    name=None
)

と定義されている。

知りたいこと: inputは良いとして、fileter, strideにどのようなTensorを指定すれば良いのだろうか?

例えば、下記のようなコードを想定する。

import tensorflow as tf
sess = tf.Session()
batch_size = 3

l = tf.constant(0.1, shape=[batch_size, 32, 32, 4])
w = tf.constant(0.1, shape=[7, 7, 4, 128])
strides = [1, 2, 2, 1]

h1 = tf.nn.conv2d(l, w, strides=strides, padding='SAME')
print sess.run(h1).shape # (3, 16, 16, 128)

入力のinputが[3, 32, 32, 4]だったとする。これが「NHWC」形式で並んでいるとすると、高さ32 x 幅32 x 色チャネル4 x 3枚の画像、を意味することになる。

また、filterが[7, 7, 4, 128]であったとする。
公式によると、filterは[height, width,in_channel, out_channels]で指定するそうなので、これはkernelが高さ7 x 幅7で、入力チャネル数4、出力チャネル数128という意味になる。

strideは長さ4のint型のリストで与えられる。[1, 2, 2, 1]と与えたとすれば、これは入力inputの各次元に対するstride幅を示すということなので、NHWC系であれば「各画像1枚1枚に対して、幅方向2、高さ方向2、チャネル方向1のストライドを行う」という意味になる。

outputであるh1の形状は(3, 16, 16, 128)、つまり3枚x 高さ16 x 幅16 x 128チャネルの出力が得られることになる。

つまり、ざっくり言えばconv2dでは
NHWC系では

filterには [ (kernelの高さ) x (kernelの幅) x (入力ch数) x (出力ch数) ]
stridesには [1 x (上下stride幅) x (左右stride幅) x 1]

を指定すれば良く、
NCHW系では

filterには [ (kernelの高さ) x (kernelの幅) x (入力ch数) x (出力ch数) ]
stridesには [1 x 1 x (上下stride幅) x (左右stride幅) ]

を指定すれば良い。
(※ ChainerはNCHW系でデータを扱うことが多かった気がするのですが、TensorflowはNHWC系で扱うことが多いような...なぜだろう)


tensorflow.ntt.conv2d_transpose の入力次元

次はconv2d_transposeについて。この関数は生成系NNで言われるところの"Deconvolution"(逆畳み込み) に対応する。

まずはconv2d_transposeのIFの公式ドキュメントを確認する。

conv2d_transpose(
    value,
    filter,
    output_shape,
    strides,
    padding='SAME',
    data_format='NHWC',
    name=None
)

valueはconv2dと同じく4次元のTensor。
filterは[height, width, output_channels, in_channels] の4次元テンソルで指定しろ、とのこと。
output_shapeを指定する必要のある部分がconv2dと異なっており、ここには出力の形状を1次元Tensorで指定するとのこと。
stridesはconv2dと同じく、valueのdata_formatを踏まえた指定となる。

知りたいこと: fileterとoutput_shapeはどのように指定すれば良いのか?
例えば、conv2dと同じような気持ちでこのようなコードを想定したとする。

sess = tf.Session()
batch_size = 3
output_shape = [batch_size, 8, 8, 128]
strides = [1, 2, 2, 1]

l = tf.constant(0.1, shape=[batch_size, 32, 32, 4])
w = tf.constant(0.1, shape=[7, 7, 128, 4])

h1 = tf.nn.conv2d_transpose(l, w, output_shape=output_shape, strides=strides, padding='SAME')
print sess.run(h1) # Error

つまり、3枚 x 高さ32 x 幅32 x ch4の入力に対して、
高さ7 x 幅7の大きさの入力4ch、出力ch128のフィルターをかけて、
3枚 x 高さ8 x 幅8 x 128chの出力を得ようとしている。

入力、出力のch数はあっており、一見うまく行きそうだが、これはoutputの形状不正でエラーとなる。

conv2d_transposeはconv2dのbackwardと対応しており、正しい形状を得るためにはまずはconv2dを考えてその逆を設定すれば良い。
つまり、output_shapeのTensorを入力として同じfilter/strideを用いたconv2dの出力のTensor形状がconv2d_transposeの適正な入力となる。

import tensorflow as tf
sess = tf.Session()
batch_size = 3
output_shape = [batch_size, 8, 8, 128]
strides = [1, 2, 2, 1]

w = tf.constant(0.1, shape=[7, 7, 128, 4])

output = tf.constant(0.1, shape=output_shape)
expected_l = tf.nn.conv2d(output, w, strides=strides, padding='SAME')

l = tf.constant(0.1, shape=expected_l.get_shape())

h1 = tf.nn.conv2d_transpose(l, w, output_shape=output_shape, strides=strides, padding='SAME')
print sess.run(h1).shape #(3, 8, 8, 128)

このために、

conv2dのfilter: [filter_height, filter_width, in_channels, out_channels]
conv2d_transposeのfilter: [height, width, output_channels, in_channels]

というようにinとoutが逆に設定されており、同じフィルター設定で逆操作として使うことができるようになっている。

まとめ

最後に、conv2dとconv2d_transposeのinput / output形状、value、filter、strideの関係を図にまとめた。

スクリーンショット 2017-05-27 17.57.19.png スクリーンショット 2017-05-27 17.58.02.png

※不思議に思ったこと: 他のDNNフレームワークだと、パディング無しの場合kernel sizeもoutputのサイズに影響する気がするが、Tensorflowの場合はkernel sizeがoutput sizeに影響を与えていない模様。この理解があってるのか、なぜこうなってるのか知りたい

こちらのStackOverflowの回答を参考にしました。
https://stackoverflow.com/questions/35488717/confused-about-conv2d-transpose/38059483#38059483

34
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
34
31

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?