2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

NumPyの数列から、隣り合う2項の平均値の数列を計算する最も速い方法

Last updated at Posted at 2021-06-07

NumPyの数列から、隣り合う2項の平均値の数列を計算する最も速い方法

数列 $[a_1, a_2, a_3, ..., a_i]$ から、数列 $[\frac{a_1+a_2}{2}, \frac{a_2+a_3}{2}, \frac{a_3+a_4}{2}, ..., \frac{a_{i-1}+a_i}{2}]$ を得たいとする。

In [1]: import numpy as np

In [2]: a = np.random.default_rng(0).integers(0, 10, 10)
   ...: a
Out[2]: array([8, 6, 5, 2, 3, 0, 0, 0, 1, 8], dtype=int64)

In [3]: func_hoge(a)
Out[3]: array([7. , 5.5, 3.5, 2.5, 1.5, 0. , 0. , 0.5, 4.5])

どのようにすればよいだろうか。

動機

Qiitaの『Pythonでデータの挙動を見やすくする可視化ツールを作成してみた(ヒストグラム・確率分布編)』という記事を見て、以下のコードが気になった。

# フィッティングの残差平方和を計算 (参考https://rmizutaa.hatenablog.com/entry/2020/02/24/191312)
hist_y, hist_x = np.histogram(x, bins=20, density=True)  # ヒストグラム化して標準化
hist_x = (hist_x + np.roll(hist_x, -1))[:-1] / 2.0  # ヒストグラムのxの値をビンの左端→中央に移動

この部分は、コメント行にある通り、『あてはまりのよい確率分布を探したい - rmizutaの日記』に掲載されている以下のコードの引用のようである。

y, x = np.histogram(data.iloc[:,1], bins=20, density=True)
#xの値はヒストグラムの左端の値なので中心点に修正
x = (x + np.roll(x, -1))[:-1] / 2.0

一般的NumPyユーザーとしての直感では、np.roll()というのは非常に遅い、つまり(大した処理じゃないのに)やたらと時間のかかる鈍い関数である。しかも、特に、ここで行っている処理は、配列における隣り合った2項の平均をとるというものであるから、np.roll()を使う必然性を感じないので、このコードにはひどく違和感を覚えた。

方法

この処理は、以下のようにスライシングして行うのが、スタンダードだと思う。

# xの値はヒストグラムの左端の値なので中心点に修正
x = (x[1:] + x[:-1]) / 2

NumPyのヘビーユーザーならnp.correlate()/np.convolve()を用いる方法も考えられると思う。あまり知られていないが、一般的NumPyユーザーとしての直感では、この関数はとても速い有能な関数である。

x = np.correlate(x, [.5, .5], 'valid')

今回、先に引用したコードの場合、xはビンのエッジ配列だから、等差数列である。冒頭で述べた「数列 $[a_1, a_2, a_3, ..., a_i]$ から、数列 $[\frac{a_1+a_2}{2}, \frac{a_2+a_3}{2}, \frac{a_3+a_4}{2}, ..., \frac{a_{n-1}+a_n}{2}]$ を得たい」という目的からは外れてしまうが、今回のような条件のときにのみ適用されるよりよい方法を考えると、公差の半分を配列に足せばいいのだと気づく。

x = x[:-1] + (x[1] - x[0]) / 2

似た要領で、np.linspace()を使って1から作るというやり方でもいける。

x = np.linspace(x[:2].mean(), x[-2:].mean(), x.size-1)

(2021/06/16追加)
スライドウィンドウを使った方法も考えられる。

x = np.lib.stride_tricks.sliding_window_view(x, (2,)).mean(1)

(2021/06/16追加)
純粋なforループをnumbaで加速させる方法も考えられる。

@njit
def use_numba(x):
    s = x.size
    out = np.empty(s-1)
    for i in range(s):
        out[i] = (x[i] + x[i+1]) / 2
    return out

x = use_numba(x)

いろいろな方法をあげたが、結局どれが速いのか。どうやらもとの配列xの長さによって変わるようである。

計測してみた

比較するコードは以下の通り。

import numpy as np
from numba import njit


def np_roll(x):
    return (x + np.roll(x, -1))[:-1] / 2


def slice_mean(x):
    return (x[1:] + x[:-1]) / 2


def np_correlate(x):
    return np.correlate(x, [.5, .5], 'valid')


def diff_add(x):
    return x[:-1] + (x[1] - x[0]) / 2


def np_linspace(x):
    return np.linspace(x[:2].mean(), x[-2:].mean(), x.size-1)


# 以下2021/06/16追加
def window(x):
    return np.lib.stride_tricks.sliding_window_view(x, (2,)).mean(1)


@njit
def use_numba(x):
    s = x.size
    out = np.empty(s-1)
    for i in range(s):
        out[i] = (x[i] + x[i+1]) / 2
    return out

benchitを用いて計測する。

import benchit

funcs = [np_roll, slice_mean, np_correlate, diff_add, np_linspace, window, use_numba]
inputs = [np.linspace(1, n/9, n) for n in 10 ** np.arange(1, 8)]  # <- xの中身は適当

# すべての関数が同じ結果を返すことを確認
all(np.allclose(funcs[i](inputs[-1]), funcs[i+1](inputs[-1])) for i in range(len(funcs)-1))
# -> True

t = benchit.timings(funcs, inputs)
t.plot(figsize=(8, 8), logx=True)
print(t)

1.png

t
Functions   np_roll  slice_mean  np_correlate  diff_add  np_linspace    window     use_numba
Len                                                                                         
10         0.000020    0.000002      0.000003  0.000002     0.000056  0.000045  6.622680e-07
100        0.000021    0.000002      0.000003  0.000002     0.000056  0.000047  6.986812e-07
1000       0.000023    0.000004      0.000004  0.000003     0.000060  0.000062  8.710666e-07
10000      0.000039    0.000014      0.000011  0.000007     0.000072  0.000178  3.680266e-06
100000     0.000182    0.000115      0.000075  0.000043     0.000187  0.001326  3.946866e-05
1000000    0.010731    0.007300      0.003628  0.003622     0.005245  0.015443  3.497354e-03
10000000   0.118733    0.077196      0.038641  0.037447     0.053853  0.152398  3.715267e-02

配列xの長さにかかわらず、公差を求めて配列に足す方法が最も速いことがわかる。ただし、これはxが等差数列のときにのみ使える手である。
それを除外すると、xの長さが3000未満程度ならば、スタンダードと呼んだスライシングして計算する方法がやはり速い。xの長さが3000以上の場合はnp.correlate()を用いた方法が他を圧倒しており、なんと、xの長さが100万になると「xが等差数列のときにのみ使える裏技」をも凌駕している。対して、np.roll()を用いた方法はきわめて遅い。


おおむね予想通りの結果となったが、np.correlate()の有能さにはもっと注目を浴びせたい。そしてやはり、こんなときにnp.roll()は使うべきではない。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?