yおよびy_hatのサンプル数(shapeの None
)= 3, config.N_CLASSES
= 4 として実験してみました。
y_hatについて、argmaxを取ったときにshapeが "潰れて" しまいますね。すなわち、@leooon さんは(None, 15587)
-> (None, 1)
になることを期待しているのだと思いますが、実際は (None)
になってしまっています。
Python 3.9.10 (main, Feb 4 2022, 11:07:01)
[Clang 13.0.0 (clang-1300.0.29.30)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> tf.__version__
'2.8.0'
>>> n_classes = 4
>>> y_hat = tf.constant([[0, 2, 3, 1], [4, 7, 6, 5], [11, 8, 9, 10]])
>>> y_hat.shape
TensorShape([3, 4])
>>> tf.math.argmax(y_hat, axis=1).shape
TensorShape([3])
>>> tf.one_hot(tf.math.argmax(y_hat, axis=1), n_classes).shape
TensorShape([3, 4])
yの方は想定通りですね。
>>> y = tf.constant([[3], [0], [1]])
>>> y.shape
TensorShape([3, 1])
>>> tf.one_hot(y, n_classes).shape
TensorShape([3, 1, 4])
解決策として例えば、潰れてしまったaxisを増やしてやることが考えられます。
>>> pred = tf.math.argmax(y_hat, axis=1)
>>> pred.shape
TensorShape([3])
>>> pred = tf.expand_dims(pred, axis=-1)
>>> pred.shape
TensorShape([3, 1])
>>> tf.one_hot(pred, n_classes).shape
TensorShape([3, 1, 4])
numpyだとkeepdims
というオプション引数があり、出力にてaxisが潰れないようにできるのですが、tf.math.argmaxには同様のオプションがなさそうでした。
逆に、yのaxis=1を潰してしまうのもいいですね。
>>> y.shape
TensorShape([3, 1])
>>> y = tf.squeeze(y, axis=1)
>>> y.shape
TensorShape([3])
>>> tf.one_hot(y, n_classes).shape
TensorShape([3, 4])