LoginSignup
6
7

More than 5 years have passed since last update.

SparkでDataFrameを使ってグループごとの累積和・総和を求める方法【Python版】

Last updated at Posted at 2015-11-22

Sparkのpython版DataFrameのWindow関数を使って、カラムをグルーピング&ソートしつつ、累積和を計算するための方法です。

公式のPython APIドキュメントを調べながら模索した方法なので、もっと良い方法があるかも。
使ったSparkのバージョンは1.5.2です。

サンプル・データ

PostgreSQLのテーブルにテスト用データを用意し、pysparkにDataFrameとしてロードします。

$ SPARK_CLASSPATH=postgresql-9.4-1202.jdbc41.jar PYSPARK_DRIVER_PYTHON=ipython pyspark
(..snip..)
In [1]: df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql://localhost:5432/postgres?user=postgres', dbtable='public.foo').load()
(..snip..)
In [2]: df.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
In [4]: df.show()
+---+--------------------+---+
|  a|                   b|  c|
+---+--------------------+---+
|  1|2015-11-22 10:00:...|  1|
|  1|2015-11-22 10:10:...|  2|
|  1|2015-11-22 10:20:...|  3|
|  1|2015-11-22 10:30:...|  4|
|  1|2015-11-22 10:40:...|  5|
|  1|2015-11-22 10:50:...|  6|
|  1|2015-11-22 11:00:...|  7|
|  1|2015-11-22 11:10:...|  8|
|  1|2015-11-22 11:20:...|  9|
|  1|2015-11-22 11:30:...| 10|
|  1|2015-11-22 11:40:...| 11|
|  1|2015-11-22 11:50:...| 12|
|  1|2015-11-22 12:00:...| 13|
|  2|2015-11-22 10:00:...|  1|
|  2|2015-11-22 10:10:...|  2|
|  2|2015-11-22 10:20:...|  3|
|  2|2015-11-22 10:30:...|  4|
|  2|2015-11-22 10:40:...|  5|
|  2|2015-11-22 10:50:...|  6|
|  2|2015-11-22 11:00:...|  7|
+---+--------------------+---+
only showing top 20 rows

カラムaがグルーピング用、カラムbがソート用、カラムcが計算対象です。

カラムグループごとの累積和

カラムaでグループ分けしつつ、カラムbでソートし、カラムcの累積和を取ります。

まずはWindowの定義

In [6]: from pyspark.sql.Window import Window

In [7]: from pyspark.sql import functions as func

In [8]: window = Window.partitionpartitionBy(df.a).orderBy(df.b).rangeBetween(-sys.maxsize,0)

In [9]: window
Out[9]: <pyspark.sql.window.WindowSpec at 0x18368d0>

このウィンドウ上でpyspark.sql.functions.sum()を計算したColumnを作成

In [10]: cum_c = func.sum(df.c).over(window)

In [11]: cum_c
Out[11]: Column<'sum(c) WindowSpecDefinition UnspecifiedFrame>

このColumnを元のDataFrameにくっつけた新しいDataFrameを作成

In [12]: mod_df = df.withColumn("cum_c", cum_c)

In [13]: mod_df
Out[13]: DataFrame[a: int, b: timestamp, c: int, cum_c: bigint]

In [14]: mod_df.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)


In [15]: mod_df.show()
+---+--------------------+---+-----+
|  a|                   b|  c|cum_c|
+---+--------------------+---+-----+
|  1|2015-11-22 10:00:...|  1|    1|
|  1|2015-11-22 10:10:...|  2|    3|
|  1|2015-11-22 10:20:...|  3|    6|
|  1|2015-11-22 10:30:...|  4|   10|
|  1|2015-11-22 10:40:...|  5|   15|
|  1|2015-11-22 10:50:...|  6|   21|
|  1|2015-11-22 11:00:...|  7|   28|
|  1|2015-11-22 11:10:...|  8|   36|
|  1|2015-11-22 11:20:...|  9|   45|
|  1|2015-11-22 11:30:...| 10|   55|
|  1|2015-11-22 11:40:...| 11|   66|
|  1|2015-11-22 11:50:...| 12|   78|
|  1|2015-11-22 12:00:...| 13|   91|
|  2|2015-11-22 10:00:...|  1|    1|
|  2|2015-11-22 10:10:...|  2|    3|
|  2|2015-11-22 10:20:...|  3|    6|
|  2|2015-11-22 10:30:...|  4|   10|
|  2|2015-11-22 10:40:...|  5|   15|
|  2|2015-11-22 10:50:...|  6|   21|
|  2|2015-11-22 11:00:...|  7|   28|
+---+--------------------+---+-----+
only showing top 20 rows

計算できていますね。

カラムグループごとの総和

今度は、カラムaのグループごとに、カラムcの総和を計算します。
DataFrameをgroupBy()でpyspark.sql.GroupedDataにして、pyspark.sql.GroupedData.sum()を使います。
さっきのsum()とややこしいけど、こちらはColumnオプジョクトを引数に持たすとエラーが出るので注意します。

In [25]: sum_c_df = df.groupBy('a').sum('c')

また、先ほどと違ってこれはWindow関数ではないので、返ってくる結果はDataFrameです。
しかも、総和を格納したカラム名は勝手に決まります。

