LoginSignup
5
4

More than 5 years have passed since last update.

tensorflowでargmaxになる複数のインデックスを取得する

Last updated at Posted at 2016-05-13


備忘録。
通常のargmaxを使用すると、singleのindexしか取得する事が出来ない。
そこで、argmaxになるような複数のindexを取得出来るようにする。

import tensorflow as tf
a = tf.constant(np.array([1,2,9,2,1,9]))

sess = tf.Session()
sess.run(tf.squeeze(tf.where(tf.equal(a,tf.reduce_max(a,0)))))

簡単に解説
tf.reduce_maxでarray内のmaxを取得する
# 9

tf.equalで、array内のmaxの値とarrayの値をそれぞれ比較する([False,False,True,False,False,True]になる)
# array([False, False, True, False, False, True], dtype=bool)

tf.whereで、Trueであるindicesを取得する
# array([[2],[5]])

tf.squeezeで、dimensionを減らす
# array([2, 5])

ちなみにnumpyだと

import numpy as np
a = np.array([1,2,9,2,1,9])
print np.squeeze(np.where(np.equal(a,np.max(a,0))))

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4