5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

お題は不問!Qiita Engineer Festa 2023で記事投稿!

numpyでgroupbyの集計を行うとpandasより速いのかを検証

Last updated at Posted at 2023-07-05

はじめに

 特徴量エンジニアリングやデータ解析のための前処理で pandas の集計(groupbyしてsum, mean, std, median とか)を使用していると「遅い」と感じることがしばしばありました。
 検索してみると、pandas を用いずに numpy だけで集計している方々がいるようだったので、実際に速いのか試してみました。しかも、numpy であれば Numba を用いることができるため、合わせて検証をしてみることにしました。

【参考】だから僕はpandasを辞めた【NumPyだけでgroupby.mean()する3つの方法 篇】

また、この記事では Numba の解説はしません。
とりあえず、読んで頂くだけであれば、python の処理を高速にするモジュールぐらいの感覚でもよいです。下記リンクに Numba を解説くださっている方の記事がありますので、気になる方は読んでみてください。

【参考】Python を Numba で高速化するときの間違えやすいポイントまとめ

今回作成したコードは下記にアップロードしています。
自由にご使用ください。
(GitHub)ソースコード

前提

  1. python バージョン:3.10.6
  2. pandas バージョン:2.0.1
  3. numpy バージョン:1.23.5
  4. Numba バージョン:0.57.1
  5. 調査対象の集計関数は、pandasで使用できる size, sum, mean, std, median, max, min
  6. groupby で指定できる列(カラム)数は1カラムとする
     ※複数カラムに対応すると実装が複雑になるため
  7. 使用するアルゴリズムは改善の余地あり(の視点でみてください)
  8. ベンチマークは benchit : 0.0.6 ライブラリを使用
  9. 訳合ってグラフの描画には matplotlib : 3.7.1 ライブラリを使用

検証に使用するデータ

無料で使用できるデータセット(以下、df に読込 )を利用します。
集計対象のグループ( category )列があって、数値( value )列があり、100万以上のデータ量のデータセットを使用しました。
※本記事では、データの中身については細かく言及しません。

調査&結果

データの準備 ※以降、このデータを使用します

dfcategory列順にソート
categoriesvaluesはデータセット( df )から下記のように準備

numpy用データの準備
sorted_df = df.sort_values('category')
values = sorted_df['value'].to_numpy()
categories = pd.factorize(sorted_df['category'], sort=True)[0]

1. size() の比較

まずは、シンプルに size の集計で調査

①pandasのみ

例えば、こんなコード
※ベンチマーク測定用の benchit で呼び出すために関数にしてます。

import pandas as pd
# ①pandasのみ
def pandas_counts(df):
  return df.groupby('category').size()

②pandas + numpy

ネットでこんな書き方も見つけました。
明らかに遅そうですが、とりあえず比較対象に

import pandas as pd
import numpy as np
# ②pandas + numpy
def pandas_np_counts(df):
  return df.groupby('category').apply(lambda x: np.size([*x], axis=0))

③numpyのみ

先駆者様を参考にするとこのように書けます。

import numpy as np
# ③numpyのみ
def grouped_counts_bincount(categories):
  return np.bincount(categories)

【参考】だから僕はpandasを辞めた【NumPyだけでgroupby.mean()する3つの方法 篇】

④numpy + Numba

Numba 版はこちら
※使用する際は、デコレータの内容を適宜変更してください。

import numpy as np
import numba as nb
from numba import njit
# ④numpy + Numba
@njit("i8[:](i4[:],i8[:])", cache=True)
def grouped_counts_bincount_jit(categories):
  return np.bincount(categories)

◆比較結果

処理速度(=より早く処理が完了)
③≒④>①>② の順
つまり、numpy で集計した方が圧倒的に速く、特に10の5乗オーダーまでは顕著に差がありました。
size()_比較

2. sum() の比較

sum による集計を同様に関数として書くと以下のような感じになります。
bincount の重みvalueを使用すると、上手い感じにグループごとの合計が求まります。

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_sum(df):
  return df.groupby('category')['value'].sum()

# ②pandas + numpy
def pandas_np_sum(df):
  return df.groupby('category')['value'].apply(lambda x: np.sum([*x], axis=0))

# ③numpyのみ
def grouped_sum_bincount(values, categories):
  return np.bincount(categories, values)

# ④numpy + Numba
@njit("f8[:](i4[:], i8[:])", cache=True)
def grouped_sum_bincount_jit(values, categories):
  return np.bincount(categories, values)

◆比較結果

④>≒③>①>② の順

わずかに ④numpy+Numba③numpyのみ より速い