In [26]: sum_c_df
Out[26]: DataFrame[a: int, sum(c): bigint]

うーん、ややこしい。

とりあえず、元のDataFrameにカラムとしてくっつけます。

In [27]: mod_df3 = mod_df2.join('a'sum_c_df, 'a'()

In [28]: mod_df3.printSchema()
root
 |-- a: integer (nullable = true)
 |-- b: timestamp (nullable = true)
 |-- c: integer (nullable = true)
 |-- cum_c: long (nullable = true)
 |-- sum(c): long (nullable = true)


In [29]: mod_df3.show()
(..snip..)
+---+--------------------+---+-------+------+
|  a|                   b|  c|  cum_c|sum(c)|
+---+--------------------+---+-------+------+
|  1|2015-11-22 10:00:...|  1|      1|    91|
|  1|2015-11-22 10:10:...|  2|      3|    91|
|  1|2015-11-22 10:20:...|  3|      6|    91|
|  1|2015-11-22 10:30:...|  4|     10|    91|
|  1|2015-11-22 10:40:...|  5|     15|    91|
|  1|2015-11-22 10:50:...|  6|     21|    91|
|  1|2015-11-22 11:00:...|  7|     28|    91|
|  1|2015-11-22 11:10:...|  8|     36|    91|
|  1|2015-11-22 11:20:...|  9|     45|    91|
|  1|2015-11-22 11:30:...| 10|     55|    91|
|  1|2015-11-22 11:40:...| 11|     66|    91|
|  1|2015-11-22 11:50:...| 12|     78|    91|
|  1|2015-11-22 12:00:...| 13|     91|    91|
|  2|2015-11-22 10:00:...|  1|      1|    91|
|  2|2015-11-22 10:10:...|  2|      3|    91|
|  2|2015-11-22 10:20:...|  3|      6|    91|
|  2|2015-11-22 10:30:...|  4|     10|    91|
|  2|2015-11-22 10:40:...|  5|     15|    91|
|  2|2015-11-22 10:50:...|  6|     21|    91|
|  2|2015-11-22 11:00:...|  7|     28|    91|
+---+--------------------+---+-------+------+
only showing top 20 rows

うまくグループごとの総和が計算できていますね。

カラムグループごとの(総和 - 累積和)

では、カラムcについて総和までの残り値を計算しましょう。つまり、総和 - 累積和です。

In [30]: diff_sum_c = mod_df3[('sum(c)'] - mod_df3['cum_c']

In [31]: mod_df4 = mod_df3.withColumn("diff_sum_c", diff_sum_c)

In [34]: mod_df4.show()
(..snip..)
+---+--------------------+---+-------+------+----------+
|  a|                   b|  c|cum_c_2|sum(c)|diff_sum_c|
+---+--------------------+---+-------+------+----------+
|  1|2015-11-22 10:00:...|  1|      1|    91|        90|
|  1|2015-11-22 10:10:...|  2|      3|    91|        88|
|  1|2015-11-22 10:20:...|  3|      6|    91|        85|
|  1|2015-11-22 10:30:...|  4|     10|    91|        81|
|  1|2015-11-22 10:40:...|  5|     15|    91|        76|
|  1|2015-11-22 10:50:...|  6|     21|    91|        70|
|  1|2015-11-22 11:00:...|  7|     28|    91|        63|
|  1|2015-11-22 11:10:...|  8|     36|    91|        55|
|  1|2015-11-22 11:20:...|  9|     45|    91|        46|
|  1|2015-11-22 11:30:...| 10|     55|    91|        36|
|  1|2015-11-22 11:40:...| 11|     66|    91|        25|
|  1|2015-11-22 11:50:...| 12|     78|    91|        13|
|  1|2015-11-22 12:00:...| 13|     91|    91|         0|
|  2|2015-11-22 10:00:...|  1|      1|    91|        90|
|  2|2015-11-22 10:10:...|  2|      3|    91|        88|
|  2|2015-11-22 10:20:...|  3|      6|    91|        85|
|  2|2015-11-22 10:30:...|  4|     10|    91|        81|
|  2|2015-11-22 10:40:...|  5|     15|    91|        76|
|  2|2015-11-22 10:50:...|  6|     21|    91|        70|
|  2|2015-11-22 11:00:...|  7|     28|    91|        63|
+---+--------------------+---+-------+------+----------+
only showing top 20 rows

補足

今回気付きましたが、SPARK_CLASSPATHを使うのはSpark 1.0以上では推奨されていないようです。
pyspark起動時に以下のようなメッセージが出ました。

15/11/22 12:32:44 WARN spark.SparkConf: 
SPARK_CLASSPATH was detected (set to 'postgresql-9.4-1202.jdbc41.jar').
This is deprecated in Spark 1.0+.

Please instead use:
 - ./spark-submit with --driver-class-path to augment the driver classpath
 - spark.executor.extraClassPath to augment the executor classpath

どうも、クラスタを利用する場合には異なるサーバでこの環境変数が正しく伝わらないため、別のパラメータを使うことが推奨されているようです。

うむむ。
こういうローカルと分散環境の違い、きちんと把握していかないとなぁ。

6
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
7