同じ値が連続する場合に除去したい
ある経路に対して垂直方向を計算したいときに、同じ座標が連続していて、
垂直方向が計算できず。
連続した値を除去する方法、検索方法が悪いのか、いいのがなかなか見つけられず(Pandasはあったけど)
Numbaなし版
def remove_dpl_no_numba(x):
return x[np.append(True, np.diff(x, axis=0).sum(axis=1) != 0)]
Numbaあり版
from numba import jit, f8
@jit(f8[:,:](f8[:,:]))
def remove_dpl(x):
d = x[1:, :] - x[:-1, :]
s = d.sum(axis=1)
f = np.append(True, s != 0)
return x[f]
numpy.diffのaxisがNumbaで使えないので、式に展開(はじめ何事かと思った)
比較してみた
リストの連続して重複する値を削除 と比較してみた
from itertools import groupby
hoge1 = np.c_[x, y]
hoge2 = hoge1.tolist()
%timeit remove_dpl1(hoge2)
%timeit remove_dpl2(hoge2)
%timeit remove_dpl3(hoge2)
%timeit [next(g) for _,g in groupby(hoge2)]
%timeit remove_dpl(hoge1)
%timeit remove_dpl_no_numba(hoge1)
リスト版はじめ30µsだったのに、何回か実行しているとなぜか20µsになってきた。
とは言え、元データはnumpyなので、tolist分は遅いし。
リスト版でnumba.jitできればよかったけど、reflected listとかよくわからん。
まあ、そこまで速度気にする必要はないんだけど、for文使いたくなかったので・・・
折角for文使わなかったのがどうなのか調べたく・・・
19.9 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
20.4 µs ± 410 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.7 µs ± 191 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
29.7 µs ± 390 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
3.86 µs ± 67.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
18 µs ± 378 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)