TensorFlow 2.0系では
tf.keras.layers.GlobalAveragePooling2D()
で完了します。
一応 2.0 系でも似たような書き方はできますが、よほど特殊なレイヤーでなければ必要ないはずです。
TensorFlow での Global Average Pooling の汎用的な書き方
検索しても見つけられなかったので書きました。参考になっていただければ幸いです。
Global Average Pooling
に関しては、ネット上に秀逸な記事が豊富なので説明は省きます。
入力と出力のテンソル形状が判明している場合は (MNIST など)
x = tf.nn.avg_pool(x, ksize=[1,7,7,1], strides=[1,1,1,1], padding="VALID")
のような書き方がありますね。(変数 x
は処理するテンソルです)
しかし入力画像がフリーサイズの場合、上記の方法だとカーネルサイズを合わせることができずにうまく行きません。
そこで、個人的に以下の書き方で Global Average Pooling
を実装しているので紹介します。
他にも色々な方法があると思いますが、その一つだと思ってくれれば幸いです。
書き方
- 想定する入力の shape: [バッチサイズ, 高さ, 幅, チャンネル数]
- 想定する出力の shape: [バッチサイズ, チャンネル数]
以下の関数で記述できる。
def global_average_pooling(x):
for _ in range(2):
x = tf.reduce_mean(x, axis=1)
return x
解説
tf.reduce_mean
によってだるま落としのようにテンソルを削り、平均を取る。
例
例えば、Cifar-10 のような 10クラス分類に使う場合、Global Average Pooling
の前に カーネルサイズ 1x1
の Convolution
などの方法で 10チャンネルにしましょう。
その後、Global Average Pooling をすることでテンソルの形状が [バッチサイズ, 10]
になります。
後は Softmax
等にかけて出力しましょう。
おわりに
説明不足であればコメントで対応します。
実装におかしな点があればこちらもコメントで対応します。