Spark DataframeのSample Code集

  • 43
    いいね
  • 0
    コメント

はじめに:Spark Dataframeとは

Spark Ver 1.3からSpark Dataframeという機能が追加されました。特徴として以下の様な物があります。

  • Spark RDDにSchema設定を加えると、Spark DataframeのObjectを作成できる
  • Dataframeの利点は、
    • SQL風の文法で、条件に該当する行を抽出したり、Dataframe同士のJoinができる
    • filter, selectというmethodで、条件に該当する行、列を抽出できる
    • groupBy → aggというmethodで、Logの様々な集計ができる
    • UDF(User Defined Function)で独自関数で列に処理ができる
    • SQLで言うPivotもサポート (Spark v1.6からの機能)

つまり、RDDのmapfilterでシコシコ記述するよりもSimple Codeで、且つ高速に処理が行えるのがウリです。Dataの前処理はRDDでやるとして、さっさとDataframeに読み込んだ方がmajiで捗ります。Dataframeのメモが散在したので、備忘録がてらSample codeをチラ裏しておきます。

なお、

Sample Logの読み込み

Access Logを題材として使います。技術評論社さんの本で使われていたAccess Log(csv)で、csv fileへの直リンはこちらです。csvの中身は、日付、User_ID, Campaign_IDの3つの情報を持つ以下の様なLogです

click.at    user.id campaign.id
2015/4/27 20:40 144012  Campaign077
2015/4/27 0:27  24485   Campaign063
2015/4/27 0:28  24485   Campaign063
2015/4/27 0:33  24485   Campaign038

csvを読み込んでRDDにします。1行目のheaderの削除と、1列目をdatetime Objectとして読み込みます。

import json, os, datetime, collections, commands
from pyspark.sql import SQLContext, Row
from pyspark.sql.types import *

if not os.path.exists("./click_data_sample.csv"):
    print "csv file not found at master node, will download and copy to HDFS"
    commands.getoutput("wget -q http://image.gihyo.co.jp/assets/files/book/2015/978-4-7741-7631-4/download/click_data_sample.csv")
    commands.getoutput("hadoop fs -copyFromLocal -f ./click_data_sample.csv /user/hadoop/")

whole_raw_log = sc.textFile("/user/hadoop/click_data_sample.csv")
header = whole_raw_log.first()
whole_log = whole_raw_log.filter(lambda x:x !=header).map(lambda line: line.split(","))\
            .map(lambda line: [datetime.datetime.strptime(line[0].replace('"', ''), '%Y-%m-%d %H:%M:%S'), int(line[1]), line[2].replace('"', '')])

whole_log.take(3)
#[[datetime.datetime(2015, 4, 27, 20, 40, 40), 144012, u'Campaign077'],
# [datetime.datetime(2015, 4, 27, 0, 27, 55), 24485, u'Campaign063'],
# [datetime.datetime(2015, 4, 27, 0, 28, 13), 24485, u'Campaign063']]

Dataframeの作成方法

RDDから作成

Dataframeは、元となるRDDがあれば、Columnの名前とそれぞれのType(TimestampType, IntegerType, StringTypeなど)を指定して、sqlContext.createDataFrame(my_rdd, my_schema)で作成できます。Schemaの定義はここを参照

printSchema(), dtypesでSchema情報、count()で行数、show(n)で最初のn件のrecordの表示です。

fields = [StructField("access_time", TimestampType(), True), StructField("userID", IntegerType(), True), StructField("campaignID", StringType(), True)]
schema = StructType(fields)

whole_log_df = sqlContext.createDataFrame(whole_log, schema)
print whole_log_df.count()
print whole_log_df.printSchema()
print whole_log_df.dtypes
print whole_log_df.show(5)

#327430
#root
# |-- access_time: timestamp (nullable = true)
# |-- userID: integer (nullable = true)
# |-- campaignID: string (nullable = true)
#
#[('access_time', 'timestamp'), ('userID', 'int'), ('campaignID', 'string')]
#
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-04-27 20:40:...|144012|Campaign077|
#|2015-04-27 00:27:...| 24485|Campaign063|
#|2015-04-27 00:28:...| 24485|Campaign063|
#|2015-04-27 00:33:...| 24485|Campaign038|
#|2015-04-27 01:00:...| 24485|Campaign063|
csv fileから直接作成

