0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

tensorflowのsoftmax cross entropy関数の実装例

Posted at

tensorflowを用いたsoftmax cross entropy関数の実装方法は複数あります。
ここでは4つの実装方法についてまとめました。

実装

tf.softmax()

logits = tf.placeholder(tf.float32, [None, 2])
labels = tf.placeholder(tf.float32, [None, 2])

softmax = tf.nn.softmax(logits)
output = -tf.reduce_sum(tf.log(softmax) * labels, axis=1)

tf.nn.softmax_cross_entropy()

logits = tf.placeholder(tf.float32, [None, 2])
labels = tf.placeholder(tf.float32, [None, 2])

output = tf.nn.softmax_cross_entropy_with_logits_v2(
         labels=labels, 
         logits=logits)

tf.nn.sparse_softmax_cross_entropy()

logits = tf.placeholder(tf.float32, [None, 2])
sparse_labels = tf.placeholder(tf.int32, [None])

output = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits = logits,
        labels = sparse_labels)

tf.one_hot

logits = tf.placeholder(tf.float32, [None, 2])
sparse_labels = tf.placeholder(tf.int32, [None])

softmax = tf.nn.softmax(logits)
one_hot = tf.one_hot(sparse_labels, 2)
output = -tf.reduce_sum(tf.log(softmax) * one_hot, axis=1)

まとめ

import numpy as np
import tensorflow as tf

logits = tf.placeholder(tf.float32, [None, 2])
labels = tf.placeholder(tf.float32, [None, 2])
sparse_labels = tf.placeholder(tf.int32, [None])

# Use softmax function
softmax = tf.nn.softmax(logits)
output1 = -tf.reduce_sum(tf.log(softmax) * labels, axis=1)

# Use softmax_cross_entropy_with_logits function
output2 = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=labels, 
        logits=logits)

# Use sparse_softmax_cross_entropy_with_logits function
output3 = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits = logits,
        labels = sparse_labels)

# Use one_hot function
one_hot = tf.one_hot(sparse_labels, 2)
output4 = -tf.reduce_sum(tf.log(softmax) * one_hot, axis=1)

logit = np.array([[0.7, 0.3],
                  [0.2, 0.8]])

label = np.array([[1, 0],
                  [0, 1]])

sparse_label = np.array([0,
                         1])

with tf.Session() as sess:
    y1, y2, y3, y4 = sess.run(
            [output1, output2, output3, output4],
            feed_dict = {logits: logit, 
                         labels: label,
                         sparse_labels: sparse_label
            }
    )

    one_hot = sess.run(one_hot, feed_dict={sparse_labels: sparse_label})

    print("y1:", y1)
    print("y2:", y2)
    print("y3:", y3)
    print("y4:", y4)

Reference

TensorFlowのクロスエントロピー関数の動作

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?