isin使えばおkと教えていただいたのでこの記事はobsoleteされました。ただしちょっとisinに関しての探究があります。
はじめに
あるNumPy配列
a = np.arange(10)
a = np.concatenate((a, a))
np.random.shuffle(a)
array([7, 5, 2, 3, 0, 3, 7, 9, 4, 0, 6, 5, 8, 8, 1, 1, 6, 4, 2, 9])
の中の値が1であるものの位置を示した配列1
a == 1
array([False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, True, False, False,
False, False])
ふと「これ、1の位置だけじゃなくて2や5の位置も取りたい場合どう書けばいいんだろ」と思いました。
(a == 1) | (a == 2) | (a == 5)
array([False, True, True, False, False, False, False, False, False,
False, False, True, False, False, True, True, False, False,
True, False])
と書けばいいわけですがダサいし、そもそも1と2と5は単なる例で「配列から探す対象を可変個指定」したいわけです。
ないので作りました
SQL的な感覚で言うとa in [1, 2, 5]
でいけるんじゃね?と期待するところですが、Pythonではこれは「リスト[1, 2, 5]にa(が指すもの)が含まれるか」という意味(リストの__contains__
メソッドが呼ばれる)なので解釈違いも甚だしいです。2
というわけで自分で作りました。
def sql_like_in(a, labels):
result = np.zeros_like(a, dtype=np.bool)
for v in labels:
result |= (a == v)
return result
sql_like_in(a, [1, 2, 5])
array([False, True, True, False, False, False, False, False, False,
False, False, True, False, False, True, True, False, False,
True, False])
速度計測
もちろん(?)実際に適用したい配列は20要素なんかじゃありません。もっとたくさんの要素で速度比較をしてみましょう。3
b = np.random.randint(0, 10, 60000)
まずはベタなやつ
%%timeit
(b == 1) | (b == 2) | (b == 5)
99.7 µs ± 298 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
作ったやつ
%%timeit
sql_like_in(b, [1, 2, 5])
108 µs ± 243 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
orが一個増える分遅いですね(誤差ですが)
labelsが1個以上渡されるという前提が必要ですが以下のようにしてみました。
def sql_like_in2(a, labels):
result = (a == labels[0])
for v in labels[1:]:
result |= (a == v)
return result
まあベタなやつを上回ることはないです。
%%timeit
sql_like_in2(b, [1, 2, 5])
99 µs ± 169 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
余談
以上、配列の中から探す値を可変個指定したい、けどないから作ったというお話でした。
実際にはNumPy配列ではなくTorchテンソルに対して使いたかったのですが同じような操作はできるようなのでPyTorchの場合にも適用できると思います。
追記:isin
コメントで「isinってのがありますよ」と教えていただきました。pandasでは使ったことあるのになぜ気づかなかった。
個人的に、探す値ごとに配列ができるのが気に食わなかったのでよっしこれで極限まで速くなる!とtimeitしてみる。。。
%%timeit
np.isin(b, [1, 2, 5])
115 µs ± 527 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
変わんねえじゃん。
というわけでソースを見てみます。本体はin1dでその中に
if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:
if invert:
mask = np.ones(len(ar1), dtype=bool)
for a in ar2:
mask &= (ar1 != a)
else:
mask = np.zeros(len(ar1), dtype=bool)
for a in ar2:
mask |= (ar1 == a)
return mask
マジックナンバーぇと言ったところですが大体7乗根ぐらいですかね。念のため今回のケース(ar1が60000)で計算してみると49.29829863311494でした。
結局自力でやる場合と速度的には変わらない(やってること同じなのだから当たり前)ということですね。まあ自分で書かなくてもよくなるしinvert(not演算)も使えたりしてよりよいですが。
なお、上記の条件が成立しない場合(ar1がそれほど大きくない場合)は魔法のようなことが行われてますがやってることとしては以下になります。
- 「値があるか調べたい配列」と「調べる値の配列」をuniqueしたうえでconcatする
- concatした配列をソートし「値が隣接するか」を調べる
- 元の「値があるか調べたい配列」に対応するboolean配列にTrue/Falseを設定する(これがかなり魔法)
これ速いのだろうか。ar1が小さくても素直にやればいいような。。。