概要
numpyを使っていると任意の格子点の要素を抜き出して配列を作りたい場合がある。そんなときに便利なのが、numpy.ix_
である。このnumpy.ix_
を使った方法を紹介する。
実装
Google Colabで作成した本記事のコードは、こちらにあります。
np.ix_
の入力は、1次元のシーケンスで各シーケンスは整数型またはbool型である必要がある。bool型シーケンスの場合は、np.nonzero(boolean_sequence)
に等しく、つまり、シーケンスのTrue
に対応したindexの格子点が抜きだされる。
np.ix_
の出力は、N次元の配列で、Nは入力シーケンスの数に対応する。
使い方は公式ドキュメントによると
Using ix_ one can quickly construct index arrays that will index the cross product. a[np.ix_([1,3],[2,5])] returns the array [[a[1,2] a[1,5]], [a[3,2] a[3,5]]].
NumPy numpy.ix_ version 1.23より引用
と書かれていて、外積をとった配列のindexのリストを返すことで、任意の格子点のindexの配列を作ることができる。
整数型の例
視覚的にも分かりやすくするために、抜き出したindexを赤色でプロットし2次元ヒートマップに表示する。左図がarray
で、右図が左図の赤色のプロットしたindexを抜き出して配列にしたmesh_aray
である。
import numpy as np
import matplotlib.pyplot as plt
array = np.arange(6*10).reshape(6, 10)
x = [0, 1, 3, 5, 9]
y = [0, 1, 4, 5]
mesh_array = array[np.ix_(y, x)]
print('array:\n', array)
print('mesh_array:\n', mesh_array)
"""2次元ヒートマップで表示"""
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 12))
ax1.imshow(array, vmin=np.min(array), vmax=np.max(array))
ax1.set_title('array')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
# 格子点をプロット
for x_elem in x:
for y_elem in y:
ax1.scatter(x_elem, y_elem, c='r')
ax2.imshow(mesh_array, vmin=np.min(array), vmax=np.max(array))
ax2.set_title('mesh_array')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
plt.show()
array:
[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]]
mesh_array:
[[ 0 1 3 5 9]
[10 11 13 15 19]
[40 41 43 45 49]
[50 51 53 55 59]]
bool型の例
bool型でも整数型の例と同様の結果を得ることができるので紹介する。
import numpy as np
import matplotlib.pyplot as plt
array = np.arange(6*10).reshape(6, 10)
x = [True, True, False, True, False, True, False, False, False, True] # 変更箇所
# np.nonzero(x) -> (array([0, 1, 3, 5, 9]),)
y = [True, True, False, False, True, True] # 変更箇所
# np.nonzero(y) -> (array([0, 1, 4, 5]),)
mesh_array = array[np.ix_(y, x)]
print('array:\n', array)
print('mesh_array:\n', mesh_array)
"""2次元ヒートマップで表示"""
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 12))
ax1.imshow(array, vmin=np.min(array), vmax=np.max(array))
ax1.set_title('array')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
# 格子点をプロット
for x_elem in np.nonzero(x)[0]:
for y_elem in np.nonzero(y)[0]:
ax1.scatter(x_elem, y_elem, c='r')
ax2.imshow(mesh_array, vmin=np.min(array), vmax=np.max(array))
ax2.set_title('mesh_array')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
plt.show()
出力結果は、整数型の例と同様である。
まとめ
numpy.ix_
の使い方を紹介しました。結構マイナーだと思いますが、使い所によっては便利なので是非取り入れて快適なnumpy環境を作っていきましょう。
参考資料