備忘録。
通常のargmaxを使用すると、singleのindexしか取得する事が出来ない。
そこで、argmaxになるような複数のindexを取得出来るようにする。
import tensorflow as tf
a = tf.constant(np.array([1,2,9,2,1,9]))
sess = tf.Session()
sess.run(tf.squeeze(tf.where(tf.equal(a,tf.reduce_max(a,0)))))
簡単に解説
tf.reduce_maxでarray内のmaxを取得する
# 9
tf.equalで、array内のmaxの値とarrayの値をそれぞれ比較する([False,False,True,False,False,True]になる)
# array([False, False, True, False, False, True], dtype=bool)
tf.whereで、Trueであるindicesを取得する
# array([[2],[5]])
tf.squeezeで、dimensionを減らす
# array([2, 5])
ちなみにnumpyだと
import numpy as np
a = np.array([1,2,9,2,1,9])
print np.squeeze(np.where(np.equal(a,np.max(a,0))))