Posted at

tensorflowのsoftmax cross entropy関数の実装例

tensorflowを用いたsoftmax cross entropy関数の実装方法は複数あります。

ここでは4つの実装方法についてまとめました。


実装


tf.softmax()

logits = tf.placeholder(tf.float32, [None, 2])

labels = tf.placeholder(tf.float32, [None, 2])

softmax = tf.nn.softmax(logits)
output = -tf.reduce_sum(tf.log(softmax) * labels, axis=1)


tf.nn.softmax_cross_entropy()

logits = tf.placeholder(tf.float32, [None, 2])

labels = tf.placeholder(tf.float32, [None, 2])

output = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels,
logits=logits)


tf.nn.sparse_softmax_cross_entropy()

logits = tf.placeholder(tf.float32, [None, 2])

sparse_labels = tf.placeholder(tf.int32, [None])

output = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits = logits,
labels = sparse_labels)


tf.one_hot

logits = tf.placeholder(tf.float32, [None, 2])

sparse_labels = tf.placeholder(tf.int32, [None])

softmax = tf.nn.softmax(logits)
one_hot = tf.one_hot(sparse_labels, 2)
output = -tf.reduce_sum(tf.log(softmax) * one_hot, axis=1)


まとめ

import numpy as np

import tensorflow as tf

logits = tf.placeholder(tf.float32, [None, 2])
labels = tf.placeholder(tf.float32, [None, 2])
sparse_labels = tf.placeholder(tf.int32, [None])

# Use softmax function
softmax = tf.nn.softmax(logits)
output1 = -tf.reduce_sum(tf.log(softmax) * labels, axis=1)

# Use softmax_cross_entropy_with_logits function
output2 = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels,
logits=logits)

# Use sparse_softmax_cross_entropy_with_logits function
output3 = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits = logits,
labels = sparse_labels)

# Use one_hot function
one_hot = tf.one_hot(sparse_labels, 2)
output4 = -tf.reduce_sum(tf.log(softmax) * one_hot, axis=1)

logit = np.array([[0.7, 0.3],
[0.2, 0.8]])

label = np.array([[1, 0],
[0, 1]])

sparse_label = np.array([0,
1])

with tf.Session() as sess:
y1, y2, y3, y4 = sess.run(
[output1, output2, output3, output4],
feed_dict = {logits: logit,
labels: label,
sparse_labels: sparse_label
}
)

one_hot = sess.run(one_hot, feed_dict={sparse_labels: sparse_label})

print("y1:", y1)
print("y2:", y2)
print("y3:", y3)
print("y4:", y4)


Reference

TensorFlowのクロスエントロピー関数の動作