sum()_比較

3. mean() の比較

mean による集計の比較は以下の通りになりました。

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_mean(df):
  return df.groupby('category')['value'].mean()

# ②pandas + numpy
def pandas_np_mean(df):
  return df.groupby('category')['value'].apply(lambda x: np.mean([*x], axis=0))

# ③numpyのみ
def grouped_mean_bincount(values, categories):
  counts = grouped_counts_bincount(values, categories)
  sums = grouped_sum_bincount(values, categories)
  return sums / counts

# ④numpy + Numba
@njit("f8[:](i4[:], i8[:])", cache=True)
def grouped_mean_jit_bitcount(values, categories):
  counts = grouped_counts_bincount_jit(values, categories)
  sums = grouped_sum_bincount_jit(values, categories)
  return sums / counts

◆比較結果

【2023/07/06修正】
すみません、ベンチマークしたときのコードが誤っていて、間違った結果を貼っていました。
下記の通り、結果のテキストと画像を修正しました。

④>③>①>② の順
明らかに、「④numpy+Numba」が最も速い
④>≒③>①>② の順
わずかに ④numpy+Numba③numpyのみ より速い
mean()_比較

4. std() の比較

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_std(df):
  return df.groupby('category')['value'].std()

# ②pandas + numpy
def pandas_np_std(df):
  return df.groupby('category')['value'].apply(lambda x: np.std([*x], axis=0))

# ③numpyのみ
def grouped_std_bincount(values, categories):
  counts = grouped_counts_bincount(values, categories)
  sums = grouped_sum_bincount(values, categories)
  means = sums / counts
  return np.sqrt(np.bincount(categories, (means[categories] - values)**2) / counts)

# ④numpy + Numba
@njit("f8[:](i4[:],i8[:])", cache=True)
def grouped_std_bincount_jit(values, categories):
  counts = grouped_counts_bincount_jit(values, categories)
  sums = grouped_sum_bincount_jit(values, categories)
  means = sums / counts
  return np.sqrt(np.bincount(categories, (means[categories] - values)**2) / counts)

◆比較結果

④>≒③>①>② の順

わずかに ④numpy+Numba③numpyのみ より速い
std()_比較

5. median() の比較

グループごとの中央値を求めるアルゴリズムの方針

 「グループ(今回はcategories)ごとの中央値」を numpy だけで求めようとすると、bincount では集計できません。中央値を求めるには、values をグループごとにソートされている形にする必要があります。
 categoriesは、各カテゴリが factorize() によって0からの連番になっていること、既にグループごとにソートされていることを利用し、各グループの開始・終了インデックスを探索し、そのインデックスでvaluesをスライスすると速そうでした。
 numpy.where() を使用して、各グループのvaluesに絞る方法もありますが、毎回全探索になってしまうため遅かったです。
 中央値の算出自体は、numpy.median() を使用しました。中央値の算出は自前でも十分実装できますが、探索アルゴリズムが ndarray に最適化されているためか numpy.median() が一番速かったです。

 以降のベンチマークは「グループごとの開始・終了インデックスを用いたアルゴリズム」で実施しました。

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_median(df):
  return df.groupby('category')['value'].median()

# ②pandas + numpy
def pandas_np_median(df):
  return df.groupby('category')['value'].apply(lambda x: np.median([*x], axis=0))

# ③numpyのみ
def grouped_median(values, categories):
    # グループ数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)
    # 各グループの中央値を格納
    medians = np.zeros(n_groups, dtype=np.float64)
    # グループの開始インデックス初期化
    start = 0

    # グループごとに処理
    for current_group in range(n_groups):
        # 現在グループの終了インデックスを取得
        end = start
        while end < len(categories) and categories[end] == current_group:
            end += 1

        # 現在グループの中央値を取得
        grouped_value = values[start:end]
        medians[current_group] = np.median(grouped_value)
        # 次グループの開始インデックスに更新
        start = end

    return medians

# ④numpy + Numba
@njit("f8[:](i4[:],i8[:])", cache=True)
def grouped_median_jit(values, categories):
    # グループ数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)

    # ~以下省略~ grouped_median()に同じ

    return medians

◆比較結果

10の3乗オーダーまで:④>③>①>② の順
10の3乗オーダー以降:④>①>③>≒② の順
※実際に集計を行う際は、10の3乗オーダー以上のデータを扱うことが殆どかと思います。

10の5乗オーダーまでは、明らかに ④numpy+Numba が速い
大規模データセットのときは ①pandas も悪くないが、それでも が速い

median()_比較

