1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【numpy】np.ix_を使った任意の格子点の要素を抜き出して配列を作成する方法

Posted at

概要

numpyを使っていると任意の格子点の要素を抜き出して配列を作りたい場合がある。そんなときに便利なのが、numpy.ix_である。このnumpy.ix_を使った方法を紹介する。

実装

Google Colabで作成した本記事のコードは、こちらにあります。

np.ix_の入力は、1次元のシーケンスで各シーケンスは整数型またはbool型である必要がある。bool型シーケンスの場合は、np.nonzero(boolean_sequence)に等しく、つまり、シーケンスのTrueに対応したindexの格子点が抜きだされる。
np.ix_の出力は、N次元の配列で、Nは入力シーケンスの数に対応する。
使い方は公式ドキュメントによると

Using ix_ one can quickly construct index arrays that will index the cross product. a[np.ix_([1,3],[2,5])] returns the array [[a[1,2] a[1,5]], [a[3,2] a[3,5]]].
NumPy numpy.ix_ version 1.23より引用

と書かれていて、外積をとった配列のindexのリストを返すことで、任意の格子点のindexの配列を作ることができる。

整数型の例

視覚的にも分かりやすくするために、抜き出したindexを赤色でプロットし2次元ヒートマップに表示する。左図がarrayで、右図が左図の赤色のプロットしたindexを抜き出して配列にしたmesh_arayである。

整数型の例
import numpy as np
import matplotlib.pyplot as plt

array = np.arange(6*10).reshape(6, 10)
x = [0, 1, 3, 5, 9]
y = [0, 1, 4, 5]
mesh_array = array[np.ix_(y, x)]

print('array:\n', array)
print('mesh_array:\n', mesh_array)

"""2次元ヒートマップで表示"""
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 12))

ax1.imshow(array, vmin=np.min(array), vmax=np.max(array))
ax1.set_title('array')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
# 格子点をプロット
for x_elem in x:
    for y_elem in y:
        ax1.scatter(x_elem, y_elem, c='r')

ax2.imshow(mesh_array, vmin=np.min(array), vmax=np.max(array))
ax2.set_title('mesh_array')
ax2.set_xlabel('x')
ax2.set_ylabel('y')

plt.show()
出力結果
array:
 [[ 0  1  2  3  4  5  6  7  8  9]
 [10 11 12 13 14 15 16 17 18 19]
 [20 21 22 23 24 25 26 27 28 29]
 [30 31 32 33 34 35 36 37 38 39]
 [40 41 42 43 44 45 46 47 48 49]
 [50 51 52 53 54 55 56 57 58 59]]
mesh_array:
 [[ 0  1  3  5  9]
 [10 11 13 15 19]
 [40 41 43 45 49]
 [50 51 53 55 59]]

出力結果
image.png

bool型の例

bool型でも整数型の例と同様の結果を得ることができるので紹介する。

bool型の例
import numpy as np
import matplotlib.pyplot as plt

array = np.arange(6*10).reshape(6, 10)
x = [True, True, False, True, False, True, False, False, False, True]  # 変更箇所
# np.nonzero(x) -> (array([0, 1, 3, 5, 9]),)
y = [True, True, False, False, True, True]  # 変更箇所
# np.nonzero(y) -> (array([0, 1, 4, 5]),)
mesh_array = array[np.ix_(y, x)]

print('array:\n', array)
print('mesh_array:\n', mesh_array)

"""2次元ヒートマップで表示"""
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 12))

ax1.imshow(array, vmin=np.min(array), vmax=np.max(array))
ax1.set_title('array')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
# 格子点をプロット
for x_elem in np.nonzero(x)[0]:
    for y_elem in np.nonzero(y)[0]:
        ax1.scatter(x_elem, y_elem, c='r')

ax2.imshow(mesh_array, vmin=np.min(array), vmax=np.max(array))
ax2.set_title('mesh_array')
ax2.set_xlabel('x')
ax2.set_ylabel('y')

plt.show()

出力結果は、整数型の例と同様である。

まとめ

numpy.ix_の使い方を紹介しました。結構マイナーだと思いますが、使い所によっては便利なので是非取り入れて快適なnumpy環境を作っていきましょう。

参考資料

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?