Python
numpy

[Python] Numpyで多次元配列のargmaxをとりたい

使い方

PythonでNumpyを使っている時,多次元配列に対してargmaxargminを使いたい時があります.
例えば,

$$
\begin{bmatrix}
11 & 22 & 33 \\
77 & 88 & 99 \\
44 & 55 & 66
\end{bmatrix}
$$

であったら,99が位置している (1, 2) が欲しいということですね.
そんな時は,numpy.unravel_indexを使えば1行で取得することができるようです.

np.unravel_index(np.argmax(array), array.shape)
>> (1, 2)

こんなかんじ.

以下全体.

import numpy as np
array = np.array([
    [11, 22, 33],
    [77, 88, 99],
    [44, 55, 66]]
)

np.unravel_index(np.argmax(array), array.shape)
>> (1, 2)


解説

簡単に解説すると,np.argmax(array)は,arrayを1次元配列と解釈したときの最大値があるインデックスを返却します.今回であれば99は6つ目なので,5が返ってきます.ちなみに,shape(3, 3)ですね.

np.argmax(array)
>> 5

array.shape
>> (3, 3)

つぎにnumpy.unravel_index(indices, dims)の挙動ですが,ここでindecesは前から数えた時の数を表す整数で,dimsは次元を表すタプルです.
要するに,dimsを次元としてもつ行列で,前から数えてindeces番目のインデックスを「多次元で」返してくれます.

99は前から5番目なのでindeces5,行列は$3\times3$ですのでdims(3, 3)ですね.
日本語でいうと,

$3\times3$の行列において,頭から5番目の要素のインデックスはいくつですか?

ということです.

np.unravel_index(5, (3, 3))
>> (1, 2)

はい,(1, 2) ですね!


公式リファレンンス
numpy.unravel_index