csvから読み込んだdataをそのままDataframeにするには、Spark Packageの1つであるspark-csvを使うと楽です。特に指定しないと全てstringとして読み込みますが、inferSchemaを指定してあげると良い感じに類推してくれます。

whole_log_df_2 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load("/user/hadoop/click_data_sample.csv")
print whole_log_df_2.printSchema()
print whole_log_df_2.show(5)

#root
# |-- click.at: string (nullable = true)
# |-- user.id: string (nullable = true)
# |-- campaign.id: string (nullable = true)
#
#+-------------------+-------+-----------+
#|           click.at|user.id|campaign.id|
#+-------------------+-------+-----------+
#|2015-04-27 20:40:40| 144012|Campaign077|
#|2015-04-27 00:27:55|  24485|Campaign063|
#|2015-04-27 00:28:13|  24485|Campaign063|
#|2015-04-27 00:33:42|  24485|Campaign038|
#|2015-04-27 01:00:04|  24485|Campaign063|

whole_log_df_3 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("/user/hadoop/click_data_sample.csv")
print whole_log_df_3.printSchema()

#root
# |-- click.at: timestamp (nullable = true)
# |-- user.id: integer (nullable = true)
# |-- campaign.id: string (nullable = true)

ちなみに、column名に.が入って居ると色々面倒なので、withColumnRenamedでrename可能です(renameした別のDataframeを作成可能です)。

whole_log_df_4 = whole_log_df_3.withColumnRenamed("click.at", "access_time")\
                 .withColumnRenamed("user.id", "userID")\
                 .withColumnRenamed("campaign.id", "campaignID")
print whole_log_df_4.printSchema()

#root
# |-- access_time: timestamp (nullable = true)
# |-- userID: integer (nullable = true)
# |-- campaignID: string (nullable = true)
jsonから直接作成

json fileから読み込んだdataをそのままDataframeにするにはsqlContext.read.jsonを使います。fileの各行を1 json objectとして扱います、存在しないKeyがある場合には、nullが入ります。

# test_json.json contains following 3 lines, last line doesn't have "campaignID" key
#
#{"access_time": "2015-04-27 20:40:40", "userID": "24485", "campaignID": "Campaign063"}
#{"access_time": "2015-04-27 00:27:55", "userID": "24485", "campaignID": "Campaign038"}
#{"access_time": "2015-04-27 00:27:55", "userID": "24485"}

df_json = sqlContext.read.json("/user/hadoop/test_json.json")
df_json.printSchema()
df_json.show(5)

#root
# |-- access_time: string (nullable = true)
# |-- campaignID: string (nullable = true)
# |-- userID: string (nullable = true)
#
#+-------------------+-----------+------+
#|        access_time| campaignID|userID|
#+-------------------+-----------+------+
#|2015-04-27 20:40:40|Campaign063| 24485|
#|2015-04-27 00:27:55|Campaign038| 24485|
#|2015-04-27 00:27:55|       null| 24485|
#+-------------------+-----------+------+
parquetから直接作成

parquet fileから読み込んだdataをそのままDataframeにするにはsqlContext.read.parquetを使います。parquet fileが置いてあるFolderを指定すると、そのFolder以下のparquet fileを一括で読み込んでくれます。

sqlContext.read.parquet("/user/hadoop/parquet_folder/")

SQL文でQuery

Dataframeに対して、SQLの文でQueryを掛けるSampleです。registerTempTableでDataframeにSQL Table nameを付与すると、SQLのTable名として参照できます。sqlContext.sql(SQL文)の戻り値もDataframeです。

なお、Sub Queryを記載する事も可能なのですが、Sub Query側にAliasを付与しないと、何故かSyntax errorが起きるので注意です。

#単純なSQL query

whole_log_df.registerTempTable("whole_log_table")

print sqlContext.sql(" SELECT * FROM whole_log_table where campaignID == 'Campaign047' ").count()
#18081
print sqlContext.sql(" SELECT * FROM whole_log_table where campaignID == 'Campaign047' ").show(5)
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:27:...| 14151|Campaign047|
#|2015-04-27 05:28:...| 14151|Campaign047|
#+--------------------+------+-----------+


#SQL文の中に変数を入れる場合
for count in range(1, 3):
    print "Campaign00" + str(count)
    print sqlContext.sql("SELECT count(*) as access_num FROM whole_log_table where campaignID == 'Campaign00" + str(count) + "'").show()