6. min() の比較

中央値と同様のアルゴリズムを使用します。
※③④の np.median()np.min() に変更するだけ

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_min(df):
  return df.groupby('category')['value'].min()

# ②pandas + numpy
def pandas_np_min(df):
  return df.groupby('category')['value'].apply(lambda x: np.min([*x], axis=0))

# ③numpyのみ
def grouped_min(values, categories):
    # グループの総数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)
    # 各グループの最小値を格納
    mins = np.zeros(n_groups, dtype=np.float64)
    # グループの開始インデックス初期化
    start = 0

    ## グループごとに処理
    for current_group in range(n_groups):
        # 現在グループの終了インデックスを取得
        end = start
        while end < len(categories) and categories[end] == current_group:
            end += 1

        # 現在グループの最小値を取得
        grouped_value = values[start:end]
        mins[current_group] = np.min(grouped_value)
        # 次グループの開始インデックスに更新
        start = end

    return mins

# ④numpy + Numba
@njit("f8[:](i4[:],i8[:])", cache=True)
def grouped_min_jit(values, categories):
    # グループの総数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)

    # ~以下省略~ grouped_min()に同じ

    return mins

◆比較結果

10の3乗オーダーまで:④>③>①>② の順
10の3乗オーダー以降:④>①>③>≒② の順
※実際に集計を行う際は、10の3乗オーダー以上のデータを扱うことが殆どかと思います。

中央値と同じ傾向でした。
ただ、最小値の方が中央値を探索するよりシンプルな線形探索になるため気持ち速い気がします。
※中央値はクイックセレクトで探索すれば計算時間を線形にできる模様

【参考】「選択アルゴリズム」と「中央値の中央値」
※外部サイト
min()_比較

7. max() の比較

③④の np.min()np.max() に変更しただけ

import pandas as pd
import numpy as np
import numba as nb
from numba import njit

# ①pandasのみ
def pandas_max(df):
  return df.groupby('category')['value'].max()

# ②pandas + numpy
def pandas_np_max(df):
  return df.groupby('category')['value'].apply(lambda x: np.max([*x], axis=0))

# ③numpyのみ
def grouped_max(values, categories):
    # グループの総数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)
    # 各グループの最大値を格納
    maxs = np.zeros(n_groups, dtype=np.float64)
    # グループの開始インデックス初期化
    start = 0

    # グループごとに処理
    for current_group in range(n_groups):
        # 現在グループの終了インデックスを取得
        end = start
        while end < len(categories) and categories[end] == current_group:
            end += 1

        # 現在グループの最大値を取得
        grouped_value = values[start:end]
        maxs[current_group] = np.max(grouped_value)
        # 次グループの開始インデックスに更新
        start = end

    return maxs

# ④numpy + Numba
@njit("f8[:](i4[:],i8[:])", cache=True)
def grouped_max_jit(values, categories):
    # グループの総数を計算
    n_groups = np.max(categories) + 1 # catetoriesは0からの連番(整数・重複あり)

    # ~以下省略~ grouped_max()に同じ

    return maxs

◆比較結果

10の3乗オーダーまで:④>③>①>② の順
10の3乗オーダー以降:④>①>③>≒② の順
※実際に集計を行う際は、10の3乗オーダー以上のデータを扱うことが殆どかと思います。

(当たり前ですが)最小値と全く同様の傾向でした。

min()_比較

まとめ

 全てにおいて「numpy+Numbaが速い」という結論でした。
 つまり、Numba を用いないと pandas の方が速い場合があります。
 ※アルゴリズムが最善ではないという点も十分にあるかと思いますが…

 また、今回1種類のデータセットでしか検証をしていないため、データセットによっては pandas での集計と変わらなかったり、むしろpandas での集計の方が速いということはあるかもしれません。

 加えて、「事前にグループカテゴリのソートが必要」だったり、「複数カテゴリでの集計(例えば、都道府県別・年代別の年収)にはもう一手間が必要」だったりすること、Numba はサポートされていない関数やそもそも Numpy では実装が面倒な集計があるため、第一選択としては pandas でいいんじゃないかな、と思ってます。

 追加の学習コストは掛かりますが、他のデータフレームライブラリであるPolarsを使用するという選択もあります。

ソースコード

(GitHub)ソースコード

参考資料

【参考】だから僕はpandasを辞めた【NumPyだけでgroupby.mean()する3つの方法 篇】
【参考】Python を Numba で高速化するときの間違えやすいポイントまとめ
【参考】「選択アルゴリズム」と「中央値の中央値」※外部サイト

5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?