python3
TensorFlow

tf.reduce_meanの使い方と意味

はじめに

 TensorFlowモジュールにおいて、ニューラルネットを作成する際には損失関数などで必ず使用するであろうtf.reduce_meanの使い方について書いていきます。主にaxisがリストの場合やkeep_dimsがどういう意味かについて触れていこうと思います。自分はMobilenetのプログラムを紐解いているときにaxisがリストだったので、調べてみました。
 何故reduce関数の中に平均値を求めるものが含まれているのかについては項目「tf.reduce_meanの処理を数学的に追う」のex2で説明します。数学の記号についてはAppendixで書き記します。

tf.reduce_mean(input_tensor, axis=None, keep_dims=False, name=None, reduction_indices=None)

公式のtf.reduce_meanについての記述

 公式に関数説明が書かれていますので、簡単に引数と関数の役割についてまとめます。

tf.reduce_mean

 与えたリストに入っている数値の平均値を求める関数

Args(引数)

input_tensor

 input_tensorにはリストで値を渡してあげます。ようするに入力です。
ex) tf.reduce_mean([1, 3, 4, 5], ...)

axis

 axisは一昔前には最後に書いてある引数であるreduction_indicesとなっていました。理由はよくわかりませんが、今のバージョン(TensorFlow ver1.3)では変わっています。
axisには0や正の整数※1、またはリスト※1で値を渡してあげます。これはテンソルのランク※2をどのくらいに落とすかというものです。axisは軸という意味ですが、何故軸なのかということはAppendixで述べます。
ex) tf.reduce_mean([[1, 2], [4, 6]], 1, ...)
   tf.reduce_mean([[1, 2], [4, 6]], [0, 1], ...)

※1 値の範囲については後程記述しますので、こういう書き方をするんだなということだけ見ておいてください
※2 テンソルのランクの定義については別途投稿するので、そちらを参照してください
   https://qiita.com/uyuni/items/5659bb23219a00e7621d
   こちらのページのほうで解説していただいていたので、こちらを参照してください

keep_dims

 keep_dimsにはTrueまたはFalseを渡してあげます。略さないとkeep dimensions(次元の維持)です。Trueにするとテンソルの構造を維持したまま出力してくれます。
ex) tf.reduce_mean([1, 3, 4, 5], 0, keep_dims=True, ...)

name

 その名の通り名前ですが、指定した名前空間を借りて出力します。指定しても使われた名前空間には結果は代入されません。特に使う場面もないと思うので、基本無視でいいです。
ex) test = None
   tf.reduce_mean([1, 3, 4, 5], 0, keep_dims=True, name=test, ...)

reduce_indices

reduce_indicesはaxisと同じです。しかし、名残として残っているだけなので、使う必要は全くありません。

tf.reduce_meanの処理を数学的に追う

 上記のような引数を持ったtf.reduce_meanですが、詳しい処理が一体どうなっているのかということは公式では言及されていません。そこで、数式を用いてどのような処理が行われているのかということを見ていきます。
 ちなみにtensorflowのモジュールの関数を使うにはtf.Session()を使ってセッションを始めてあげなければいけません。わからなければすぐ出てきますので、ググってみてください。

ex1) axis=Noneの場合

 >>>sess = tf.Session() //次回から省略
 >>>net1 = tf.constant([1., 2., 3.]) //1.はfloat型という意味です。
 >>>mean1 = tf.reduce_mean(net1)
 >>>sess.run(mean1)
 2.0

 引数input_tensorにnet1だけを入れてみました。すると、どうでしょう。全ての要素を足し合わせたものの平均が出力されました。まさに平均値を求める関数ですね。
数式で示します。
$$mean1 = \frac{1}{2}\sum_{i=1}^{3}x_i$$
 今回はaxis=Noneだったわけですが、この場合は全ての要素を足し合わせて平均値を出力します。