#Campaign001
#+----------+
#|access_num|
#+----------+
#|      2407|
#+----------+
#
#Campaign002
#+----------+
#|access_num|
#+----------+
#|      1674|
#+----------+

#Sub Queryの場合:
print sqlContext.sql("SELECT count(*) as first_count FROM (SELECT userID, min(access_time) as first_access_date FROM whole_log_table GROUP BY userID) subquery_alias WHERE first_access_date < '2015-04-28'").show(5)
#+------------+
#|first_count |
#+------------+
#|       20480|
#+------------+

filter, selectで条件付き検索

Dataframeに対しての簡易的な検索機能です。上記にあるSQL文でQueryと機能は似ていますが、filter, selectは簡易的な検索機能という位置づけです。filterは条件に該当する行の抽出、selectは列を抽出します。RDDのfilterとちょっと文法が違うのに注意です。

#Sample for filter
print whole_log_df.filter(whole_log_df["access_time"] < "2015-04-28").count()
#41434
print whole_log_df.filter(whole_log_df["access_time"] > "2015-05-01").show(3)
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-05-01 22:11:...|114157|Campaign002|
#|2015-05-01 23:36:...| 93708|Campaign055|
#|2015-05-01 22:51:...| 57798|Campaign046|
#+--------------------+------+-----------+

#Sample for select
print whole_log_df.select("access_time", "userID").show(3)
#+--------------------+------+
#|         access_time|userID|
#+--------------------+------+
#|2015-04-27 20:40:...|144012|
#|2015-04-27 00:27:...| 24485|
#|2015-04-27 00:28:...| 24485|
#+--------------------+------+

groupByで集計

groupByは、RDDのreduceByKeyに似た機能を提供しますが、groupByここにあるmethodをその後ろでCallする事で、様々な集計機能を実現できます。代表的なのはaggcountです。

groupBycountで集計

campaignIDをKeyにしてgroupByを実行し、そのRecord数をcount()で集計してくれます。groupByに複数のKeyを列挙すれば、その組み合わせをkeyとしてgroupByしてくれます。

print whole_log_df.groupBy("campaignID").count().sort("count", ascending=False).show(5)
#+-----------+-----+
#| campaignID|count|
#+-----------+-----+
#|Campaign116|22193|
#|Campaign027|19206|
#|Campaign047|18081|
#|Campaign107|13295|
#|Campaign131| 9068|
#+-----------+-----+

print whole_log_df.groupBy("campaignID", "userID").count().sort("count", ascending=False).show(5)
#+-----------+------+-----+
#| campaignID|userID|count|
#+-----------+------+-----+
#|Campaign047| 30292|  633|
#|Campaign086|107624|  623|
#|Campaign047|121150|  517|
#|Campaign086| 22975|  491|
#|Campaign122| 90714|  431|
#+-----------+------+-----+
groupByaggで集計

userIDをKeyにしてGroupByを実行し、その集計結果の平均や最大/最小を計算が可能です。agg({key:value})で、keyの列に対してvalueの関数(min,sum, ave etc)を実行した結果を返します。戻り値はDataframeなので、.filter()で更に行を絞る事も可能です。

print whole_log_df.groupBy("userID").agg({"access_time": "min"}).show(3)
#+------+--------------------+
#|userID|    min(access_time)|
#+------+--------------------+
#|  4831|2015-04-27 22:49:...|
#| 48631|2015-04-27 22:15:...|
#|143031|2015-04-27 21:52:...|
#+------+--------------------+

print whole_log_df.groupBy("userID").agg({"access_time": "min"}).filter("min(access_time) < '2015-04-28'").count()
#20480
groupBypivotで縦横変換

PivotはSpark v1.6からの新機能でSQLのPivotと似た機能を提供します。Sample codeのPivotの場合、以下の様に縦横が変化します。

  • pivot前(agged_df)
    • 行数が(UserID数 (=75,545) x campainID数 (=133) )
    • 列が3列
  • pivot後(pivot_df)
    • 行数がUserID数 (=75,545)
    • 列が UserID + CampainID数 = 1 + 133 = 134

必ず、groupBy("縦のままの列").pivot("縦から横へ変換したい列").sum("集計値の列")と3つのmethodをchainで呼ぶ必要があります。

