1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

NumPyで大きな配列に対してPiece-wiseな関数を効率良く実行する方法

Last updated at Posted at 2025-07-17

モチベ

NumPyでフィルタリング操作をすることはよくあると思いますが,あるMECEな条件によって配列を分割し,分割された先で異なる操作をしたいときに高速な方法は何か調査したいと思います.

今回の例では,以下のような関数に対して,ある大きな1次元配列 A が与えられたときに,高速に f(A) を計算できる方法を考えます.また,条件の分割結果として,A はbinに分割されるものとします.つまり,各条件は A の要素がある有界閉区間に属するというものに絞ります.

例えば,以下のような操作が今回の記事の対象になります.


def f(x: float) -> float:
    assert x >= 0
    if x < 0.25:
        return x
    elif x < 0.5:
        return x**2
    elif x < 0.75:
        return x**3
    else:
        return x**4

方法論

以下,配列のサイズを $N$, 関数の piece数 (binサイズと等価) をKとします.
高速化の余地があるのは以下の2箇所:

  • A のフィルタリング方法
  • A のフィルタリング作成方法

まず,フィルタリング方針としては以下の3つが考えられます.

  • bool配列: A と同じサイズの bool 配列を $K$ 個用意して,各binに対して bool 配列を利用した抜き出しをする.
  • (部分)ソート + 二分探索スライシング: A をソートして,ソートされた A に対して二分探索で各binの始点と終点を求め,これによってスライシングする.
  • bin index配列: A の各要素に対応する bin のindex配列を作成し,index配列を利用して抜き出す.
方針 長所 短所
bool配列 理解しやすい. 抜き出す度にN要素の走査を行う.
(部分)ソート + 二分探索スライシング 抜き出しはbinサイズ要素の走査で済む. ソートの計算量が多い.最後にもとの順番に戻す走査が必要になる.
bin index配列 抜き出しはbinサイズ要素の走査で済む. 少し読みづらいかも?

上記の表からわかるように,基本的にはindex配列もしくはスライシングができる方が良いです.
index配列に関しては,np.nonzero (np.whereは非推奨) によって取得できます.
また,NumPy Backendとの通信回数が少なくなるような方針を取ると高速になることが考えられます.

bool配列とソート+二分探索に関しては作成が比較的直観的なので省略して,bin index配列作成方法を考えます.

  • np.digitize を利用する方法: $O(N \log K)$
  • A に対して,直接比較演算を適用し,np.count_nonzeroで集計: $O(NK)$

今回は $K$ の値が小さいので,おそらく後者が速くなることが考えられます.
また,with_count_nonzero_and_vectorized_whereにあるように,bin index配列を作成した後のindex配列の作成に加えて,直接比較 + np.nonzero + 二分探索を駆使すると,もとの配列 A をbinに分割した部分ソートを実装できます.
ただし,今回の実験結果によると,部分ソートを利用するよりも,直接index配列をK回作成するほうが実行時間が短くなるようです.

実験

少なくとも自分の環境では with_count_nonzero_and_individual_where が最速でした.
NumPy標準の np.digitize を利用するよりは vectorize で比較演算をまとめて,np.count_nonzero を利用する方法が高速なようです.
ちなみに np.digitize は内部で np.searchsorted(bins, x) をしているため,計算量が $O(N \log K)$ となりますが,np.count_nonzero を利用する方法では O(NK) となるため,大きな K に対しては np.digitize が高速になる可能性が高いです (参考).

import numpy as np


ReturnType = tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]


