お悩み
numpyで特定の列に対して条件を指定して行全体を操作したいことがある。
例えば、機械学習のモデルに画像を通すと、オブジェクトクラスの値がnumpyの中に入って返ってくるけれど、特定のクラスだけフィルタかけたい時とか。本筋としてはモデル側で欲しいクラスだけを返すようにしろって話なんだけど、それができない時の話。
下の出力例だと6番目の値がオブジェクトクラスなので、例えば車(2)だけの結果を集めたいとか、人間(0)だけを集めたい時に、6列目の値で検索して該当する行だけ集めたい。
[[5.39797363e+02 2.71583344e+02 5.93862732e+02 3.26177856e+02
8.78980994e-01 2.00000000e+00]
[7.61696655e+02 2.84762238e+02 7.84034058e+02 3.39492310e+02
8.69529247e-01 0.00000000e+00]
[8.06120178e+02 2.78944641e+02 8.31428406e+02 3.36827484e+02
8.61427248e-01 0.00000000e+00]
[4.77429657e+02 2.68967316e+02 5.18531067e+02 3.06501892e+02
8.59408557e-01 2.00000000e+00]
[2.50081268e+02 2.73935608e+02 2.67830475e+02 3.19194946e+02
(snip)
ソリューション
numpyにはこれに相当するコマンドはなさそう。一度調べたい行/列だけを対象としてnp.whereをかけて、該当する行/列のリストを取得して、その列を全体から抜きだす必要がある。
# python3
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> a = np.arange(20).reshape((4, 5)) # 行列を作成
>>> a
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
>>> a[:,3] # 列3(4列目)を抜き出し
array([ 3, 8, 13, 18])
>>> np.where(a[:,3]%2) # 列3が奇数な行の番号を取得。タプルで帰ってくるから…
(array([0, 2]),)
>>> np.where(a[:,3]%2)[0] # 列3が奇数な行の番号のリストを取得。
array([0, 2])
>>> a[np.where(a[:,3]%2)[0], :] # 上で入手したリストを元にa全体から行を抜き出し
array([[ 0, 1, 2, 3, 4],
[10, 11, 12, 13, 14]])
洗練されたソリューション
コメントでもっと洗練させれた書き方を教えてもらった。
>>> a[:, 4]%2 == 1
array([False, True, False, True])
>>> a[a[:, 4]%2==1]
array([[ 5, 6, 7, 8, 9],
[15, 16, 17, 18, 19]])
>>>
しかし、なぜ奇数を抜き出すような例にしたのか今となっては思い出せない。
結論
というわけで最初に挙げた機械学習の評価から必要なクラスのデータだけ取り出すには以下のようにすれば良い。
outputs[np.where(abs(outputs[:,5]) == 0 )[0],:]
outputs[np.isclose(outputs[:, 5], 0)]
これで対象のクラスだけにフィルターできる。実行速度は知らん。もしかしたらループまわしても変わらないかも…
なんでabs()/iscloseしているか
absで絶対値を取っているのは、値がfloatなので単純な比較ではマッチしないせいです。
abs(outputs[:,5])
参考
[Python] 浮動小数点数floatの比較は要注意!!
Pythonで浮動小数点数floatの誤差を考慮して比較