agged_df = whole_log_df.groupBy("userID", "campaignID").count()
print agged_df.show(3)

#+------+-----------+-----+
#|userID| campaignID|count|
#+------+-----------+-----+
#|155812|Campaign107|    4|
#|103339|Campaign027|    1|
#|169114|Campaign112|    1|
#+------+-----------+-----+

#値が無いCellは、nullが入る
pivot_df = agged_df.groupBy("userID").pivot("campaignID").sum("count")
print pivot_df.printSchema()

#root
# |-- userID: integer (nullable = true)
# |-- Campaign001: long (nullable = true)
# |-- Campaign002: long (nullable = true)
# ..
# |-- Campaign133: long (nullable = true)

#値が無いCellを0で埋めたい場合
pivot_df2 = agged_df.groupBy("userID").pivot("campaignID").sum("count").fillna(0)

UDFで列の追加

Spark DataframeではUDFが使えます、主な用途は、列の追加になるかと思います。Dataframeは基本Immutable(不変)なので、列の中身の変更はできず、列を追加した別のDataframeを作成する事になります。

from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import DoubleType

def add_day_column(access_time):
    return int(access_time.strftime("%Y%m%d"))

my_udf = UserDefinedFunction(add_day_column, IntegerType())
print whole_log_df.withColumn("access_day", my_udf("access_time")).show(5)

#+--------------------+------+-----------+----------+
#|         access_time|userID| campaignID|access_day|
#+--------------------+------+-----------+----------+
#|2015-04-27 20:40:...|144012|Campaign077|  20150427|
#|2015-04-27 00:27:...| 24485|Campaign063|  20150427|
#|2015-04-27 00:28:...| 24485|Campaign063|  20150427|
#|2015-04-27 00:33:...| 24485|Campaign038|  20150427|
#|2015-04-27 01:00:...| 24485|Campaign063|  20150427|
#+--------------------+------+-----------+----------+

UDFの表記は、lambda関数を使って書くことも可能です。

my_udf2 = UserDefinedFunction(lambda x: x + 5, IntegerType())
print whole_log_df.withColumn("userID_2", my_udf2("userID")).show(5)

#+--------------------+------+-----------+--------+
#|         access_time|userID| campaignID|userID_2|
#+--------------------+------+-----------+--------+
#|2015-04-27 20:40:...|144012|Campaign077|  144017|
#|2015-04-27 00:27:...| 24485|Campaign063|   24490|
#|2015-04-27 00:28:...| 24485|Campaign063|   24490|
#|2015-04-27 00:33:...| 24485|Campaign038|   24490|
#|2015-04-27 01:00:...| 24485|Campaign063|   24490|
#+--------------------+------+-----------+--------+

逆に、特定の列を削除したいDataframeを作るにはdf.drop()を使います。

print whole_log_df.drop("userID").show(3)

#+--------------------+-----------+
#|         access_time| campaignID|
#+--------------------+-----------+
#|2015-04-27 20:40:...|Campaign077|
#|2015-04-27 00:27:...|Campaign063|
#|2015-04-27 00:28:...|Campaign063|
#+--------------------+-----------+

Joinで2つのDataframeを結合させる

2つのDataframeをJoinさせる事も可能です。ここでは、Heavy User(Access数が100回以上あるUser)のLogのみを全体のLogから抽出するケースを考えてみます。

まず、Access数が100回以上あるUserのUser IDとそのAccess数を、.groupBy("userID").count()で集計し、filterで100回以上に絞り込みます。

heavy_user_df1 = whole_log_df.groupBy("userID").count()
heavy_user_df2 = heavy_user_df1.filter(heavy_user_df1 ["count"] >= 100)

print heavy_user_df2 .printSchema()
print heavy_user_df2 .show(3)
print heavy_user_df2 .count()

#root
# |-- userID: integer (nullable = true)
# |-- count: long (nullable = false)
#
#+------+-----+
#|userID|count|
#+------+-----+
#| 84231|  134|
#| 13431|  128|
#|144432|  113|
#+------+-----+
#
#177

元のDataframe(こちらがLeftになる)でjoin methodを呼び、joinの相手(Rightになる)とjoinの条件を書くと、SQLのjoinの様にDataframeの結合が可能です。