def with_digitized(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    bin_inds = np.digitize(a, [0, 0.25, 0.5, 0.75, 1])
    out = np.empty_like(a)
    if (target_inds := np.nonzero(bin_inds == 1)[0]).size:
        out[target_inds] = a[target_inds]
    if (target_inds := np.nonzero(bin_inds == 2)[0]).size:
        out[target_inds] = a[target_inds] ** 2
    if (target_inds := np.nonzero(bin_inds == 3)[0]).size:
        out[target_inds] = a[target_inds] ** 3
    if (target_inds := np.nonzero(bin_inds == 4)[0]).size:
        out[target_inds] = a[target_inds] ** 4

    return out


def with_argsort(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    order = np.argsort(a)
    a_order = a[order]
    boundary_inds = np.searchsorted(a_order, [0.25, 0.5, 0.75])
    out = np.empty_like(a)
    if boundary_inds[0] > 0:
        out[:boundary_inds[0]] = a_order[:boundary_inds[0]]
    if boundary_inds[1] - boundary_inds[0] > 0:
        out[boundary_inds[0]:boundary_inds[1]] = a_order[boundary_inds[0]:boundary_inds[1]] ** 2
    if boundary_inds[2] - boundary_inds[1] > 0:
        out[boundary_inds[1]:boundary_inds[2]] = a_order[boundary_inds[1]:boundary_inds[2]] ** 3
    if a.size - boundary_inds[2] > 0:
        out[boundary_inds[2]:] = a_order[boundary_inds[2]:] ** 4
    inv = np.empty_like(order)
    inv[order] = np.arange(order.size)
    return out[inv]


def with_count_nonzero_and_individual_where(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    bin_inds = np.count_nonzero(a >= [[0.25], [0.5], [0.75]], axis=0)
    out = np.empty_like(a)
    if (target_inds := np.nonzero(bin_inds == 0)[0]).size:
        out[target_inds] = a[target_inds]
    if (target_inds := np.nonzero(bin_inds == 1)[0]).size:
        out[target_inds] = a[target_inds] ** 2
    if (target_inds := np.nonzero(bin_inds == 2)[0]).size:
        out[target_inds] = a[target_inds] ** 3
    if (target_inds := np.nonzero(bin_inds == 3)[0]).size:
        out[target_inds] = a[target_inds] ** 4
    return out


def with_count_nonzero_and_individual_bool_array(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    bin_inds = np.count_nonzero(a >= [[0.25], [0.5], [0.75]], axis=0)
    out = np.empty_like(a)
    case1 = bin_inds == 0
    case2 = bin_inds == 1
    case3 = bin_inds == 2
    case4 = bin_inds == 3
    if (a1 := a[case1]).size:
        out[case1] = a1
    if (a2 := a[case2]).size:
        out[case2] = a2 ** 2
    if (a3 := a[case3]).size:
        out[case3] = a3 ** 3
    if (a4 := a[case4]).size:
        out[case4] = a4 ** 4
    return out


def with_count_nonzero_and_vectorized_where(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    bin_inds = np.count_nonzero(a >= [[0.25], [0.5], [0.75]], axis=0)
    target_bin_inds, inds_in_each_bin = np.nonzero(bin_inds == [[i] for i in range(4)])
    boundary_inds = np.searchsorted(target_bin_inds, list(range(1, 4)))
    a_part = a[inds_in_each_bin]
    out = np.empty_like(a)
    if boundary_inds[0] > 0:
        out[:boundary_inds[0]] = a_part[:boundary_inds[0]]
    if boundary_inds[1] - boundary_inds[0] > 0:
        out[boundary_inds[0]:boundary_inds[1]] = a_part[boundary_inds[0]:boundary_inds[1]] ** 2
    if boundary_inds[2] - boundary_inds[1] > 0:
        out[boundary_inds[1]:boundary_inds[2]] = a_part[boundary_inds[1]:boundary_inds[2]] ** 3
    if a.size - boundary_inds[2] > 0:
        out[boundary_inds[2]:] = a_part[boundary_inds[2]:] ** 4
    inv = np.empty_like(inds_in_each_bin)
    inv[inds_in_each_bin] = np.arange(inds_in_each_bin.size)
    return out[inv]


def with_individual_bool_array(a: np.ndarray) -> ReturnType:
    assert len(a.shape) == 1
    out = np.empty_like(a)
    case1 = a < 0.25
    case2 = (a >= 0.25) & (a < 0.5)
    case3 = (a >= 0.5) & (a < 0.75)
    case4 = a >= 0.75
    if (a1 := a[case1]).size:
        out[case1] = a1
    if (a2 := a[case2]).size:
        out[case2] = a2 ** 2
    if (a3 := a[case3]).size:
        out[case3] = a3 ** 3
    if (a4 := a[case4]).size:
        out[case4] = a4 ** 4

    return out


if __name__ == "__main__":
    import time
    import pandas as pd
    rng = np.random.RandomState(0)
    target_funcs = [
        with_digitized,
        with_argsort,
        with_count_nonzero_and_individual_where,
        with_count_nonzero_and_individual_bool_array,
        with_count_nonzero_and_vectorized_where,
        with_individual_bool_array,
    ]
    runtime_table = {f.__name__: [] for f in target_funcs}
    size_list = [100, 300, 1000, 3000, 10000, 30000, 100000, 300000]
    for size in size_list:
        print(f"Try {size=}")
        runtimes = [0.0] * len(target_funcs)
        for i in range(100):
            a = rng.random(size)
            results = [None] * len(target_funcs)
            for i, func in enumerate(target_funcs):
                start = time.time()
                results[i] = func(a)
                runtimes[i] += time.time() - start

            for vs in results[1:]:
                np.all(vs == results[0])

        for func, runtime in zip(target_funcs, runtimes):
            runtime_table[func.__name__].append(runtime * 1000)

    print(pd.DataFrame(runtime_table, index=size_list))
実験コード
        with_digitized  with_argsort  with_count_nonzero_and_individual_where  with_count_nonzero_and_individual_bool_array  with_count_nonzero_and_vectorized_where  with_individual_bool_array
100           1.203299      0.940800                                 1.212835                                      1.265287                                 1.523018                    0.969410
300           1.662731      1.549959                                 1.516581                                      1.950026                                 2.284288                    1.440763
1000          3.052235      2.952814                                 2.402306                                      4.219055                                 4.176617                    3.509283
3000          6.706953      6.883860                                 4.391193                                     10.186672                                 8.826971                    9.245157
10000        19.091845     21.739483                                11.170626                                     30.589819                                25.300503                   28.718472
30000        54.786444     69.526434                                31.057835                                     90.868950                                90.751648                   84.001541
100000      195.596457    282.873154                               106.905222                                    313.409567                               342.260838                  288.238525
300000      624.125719   1052.558422                               391.471148                                    988.689661                              1109.498978                  896.911860
結果

P.S.

np.selectも試しましたが,結局遅かったので実験には載せていません.

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?