はじめに
入力配列にある最も近い値を取得するコードはこれにあるように差分が最小になる値を取得すれば良い。だが、探索したい値が配列の場合にはどうするのだろうと思い考えてみた。
もちろんここにある関数をリスト内包表記で計算すればよいのだが、ここではあくまでnumpy
で行う事を考えてみた。
コード
import numpy as np
def search_nearestvalues(a, v):
ids = np.argmin(((np.tile(a, (len(v), 1))) - v[:, np.newaxis])**2, axis=1)
return a[ids]
np.random.seed(0)
a = 10*np.random.rand(10)
v = np.array([8, 10, 3])
print(search_nearestvalues(a, v)) # [7.15189366 9.63662761 3.83441519]
解説
ポイントはv[:, np.newaxis]
でv
を(3)から(3, 1)の配列にして差分を取ることで各値との差を求めることができる。もちろんnp.rehape(-1, 1)
でも可能である。ちなみにnp.newaxis
はNone
のエイリアスみたい。
締め
配列をnumpy.tile
関数で複製して探索しているのでメモリに余裕があるときしか動かないからあまり賢い方法には思えないかな。他に良い方法があれば(あるはず)コメント頂けたら幸いです。
追記
@nkayさんに、np.tileはなしでも自動でブロードキャストしてくれると教えていただきました。また、2乗よりも絶対値のほうが速いみたいです。ありがとうございます!
def search_nearestvalues(a, v):
ids = np.argmin(abs(a - v[:, None]), axis=1)
return a[ids]