Help us understand the problem. What is going on with this article?

TensorFlowのクロスエントロピー関数の動作

More than 1 year has passed since last update.

はじめに

TensorFlowでは多クラス分類をするためにクロスエントロピーを計算する関数が用意されていますが、内部でソフトマックス関数を実行しているため、動作がわかりにくくなっています。サンプルコードを見れば使い方はなんとなくわかりますが、自分の理解があっているのか確認したかったので、そのためのコードを書いてみました。
(この記事はTensorFlow ver 1.x向けに書かれた記事です。ver 2.xでは変更されている可能性があります。)

ソフトマックス関数

ソフトマックス関数はニューラルネットワークが出力する値$x_i$を確率に変更する活性化関数の一種で次のように定義されます。ただし、分類するクラス数を$K$として、ニューラルネットワークは$x_1,x_2,...,x_K$の値を出力しているものとします。

softmax(x_i) := \frac{\exp(x_i)}{\sum_{k=1}^K \exp(x_k)}

ソフトマックス関数の出力はすべて$0$から$1$の範囲に収まり、合計は$1$になります。

多クラス分類ではニューラルネットワークの出力をこの関数に入力して得られた結果を分類の推定確率とします。

クロスエントロピー

クロスエントロピーは多クラス分類でよく使われる誤差関数の一つで、$d$番目のデータに対して、クラス$i$である推定確率$x_{d,i}$と真の確率$y_{d,i}$の誤差を計算します。
定義は次のようになります。

crossentropy(\boldsymbol{x}, \boldsymbol{y}) := - \sum_{d=1}^D\sum_{i=1}^K y_{d,i} \log(x_{d,i})

ディープラーニングの多クラス分類では$x_{d,i}$にソフトマックス関数の出力を、$y_{d,i}$に教師データをそれぞれ入力します。

TensorFlowでは

ソフトマックス関数はtf.nn.softmaxという関数で計算できます。

クロスエントロピーについてはTensorFlowでは単体の関数は用意されていないようなので、ソフトマックス関数とクロスエントロピーをセットにした、tf.nn.softmax_cross_entropy_with_logits_v2tf.nn.sparse_softmax_cross_entropy_with_logitsを使います。"with_logits"というのは、この関数は内部でソフトマックスも計算するから、ニューラルネットワークの出力をそのまま入力してね、という意味です。
(古いバージョンでは_v2という名前ではなかったのですが、最新版では_v2のみ残っています。また、sparse_で始まる方は_v2はないようです。)

注意点として、TensorFlowの関数はクロスエントロピーの式の2つある$\sum$のクラスに対する和$\sum_{i=1}^K$しか計算しないので、バッチ内データの和$\sum_{d=1}^D$は自分で計算する必要があり、reduce_sumを使うか、tf.lossesに定義されている関数を使うことで計算できます。

tf.nn.softmax_cross_entropy_with_logits_v2tf.nn.sparse_softmax_cross_entropy_with_logitsは教師データのデータ形式で使い分けます。
前者の教師データにはone hot(一つだけ1で他は0の配列)や各クラスに属する確率の配列を使用します。
後者は教師データとしてクラスのインデックスを整数で指定します。

確認用コード

TensorFlowでクロスエントロピーを計算するコードを書いてみました。クロスエントロピーを手動で計算する方法(y1)とtf.nn.softmax_cross_entropy_with_logits_v2(y2)、tf.nn.sparse_softmax_cross_entropy_with_logits(y3)で結果を比較しています。

下のコードではxが実際のニューラルネットワークの出力(ソフトマックス関数などの活性化関数を通す前)の値になります。
labelsoftmax_cross_entropy_with_logits用の教師データでクラス数と同じ数の要素を持つ配列になります。要素の値はそれぞれの確率になります。
sparse_labelsparse_softmax_cross_entropy_with_logits用の教師データで一番確率の高いクラスのインデックスを指定します。

y1y2は同じ結果になります。(多少の誤差があるようですが。)
y2y3については、y2用の教師データがone hotの時は同じ結果になり、確率分布で与えた場合は異なる結果になります。

TensorFlowのsoftmax_cross_entropy_with_logitsのリファレンスにも、教師データが一つのクラスだけを指定する場合はsparse_softmax_cross_entropy_with_logitsが使えるよ、みたいなことが書いてあります。

import tensorflow as tf

### グラフ定義
x = tf.placeholder(tf.float32, [None,3])
label = tf.placeholder(tf.float32, [None,3])
sparse_label = tf.placeholder(tf.int32, [None])

# 個別に計算
#   xをソフトマックスしてからlabelとのクロスエントロピーを計算
soft = tf.nn.softmax(x)                            # ソフトマックス: exp(x) / sum(exp(x))
y1 = - tf.reduce_sum(tf.log(soft) * label, axis=1) # クロスエントロピー(クラス方向の和のみ計算)

# softmax_cross_entropy_with_logits
#   ソフトマックスとクロスエントロピーを同時に実行してくれる
#   labelは確率分布(one hotでなくてもよい)
# y2 = tf.nn.softmax_cross_entropy_with_logits(labels=label, logits=x) # 古いバージョン
y2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=x)

# sparse_softmax_cross_entropy_with_logits
#   ソフトマックスとクロスエントロピーを同時に実行してくれる
#   labelは分類先のインデックスを指定する
y3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sparse_label, logits=x)

### 計算実行
sess = tf.Session()
x_ = [
    [6,5,4],
    [2,5,4],
    [3,1,6],
    [3,1,6]
]
l = [
    [1,0,0],
    [0,1,0],
    [0,0,1],
    [0.3,0.1,0.6]
]
sl = [
    0,
    1,
    2,
    2
]
y = sess.run([y1, y2, y3], feed_dict={x: x_, label: l, sparse_label: sl})

# y1とy2は同じ値になる
# y2とy3はlabelがone hotの時に同じ値になる
for i in range(len(y[0])):
    print('y1: %.8f, y2: %.8f, y3: %.8f' % (y[0][i], y[1][i], y[2][i]))

結果

y1: 0.40760601, y2: 0.40760595, y3: 0.40760595
y1: 0.34901217, y2: 0.34901217, y3: 0.34901217
y1: 0.05498519, y2: 0.05498521, y3: 0.05498521
y1: 1.45498538, y2: 1.45498538, y3: 0.05498521

平均化するための関数

tf.nn.softmax_cross_entropy_with_logitstf.nn.sparse_softmax_cross_entropy_with_logitsは各データのクロスエントロピーを出力するので、学習時の損失関数としてはバッチ内の各データのクロスエントロピーの平均を取る必要があります。
tf.lossesにはそのための関数が用意されており、それぞれ
tf.losses.softmax_cross_entropytf.losses.sparse_softmax_cross_entropyが対応します。(with_logitsがなくなっていることに注意)

exy81
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away