LoginSignup
12
7

More than 5 years have passed since last update.

tf.metricsを使ってみる

Last updated at Posted at 2017-07-18

きっかけ

tf.metrics.accuracyを使おうとしたら,返り値が2つ (accuracy, update_op) あったり値が普通の正答率じゃなかったりして悩まされました.
tf.metrics.recallやtf.metrics.precisionも同様でした.
ざっと見た感じ,現時点でこれに関する日本語記事はほぼ無いようなので,とりあえずメモ.

tf.metricsの挙動

名前の通り,正答率をはじめとする各種metricsを算出します.

ですが,名前だけ見た人は,

# labels: 正答ラベルの1次元テンソル
# predictions: 予測されたラベルの1次元テンソル

accuracy, update_op = tf.metrics.accuracy(labels, predictions)
accuracy = tf.reduce_mean(tf.cast(predictions == labels, tf.float32))

これら2つのaccuracyが同じ値になることを期待するのではないでしょうか.
また,update_opとはなんぞやと思うのではないでしょうか.

結論から言うと,tf.metrics.accuracyは,過去の全ての値を保持しているかのように振る舞います.
(実際には,過去の合計正答数totalとデータ数countを保持して,「合計÷個数」しているだけ).

つまり,一度目のエポックで全問正解し,二度目のエポックで全問間違ったならば (そして各エポックのバッチサイズが常に同じならば),一回目のaccuracyは1.00,二回目のaccuracyは0.50になります.
もし三度目のエポックで全問正解したならば,三回目のaccuracyは約0.67です.

この挙動に関しては,tensorflowのissueなどを見た限りでも困惑している人が多いようですね.「非直感的だ」「tf.metrics.streaming_accuracyのほうがこの関数に相応しい名前なのではないか」のような意見が見られます.

ちなみに,ある回答者曰く

  • 『streaming_◯◯』と命名されていないのは,どのみちすべてのtf.metricsがstreamingであるからだ.
  • non-streamingなmetricsを作らない理由は,それらの実装がどのみち容易いことだからだ.
  • 1バッチ単位の正答率に関心があるわけではないので,non-streamingなmetricsには興味がないことが多い.

とのこと.なるほど腑に落ちた.確かに便利そうに思えます.

tf.metricsの使い方

tf.metricsは2つの返り値を持ちます.accuracyとupdate_opです.

update_opを呼ぶと正答率が更新されます.accuracyは最後に算出した正答率 (初期値は0) を保持します.

要するに,こんな感じです.

import numpy as np
import tensorflow as tf

labels = tf.placeholder(tf.float32, [None])
predictions = tf.placeholder(tf.float32, [None])
accuracy, update_op = tf.metrics.accuracy(labels, predictions)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(accuracy))  # 初期値0

    # 1回目 (全問正解)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 3 = 1

    # 2回目 (全問間違い)
    sess.run(update_op, feed_dict={
        labels: np.array([0, 0, 0]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 3 / 6 = 0.5

    # 3回目 (全問正解)
    sess.run(update_op, feed_dict={
        labels: np.array([1, 1, 1]),
        predictions: np.array([1, 1, 1])
    })
    print(sess.run(accuracy))  # 6 / 9 = 約0.67

tf.metricsを使った実装

これが良いのかどうかは分かりませんが,例えばこんな感じですかね.
他に良い方法があれば教えてください.

def create_metrics(labels, predictions, register_to_summary=True):
    update_op, metrics_op = {}, {}

    # accuracy, recall, precisionの算出にはtf.metricsを使用
    for key, func in zip(('accuracy', 'recall', 'precision'),
                         (tf.metrics.accuracy, tf.metrics.recall, tf.metrics.precision)):
        metrics_op[key], update_op[key] = func(labels, predictions, name=key)

    # f1_scoreは自力で計算
    metrics_op['f1_score'] = tf.divide(
        2 * metrics_op['precision'] * metrics_op['recall'],
        metrics_op['precision'] + metrics_op['recall'] + 1e-8,
        name='f1_score'
    )  # 1e-8はゼロ除算対策

    entire_update_op = tf.group(*update_op.values())

    if register_to_summary:  # あとでtf.summary.merge_all()できるように
        for k, v in metrics_op.items():
            tf.summary.scalar(k, v)

    return metrics_op, entire_update_op

metrics_op, entire_update_op = create_metrics(labels, predictions)
merged = tf.summary.merge_all()

何が言いたい・やりたいのかというと,要するに

  • 複数のmetricsを併用する場合は,それぞれのupdate_opをtf.groupでまとめておくと楽.
  • tf.summaryでログ取るなら各metricsを後からmerge_allすると楽.
  • f1_scoreはtf.metrics内に見当たらないから自作するしかないかも.

ということです.

備考

ちなみに,これらmetricsはglobal variablesではなくlocal variablesなので,

local_init_op = tf.local_variables_initializer()
sess.run(local_init_op)

する必要があります.

12
7
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
7