8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlow (1.0系) での Global Average Pooling の汎用的な書き方

Last updated at Posted at 2018-02-04

TensorFlow 2.0系では tf.keras.layers.GlobalAveragePooling2D() で完了します。
一応 2.0 系でも似たような書き方はできますが、よほど特殊なレイヤーでなければ必要ないはずです。

TensorFlow での Global Average Pooling の汎用的な書き方

検索しても見つけられなかったので書きました。参考になっていただければ幸いです。
Global Average Pooling に関しては、ネット上に秀逸な記事が豊富なので説明は省きます。

入力と出力のテンソル形状が判明している場合は (MNIST など)

x = tf.nn.avg_pool(x, ksize=[1,7,7,1], strides=[1,1,1,1], padding="VALID")

のような書き方がありますね。(変数 x は処理するテンソルです)

しかし入力画像がフリーサイズの場合、上記の方法だとカーネルサイズを合わせることができずにうまく行きません。
そこで、個人的に以下の書き方で Global Average Pooling を実装しているので紹介します。
他にも色々な方法があると思いますが、その一つだと思ってくれれば幸いです。

書き方

  • 想定する入力の shape: [バッチサイズ, 高さ, 幅, チャンネル数]
  • 想定する出力の shape: [バッチサイズ, チャンネル数]

以下の関数で記述できる。

def global_average_pooling(x):
    for _ in range(2):
        x = tf.reduce_mean(x, axis=1)
    return x

解説

tf.reduce_mean によってだるま落としのようにテンソルを削り、平均を取る。

例えば、Cifar-10 のような 10クラス分類に使う場合、Global Average Pooling の前に カーネルサイズ 1x1Convolution などの方法で 10チャンネルにしましょう。
その後、Global Average Pooling をすることでテンソルの形状が [バッチサイズ, 10] になります。
後は Softmax 等にかけて出力しましょう。

おわりに

説明不足であればコメントで対応します。
実装におかしな点があればこちらもコメントで対応します。

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?