LoginSignup
0
0

Numpyで特定の列/行に条件を指定してマッチする行/列全体を抜きだしたい

Last updated at Posted at 2023-02-04
お悩み

numpyで特定の列に対して条件を指定して行全体を操作したいことがある。
例えば、機械学習のモデルに画像を通すと、オブジェクトクラスの値がnumpyの中に入って返ってくるけれど、特定のクラスだけフィルタかけたい時とか。本筋としてはモデル側で欲しいクラスだけを返すようにしろって話なんだけど、それができない時の話。

下の出力例だと6番目の値がオブジェクトクラスなので、例えば車(2)だけの結果を集めたいとか、人間(0)だけを集めたい時に、6列目の値で検索して該当する行だけ集めたい。

とある出力例
[[5.39797363e+02 2.71583344e+02 5.93862732e+02 3.26177856e+02
  8.78980994e-01 2.00000000e+00]
 [7.61696655e+02 2.84762238e+02 7.84034058e+02 3.39492310e+02
  8.69529247e-01 0.00000000e+00]
 [8.06120178e+02 2.78944641e+02 8.31428406e+02 3.36827484e+02
  8.61427248e-01 0.00000000e+00]
 [4.77429657e+02 2.68967316e+02 5.18531067e+02 3.06501892e+02
  8.59408557e-01 2.00000000e+00]
 [2.50081268e+02 2.73935608e+02 2.67830475e+02 3.19194946e+02
(snip)
ソリューション

numpyにはこれに相当するコマンドはなさそう。一度調べたい行/列だけを対象としてnp.whereをかけて、該当する行/列のリストを取得して、その列を全体から抜きだす必要がある。

4列目が奇数の行を抜き出す例
# python3
Python 3.9.5 (default, Nov 23 2021, 15:27:38) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import numpy as np
>>> a = np.arange(20).reshape((4, 5)) # 行列を作成
>>> a   
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])
>>> a[:,3]   # 列3(4列目)を抜き出し
array([ 3,  8, 13, 18])
>>> np.where(a[:,3]%2)  # 列3が奇数な行の番号を取得。タプルで帰ってくるから…
(array([0, 2]),)
>>> np.where(a[:,3]%2)[0] # 列3が奇数な行の番号のリストを取得。
array([0, 2])
>>> a[np.where(a[:,3]%2)[0], :] # 上で入手したリストを元にa全体から行を抜き出し
array([[ 0,  1,  2,  3,  4],
       [10, 11, 12, 13, 14]])
洗練されたソリューション

コメントでもっと洗練させれた書き方を教えてもらった。

4列目が奇数の行を抜き出す洗練された例
>>> a[:, 4]%2 == 1
array([False,  True, False,  True])
>>> a[a[:, 4]%2==1]
array([[ 5,  6,  7,  8,  9],
       [15, 16, 17, 18, 19]])
>>> 

しかし、なぜ奇数を抜き出すような例にしたのか今となっては思い出せない。

結論

というわけで最初に挙げた機械学習の評価から必要なクラスのデータだけ取り出すには以下のようにすれば良い。

6列目に入ったオブジェクトクラスが0のものだけを抜き出す例
 outputs[np.where(abs(outputs[:,5]) == 0 )[0],:]
6列目に入ったオブジェクトクラスが0のものだけを抜き出す洗練された例
 outputs[np.isclose(outputs[:, 5], 0)]

これで対象のクラスだけにフィルターできる。実行速度は知らん。もしかしたらループまわしても変わらないかも…

なんでabs()/iscloseしているか

absで絶対値を取っているのは、値がfloatなので単純な比較ではマッチしないせいです。

abs()
abs(outputs[:,5])

参考
[Python] 浮動小数点数floatの比較は要注意!!
Pythonで浮動小数点数floatの誤差を考慮して比較

0
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0