joinの形式は、inner, outer, left_outer, rignt_outerなどが選べるはずなのですが、inner以外は意図した挙動で動いてくれない為(しかもouterとして処理される)、とりあえずinnerで外部結合して後にdropで要らないColumnを削除するようにしています。詳細Option等は公式Pageを参照して下さい。

以下のjoin処理により、Access数が100回以上あるUser(177名)に該当する38,729行のLogを取り出す事ができました(全体のLogは約32万行)。

joinded_df = whole_log_df.join(heavy_user_df2, whole_log_df["userID"] == heavy_user_df2["userID"], "inner").drop(heavy_user_df2["userID"]).drop("count")
print joinded_df.printSchema()
print joinded_df.show(3)
print joinded_df.count()

#root
# |-- access_time: timestamp (nullable = true)
# |-- campaignID: string (nullable = true)
# |-- userID: integer (nullable = true)

#None
#+--------------------+-----------+------+
#|         access_time| campaignID|userID|
#+--------------------+-----------+------+
#|2015-04-27 02:07:...|Campaign086| 13431|
#|2015-04-28 00:07:...|Campaign086| 13431|
#|2015-04-29 06:01:...|Campaign047| 13431|
#+--------------------+-----------+------+
#
#38729

Dataframeから列を取り出す

  • 列のLabelを取り出すには、df.columnsで列のLabelのList(not Dataframe)が取り出せる
  • 特定の列を取り出すには、df.select("userID").map(lambda x: x[0]).collect()で、"userID"列のList(not RDD/Dataframe)が取り出せる
  • 特定の列で重複が無いListを取り出すには.distinct()をDataframeの最後に追加すればOK
print whole_log_df.columns
#['access_time', 'userID', 'campaignID']

print whole_log_df.select("userID").map(lambda x: x[0]).collect()[:5]
#[144012, 24485, 24485, 24485, 24485]

print whole_log_df.select("userID").distinct().map(lambda x:x[0]).collect()[:5]
#[4831, 48631, 143031, 39631, 80831]

DataframeからRDD/Listに戻す

DataframeはをRDDに戻すには、大きく2つの方法があります

  • .mapを呼ぶ
    • DataframeのSchema情報は破棄され、Dataframeの各行がそれぞれlistになったRDDに変換されます
  • .rddを呼ぶ
    • Dataframeの各行がそれぞれRow OjbectなRDDに変換されます。Row ObjectはSpark SQLで一行分のデータを保持する為のObjectです
    • my_rdd.rdd.map(lambda x:x.asDict())と、Row objectに対して.asDict()を呼んであげると、Key-ValueなRDDに変換可能です
#convert to rdd by ".map"
print whole_log_df.groupBy("campaignID").count().map(lambda x: [x[0], x[1]]).take(5)
#[[u'Campaign033', 786], [u'Campaign034', 3867], [u'Campaign035', 963], [u'Campaign036', 1267], [u'Campaign037', 1010]]

# rdd -> normal list can be done with "collect".
print whole_log_df.groupBy("campaignID").count().map(lambda x: [x[0], x[1]]).collect()[:5]
#[[u'Campaign033', 786], [u'Campaign034', 3867], [u'Campaign035', 963], [u'Campaign036', 1267], [u'Campaign037', 1010]]

#convert to rdd by ".rdd" will return "Row" object
print whole_log_df.groupBy("campaignID").rdd.take(3)
#[Row(campaignID=u'Campaign033', count=786), Row(campaignID=u'Campaign034', count=3867), Row(campaignID=u'Campaign035', count=963)]

#`.asDict()` will convert to Key-Value RDD from Row object
print whole_log_df.groupBy("campaignID").rdd.map(lambda x:x.asDict()).take(3)
#[{'count': 786, 'campaignID': u'Campaign033'}, {'count': 3867, 'campaignID': u'Campaign034'}, {'count': 963, 'campaignID': u'Campaign035'}]

DataframeからParquet fileに書き出す

DataframeをParquet形式でfileに書き出せば、schema情報を保持したままfileにExportが可能です。なお、ExportするS3 bucketのdirectoryが既に存在する場合には書き込みが失敗します、まだ存在していないDirectory名を指定して下さい。

#write to parquet filed
whole_log_df.select("access_time", "userID").write.parquet("s3n://my_S3_bucket/parquet_export") 

#reload from parquet filed
reload_df = sqlContext.read.parquet("s3n://my_S3_bucket/parquet_export") 
print reload_df.printSchema()