使い方
PythonでNumpyを使っている時,多次元配列に対してargmax
やargmin
を使いたい時があります.
例えば,
$$
\begin{bmatrix}
11 & 22 & 33 \\
77 & 88 & 99 \
44 & 55 & 66
\end{bmatrix}
$$
であったら,99が位置している (1, 2)
が欲しいということですね.
そんな時は,numpy.unravel_index
を使えば1行で取得することができるようです.
np.unravel_index(np.argmax(array), array.shape)
>> (1, 2)
こんなかんじ.
以下全体.
import numpy as np
array = np.array([
[11, 22, 33],
[77, 88, 99],
[44, 55, 66]]
)
np.unravel_index(np.argmax(array), array.shape)
>> (1, 2)
## 解説
簡単に解説すると,np.argmax(array)
は,array
を1次元配列と解釈したときの最大値があるインデックスを返却します.今回であれば99は6つ目なので,5
が返ってきます.ちなみに,shape
は(3, 3)
ですね.
np.argmax(array)
>> 5
array.shape
>> (3, 3)
つぎにnumpy.unravel_index(indices, dims)
の挙動ですが,ここでindeces
は前から数えた時の数を表す整数で,dims
は次元を表すタプルです.
要するに,dims
を次元としてもつ行列で,前から数えてindeces
番目のインデックスを「多次元で」返してくれます.
99は前から5番目なのでindeces
は5
,行列は$3\times3$ですのでdims
は(3, 3)
ですね.
日本語でいうと,
$3\times3$の行列において,頭から5番目の要素のインデックスはいくつですか?
ということです.
np.unravel_index(5, (3, 3))
>> (1, 2)
はい,(1, 2)
ですね!
> 公式リファレンンス > [numpy.unravel_index](https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.unravel_index.html)