ex2) axis=0の場合

 >>>net2 = tf.constant([[1., 2.], [4., 3.]])
 >>>mean2 = tf.reduce_mean(net2, 0)
 >>>sess.run(mean2)
 array([2.5, 2.5])

 rank2のテンソルをinputに入れてみると、今回はrank1のテンソルが出力されました。
数式が先ほどと変わるので示します。
$$mean2 = \frac{1}{2}\sum_{i=1}^{2}x_{ij}$$
$x_i$から$x_{ij}$に変わりました。これはテンソルの構造を表示してみるとわかります。

 >>>sess.run(net2)
 array([[1., 2.],
    [4., 3.]])

テンソルはこのような構造になっており、$x$を用いて要素として書いてあげると

 array([[$x_{11}$, $x_{12}$],
    [$x_{21}$, $x_{22}$]])

このように表せます。テンソルのrankによって次元(数学的意味のほう)が変化し、Σの扱い方が変わってきます。これこそがテンソルのランクを削減(reduction)するということであり、tf.reduce_meanがreduce関数に入っている意味なのです。この概念がtf.reduce_meanの理解をすることにおいての根幹といっても過言ではないので、適切に使いたい場合はしっかり理解したほうがいいと思います。

ex3) axis=1の場合

 >>>net3 = tf.constant([[1., 2.], [4., 3.]])
 >>>mean3 = tf.reduce_mean(net3, 1)
 >>>sess.run(mean3)
 array([1.5, 3.5])

 先ほどは$i$要素の総和でした。今回はご想像の通り$j$要素の総和です。
以下に数式を示します。
$$mean3 = \frac{1}{2}\sum_{j=1}^{2}x_{ij}$$
 ここまでのことをまとめますとaxisは0から始まり、テンソルのランク数の数まで増やせることになります。

 さて、axisが整数の場合は以上のようになりますが、リストの場合はどうでしょうか。
以下に示します。

ex4) axis=[0, 1]の場合

 >>>net4 = tf.constant([[1., 2.], [4., 3.]])
 >>>mean4 = tf.reduce_mean(net4, [0, 1])
 >>>sess.run(mean4)
 5

 総和の平均になっていますね。勘の良い方は察したんじゃないんでしょうか。
リストに入っている値はaxisを組み合わせたものです。数式を示します。
$$mean4 = \frac{1}{2・2}\sum_{i=1}^{2}\sum_{j=1}^{2}x_{ij}$$
 axis=0とaxis=1を組み合わせたものと同義です。このようにして自由自在にテンソルの構造を変えられるわけです。
余談ですが、axis=[1, 1]などの場合は後ろの1は無視されます。axis=1と同義です。

keep_dimsについて

これはおまけ要素として見ていただければ結構だと思います。

 >>>net5 = tf.constant([[[1., 2.], [4., 3.]], [[3., 2.], [4., 2.]]])
 >>>mean5 = tf.reduce_mean(net5, 0)
 >>>sess.run(mean5)
 array([2., 2.], [4., 2.5])

 >>>mean5 = tf.reduce_mean(net5, 0, keep_dims=True)
 >>>sess.run(mean5)
 array([2., 2.],
    [4., 2.5])

これだけです。
数式は以下の通りです。
$$mean5 = \frac{1}{2}\sum_{i=1}^{2}x_{ijk}$$

おわりに

 以上でtf.reduce_meanの説明は終わりになります。初めての投稿だったので、冗長な書き方をしているかもしれないのですが、大目に見ていただけると幸いです。
何かご質問や、ここが間違っているということがあればご意見をいただければと思います。
あとはAppendixになりますので、よくΣについて知らないという人がいれば見ていってください。(とはいっても他ページの貼り付けになります...)
ここまで見ていただき、ありがとうございました。

Appendix

数学記号のΣ(シグマ)について

Deep Learningの設計をやろうと思っている人で、線形代数を学んだことがない人がいれば学んでおいて損はないので、やっておくことをお勧めします。
以下は私が使っていた本です。線形空間の手前くらいまでやればDeep Learningの理解がしやすくなると思います。
やさしく学べる線形代数