Sparkのpython版DataFrameのWindow関数を使って、カラムをグルーピング&ソートしつつ、累積和を計算するための方法です。
公式のPython APIドキュメントを調べながら模索した方法なので、もっと良い方法があるかも。
使ったSparkのバージョンは1.5.2です。
サンプル・データ
PostgreSQLのテーブルにテスト用データを用意し、pysparkにDataFrameとしてロードします。
$ SPARK_CLASSPATH=postgresql-9.4-1202.jdbc41.jar PYSPARK_DRIVER_PYTHON=ipython pyspark
(..snip..)
In [1]: df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql://localhost:5432/postgres?user=postgres', dbtable='public.foo').load()
(..snip..)
In [2]: df.printSchema()
root
|-- a: integer (nullable = true)
|-- b: timestamp (nullable = true)
|-- c: integer (nullable = true)
In [4]: df.show()
+---+--------------------+---+
| a| b| c|
+---+--------------------+---+
| 1|2015-11-22 10:00:...| 1|
| 1|2015-11-22 10:10:...| 2|
| 1|2015-11-22 10:20:...| 3|
| 1|2015-11-22 10:30:...| 4|
| 1|2015-11-22 10:40:...| 5|
| 1|2015-11-22 10:50:...| 6|
| 1|2015-11-22 11:00:...| 7|
| 1|2015-11-22 11:10:...| 8|
| 1|2015-11-22 11:20:...| 9|
| 1|2015-11-22 11:30:...| 10|
| 1|2015-11-22 11:40:...| 11|
| 1|2015-11-22 11:50:...| 12|
| 1|2015-11-22 12:00:...| 13|
| 2|2015-11-22 10:00:...| 1|
| 2|2015-11-22 10:10:...| 2|
| 2|2015-11-22 10:20:...| 3|
| 2|2015-11-22 10:30:...| 4|
| 2|2015-11-22 10:40:...| 5|
| 2|2015-11-22 10:50:...| 6|
| 2|2015-11-22 11:00:...| 7|
+---+--------------------+---+
only showing top 20 rows
カラムaがグルーピング用、カラムbがソート用、カラムcが計算対象です。
カラムグループごとの累積和
カラムaでグループ分けしつつ、カラムbでソートし、カラムcの累積和を取ります。
まずはWindowの定義
In [6]: from pyspark.sql.Window import Window
In [7]: from pyspark.sql import functions as func
In [8]: window = Window.partitionpartitionBy(df.a).orderBy(df.b).rangeBetween(-sys.maxsize,0)
In [9]: window
Out[9]: <pyspark.sql.window.WindowSpec at 0x18368d0>
このウィンドウ上でpyspark.sql.functions.sum()を計算したColumnを作成
In [10]: cum_c = func.sum(df.c).over(window)
In [11]: cum_c
Out[11]: Column<'sum(c) WindowSpecDefinition UnspecifiedFrame>
このColumnを元のDataFrameにくっつけた新しいDataFrameを作成
In [12]: mod_df = df.withColumn("cum_c", cum_c)
In [13]: mod_df
Out[13]: DataFrame[a: int, b: timestamp, c: int, cum_c: bigint]
In [14]: mod_df.printSchema()
root
|-- a: integer (nullable = true)
|-- b: timestamp (nullable = true)
|-- c: integer (nullable = true)
|-- cum_c: long (nullable = true)
In [15]: mod_df.show()
+---+--------------------+---+-----+
| a| b| c|cum_c|
+---+--------------------+---+-----+
| 1|2015-11-22 10:00:...| 1| 1|
| 1|2015-11-22 10:10:...| 2| 3|
| 1|2015-11-22 10:20:...| 3| 6|
| 1|2015-11-22 10:30:...| 4| 10|
| 1|2015-11-22 10:40:...| 5| 15|
| 1|2015-11-22 10:50:...| 6| 21|
| 1|2015-11-22 11:00:...| 7| 28|
| 1|2015-11-22 11:10:...| 8| 36|
| 1|2015-11-22 11:20:...| 9| 45|
| 1|2015-11-22 11:30:...| 10| 55|
| 1|2015-11-22 11:40:...| 11| 66|
| 1|2015-11-22 11:50:...| 12| 78|
| 1|2015-11-22 12:00:...| 13| 91|
| 2|2015-11-22 10:00:...| 1| 1|
| 2|2015-11-22 10:10:...| 2| 3|
| 2|2015-11-22 10:20:...| 3| 6|
| 2|2015-11-22 10:30:...| 4| 10|
| 2|2015-11-22 10:40:...| 5| 15|
| 2|2015-11-22 10:50:...| 6| 21|
| 2|2015-11-22 11:00:...| 7| 28|
+---+--------------------+---+-----+
only showing top 20 rows
計算できていますね。
カラムグループごとの総和
今度は、カラムaのグループごとに、カラムcの総和を計算します。
DataFrameをgroupBy()でpyspark.sql.GroupedDataにして、pyspark.sql.GroupedData.sum()を使います。
さっきのsum()とややこしいけど、こちらはColumnオプジョクトを引数に持たすとエラーが出るので注意します。
In [25]: sum_c_df = df.groupBy('a').sum('c')
また、先ほどと違ってこれはWindow関数ではないので、返ってくる結果はDataFrameです。
しかも、総和を格納したカラム名は勝手に決まります。
In [26]: sum_c_df
Out[26]: DataFrame[a: int, sum(c): bigint]
うーん、ややこしい。
とりあえず、元のDataFrameにカラムとしてくっつけます。
In [27]: mod_df3 = mod_df2.join('a'sum_c_df, 'a'()
In [28]: mod_df3.printSchema()
root
|-- a: integer (nullable = true)
|-- b: timestamp (nullable = true)
|-- c: integer (nullable = true)
|-- cum_c: long (nullable = true)
|-- sum(c): long (nullable = true)
In [29]: mod_df3.show()
(..snip..)
+---+--------------------+---+-------+------+
| a| b| c| cum_c|sum(c)|
+---+--------------------+---+-------+------+
| 1|2015-11-22 10:00:...| 1| 1| 91|
| 1|2015-11-22 10:10:...| 2| 3| 91|
| 1|2015-11-22 10:20:...| 3| 6| 91|
| 1|2015-11-22 10:30:...| 4| 10| 91|
| 1|2015-11-22 10:40:...| 5| 15| 91|
| 1|2015-11-22 10:50:...| 6| 21| 91|
| 1|2015-11-22 11:00:...| 7| 28| 91|
| 1|2015-11-22 11:10:...| 8| 36| 91|
| 1|2015-11-22 11:20:...| 9| 45| 91|
| 1|2015-11-22 11:30:...| 10| 55| 91|
| 1|2015-11-22 11:40:...| 11| 66| 91|
| 1|2015-11-22 11:50:...| 12| 78| 91|
| 1|2015-11-22 12:00:...| 13| 91| 91|
| 2|2015-11-22 10:00:...| 1| 1| 91|
| 2|2015-11-22 10:10:...| 2| 3| 91|
| 2|2015-11-22 10:20:...| 3| 6| 91|
| 2|2015-11-22 10:30:...| 4| 10| 91|
| 2|2015-11-22 10:40:...| 5| 15| 91|
| 2|2015-11-22 10:50:...| 6| 21| 91|
| 2|2015-11-22 11:00:...| 7| 28| 91|
+---+--------------------+---+-------+------+
only showing top 20 rows
うまくグループごとの総和が計算できていますね。
カラムグループごとの(総和 - 累積和)
では、カラムcについて総和までの残り値を計算しましょう。つまり、総和 - 累積和です。
In [30]: diff_sum_c = mod_df3[('sum(c)'] - mod_df3['cum_c']
In [31]: mod_df4 = mod_df3.withColumn("diff_sum_c", diff_sum_c)
In [34]: mod_df4.show()
(..snip..)
+---+--------------------+---+-------+------+----------+
| a| b| c|cum_c_2|sum(c)|diff_sum_c|
+---+--------------------+---+-------+------+----------+
| 1|2015-11-22 10:00:...| 1| 1| 91| 90|
| 1|2015-11-22 10:10:...| 2| 3| 91| 88|
| 1|2015-11-22 10:20:...| 3| 6| 91| 85|
| 1|2015-11-22 10:30:...| 4| 10| 91| 81|
| 1|2015-11-22 10:40:...| 5| 15| 91| 76|
| 1|2015-11-22 10:50:...| 6| 21| 91| 70|
| 1|2015-11-22 11:00:...| 7| 28| 91| 63|
| 1|2015-11-22 11:10:...| 8| 36| 91| 55|
| 1|2015-11-22 11:20:...| 9| 45| 91| 46|
| 1|2015-11-22 11:30:...| 10| 55| 91| 36|
| 1|2015-11-22 11:40:...| 11| 66| 91| 25|
| 1|2015-11-22 11:50:...| 12| 78| 91| 13|
| 1|2015-11-22 12:00:...| 13| 91| 91| 0|
| 2|2015-11-22 10:00:...| 1| 1| 91| 90|
| 2|2015-11-22 10:10:...| 2| 3| 91| 88|
| 2|2015-11-22 10:20:...| 3| 6| 91| 85|
| 2|2015-11-22 10:30:...| 4| 10| 91| 81|
| 2|2015-11-22 10:40:...| 5| 15| 91| 76|
| 2|2015-11-22 10:50:...| 6| 21| 91| 70|
| 2|2015-11-22 11:00:...| 7| 28| 91| 63|
+---+--------------------+---+-------+------+----------+
only showing top 20 rows
補足
今回気付きましたが、SPARK_CLASSPATHを使うのはSpark 1.0以上では推奨されていないようです。
pyspark起動時に以下のようなメッセージが出ました。
15/11/22 12:32:44 WARN spark.SparkConf:
SPARK_CLASSPATH was detected (set to 'postgresql-9.4-1202.jdbc41.jar').
This is deprecated in Spark 1.0+.
Please instead use:
- ./spark-submit with --driver-class-path to augment the driver classpath
- spark.executor.extraClassPath to augment the executor classpath
どうも、クラスタを利用する場合には異なるサーバでこの環境変数が正しく伝わらないため、別のパラメータを使うことが推奨されているようです。
うむむ。
こういうローカルと分散環境の違い、きちんと把握していかないとなぁ。