目的
「ゼロから作る Deep Learning」 を読んでいて、argmaxの振る舞いが直感的に理解できなかったので整理しました。
argmax とは
最大値の"index"を取得します。
第一引数が対象の配列、そして理解しにくいのが第二引数です。
公式を見てもあまりよくわからず。
一般的な使い方: np.argmax(data)
これはdataを渡せば、その中の最大値のインデックスを取得します。
例えば、
data = np.array([[1, 2, 3],[3,4,5]])
の2×3の行列だと、結果はインデックス5が返ります。
axis=0の指定: np.argmax(data, axis=0)
これは「列指定」になります。
data = np.array([[1, 2, 3],[3,4,5]])
の2×3の行列だと、列で見ます。
[1,2,3] [3,4,5]
だと、[1,3], [2,4], [3,5]の組み合わせで見るので、結果は[1,1,1]のインデックスが返ります。
axis=1の指定: np.argmax(data, axis=1)
これは「行指定」になります。
data = np.array([[1, 2, 3],[3,4,5]])
の2×3の行列だと、行で見ます。
[1,2,3] [3,4,5]
だと、[1,2,3], [3,4,5]の組み合わせで見るので、結果は[2,2]のインデックスが返ります。
ポイント
- axisの指定(0がどっちだっけ..)になりやすいこと → 定数で記録しておくとよい
- その場合の振る舞いがわかりにくい → 簡単な例で振る舞いを理解しておけば複雑になっても理解しやすい
だと思います。