Help us understand the problem. What is going on with this article?

Spark DataframeのSample Code集

More than 1 year has passed since last update.

はじめに:Spark Dataframeとは

Spark Ver 1.3からSpark Dataframeという機能が追加されました。特徴として以下の様な物があります。

  • Spark RDDにSchema設定を加えると、Spark DataframeのObjectを作成できる
  • Dataframeの利点は、
    • SQL風の文法で、条件に該当する行を抽出したり、Dataframe同士のJoinができる
    • filter, selectというmethodで、条件に該当する行、列を抽出できる
    • groupBy → aggというmethodで、Logの様々な集計ができる
    • UDF(User Defined Function)で独自関数で列に処理ができる
    • SQLで言うPivotもサポート (Spark v1.6からの機能)

つまり、RDDのmapfilterでシコシコ記述するよりもSimple Codeで、且つ高速に処理が行えるのがウリです。Dataの前処理はRDDでやるとして、さっさとDataframeに読み込んだ方がmajiで捗ります。Dataframeのメモが散在したので、備忘録がてらSample codeをチラ裏しておきます。

なお、

Sample Logの読み込み

Access Logを題材として使います。技術評論社さんの本で使われていたAccess Log(csv)で、csv fileへの直リンはこちらです。csvの中身は、日付、User_ID, Campaign_IDの3つの情報を持つ以下の様なLogです

click.at    user.id campaign.id
2015/4/27 20:40 144012  Campaign077
2015/4/27 0:27  24485   Campaign063
2015/4/27 0:28  24485   Campaign063
2015/4/27 0:33  24485   Campaign038

csvを読み込んでRDDにします。1行目のheaderの削除と、1列目をdatetime Objectとして読み込みます。

import json, os, datetime, collections, commands
from pyspark.sql import SQLContext, Row
from pyspark.sql.types import *

if not os.path.exists("./click_data_sample.csv"):
    print "csv file not found at master node, will download and copy to HDFS"
    commands.getoutput("wget -q http://image.gihyo.co.jp/assets/files/book/2015/978-4-7741-7631-4/download/click_data_sample.csv")
    commands.getoutput("hadoop fs -copyFromLocal -f ./click_data_sample.csv /user/hadoop/")

whole_raw_log = sc.textFile("/user/hadoop/click_data_sample.csv")
header = whole_raw_log.first()
whole_log = whole_raw_log.filter(lambda x:x !=header).map(lambda line: line.split(","))\
            .map(lambda line: [datetime.datetime.strptime(line[0].replace('"', ''), '%Y-%m-%d %H:%M:%S'), int(line[1]), line[2].replace('"', '')])

whole_log.take(3)
#[[datetime.datetime(2015, 4, 27, 20, 40, 40), 144012, u'Campaign077'],
# [datetime.datetime(2015, 4, 27, 0, 27, 55), 24485, u'Campaign063'],
# [datetime.datetime(2015, 4, 27, 0, 28, 13), 24485, u'Campaign063']]

Dataframeの作成方法

RDDから作成

Dataframeは、元となるRDDがあれば、Columnの名前とそれぞれのType(TimestampType, IntegerType, StringTypeなど)を指定して、sqlContext.createDataFrame(my_rdd, my_schema)で作成できます。Schemaの定義はここを参照

printSchema(), dtypesでSchema情報、count()で行数、show(n)で最初のn件のrecordの表示です。

fields = [StructField("access_time", TimestampType(), True), StructField("userID", IntegerType(), True), StructField("campaignID", StringType(), True)]
schema = StructType(fields)

whole_log_df = sqlContext.createDataFrame(whole_log, schema)
print whole_log_df.count()
print whole_log_df.printSchema()
print whole_log_df.dtypes
print whole_log_df.show(5)

#327430
#root
# |-- access_time: timestamp (nullable = true)
# |-- userID: integer (nullable = true)
# |-- campaignID: string (nullable = true)
#
#[('access_time', 'timestamp'), ('userID', 'int'), ('campaignID', 'string')]
#
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-04-27 20:40:...|144012|Campaign077|
#|2015-04-27 00:27:...| 24485|Campaign063|
#|2015-04-27 00:28:...| 24485|Campaign063|
#|2015-04-27 00:33:...| 24485|Campaign038|
#|2015-04-27 01:00:...| 24485|Campaign063|
csv fileから直接作成

csvから読み込んだdataをそのままDataframeにするには、Spark Packageの1つであるspark-csvを使うと楽です。特に指定しないと全てstringとして読み込みますが、inferSchemaを指定してあげると良い感じに類推してくれます。

whole_log_df_2 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load("/user/hadoop/click_data_sample.csv")
print whole_log_df_2.printSchema()
print whole_log_df_2.show(5)

#root
# |-- click.at: string (nullable = true)
# |-- user.id: string (nullable = true)
# |-- campaign.id: string (nullable = true)
#
#+-------------------+-------+-----------+
#|           click.at|user.id|campaign.id|
#+-------------------+-------+-----------+
#|2015-04-27 20:40:40| 144012|Campaign077|
#|2015-04-27 00:27:55|  24485|Campaign063|
#|2015-04-27 00:28:13|  24485|Campaign063|
#|2015-04-27 00:33:42|  24485|Campaign038|
#|2015-04-27 01:00:04|  24485|Campaign063|

whole_log_df_3 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true").load("/user/hadoop/click_data_sample.csv")
print whole_log_df_3.printSchema()

#root
# |-- click.at: timestamp (nullable = true)
# |-- user.id: integer (nullable = true)
# |-- campaign.id: string (nullable = true)

ちなみに、column名に.が入って居ると色々面倒なので、withColumnRenamedでrename可能です(renameした別のDataframeを作成可能です)。

whole_log_df_4 = whole_log_df_3.withColumnRenamed("click.at", "access_time")\
                 .withColumnRenamed("user.id", "userID")\
                 .withColumnRenamed("campaign.id", "campaignID")
print whole_log_df_4.printSchema()

#root
# |-- access_time: timestamp (nullable = true)
# |-- userID: integer (nullable = true)
# |-- campaignID: string (nullable = true)
jsonから直接作成

json fileから読み込んだdataをそのままDataframeにするにはsqlContext.read.jsonを使います。fileの各行を1 json objectとして扱います、存在しないKeyがある場合には、nullが入ります。

# test_json.json contains following 3 lines, last line doesn't have "campaignID" key
#
#{"access_time": "2015-04-27 20:40:40", "userID": "24485", "campaignID": "Campaign063"}
#{"access_time": "2015-04-27 00:27:55", "userID": "24485", "campaignID": "Campaign038"}
#{"access_time": "2015-04-27 00:27:55", "userID": "24485"}

df_json = sqlContext.read.json("/user/hadoop/test_json.json")
df_json.printSchema()
df_json.show(5)

#root
# |-- access_time: string (nullable = true)
# |-- campaignID: string (nullable = true)
# |-- userID: string (nullable = true)
#
#+-------------------+-----------+------+
#|        access_time| campaignID|userID|
#+-------------------+-----------+------+
#|2015-04-27 20:40:40|Campaign063| 24485|
#|2015-04-27 00:27:55|Campaign038| 24485|
#|2015-04-27 00:27:55|       null| 24485|
#+-------------------+-----------+------+
parquetから直接作成

parquet fileから読み込んだdataをそのままDataframeにするにはsqlContext.read.parquetを使います。parquet fileが置いてあるFolderを指定すると、そのFolder以下のparquet fileを一括で読み込んでくれます。

sqlContext.read.parquet("/user/hadoop/parquet_folder/")

SQL文でQuery

Dataframeに対して、SQLの文でQueryを掛けるSampleです。registerTempTableでDataframeにSQL Table nameを付与すると、SQLのTable名として参照できます。sqlContext.sql(SQL文)の戻り値もDataframeです。

なお、Sub Queryを記載する事も可能なのですが、Sub Query側にAliasを付与しないと、何故かSyntax errorが起きるので注意です。

#単純なSQL query

whole_log_df.registerTempTable("whole_log_table")

print sqlContext.sql(" SELECT * FROM whole_log_table where campaignID == 'Campaign047' ").count()
#18081
print sqlContext.sql(" SELECT * FROM whole_log_table where campaignID == 'Campaign047' ").show(5)
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:26:...| 14151|Campaign047|
#|2015-04-27 05:27:...| 14151|Campaign047|
#|2015-04-27 05:28:...| 14151|Campaign047|
#+--------------------+------+-----------+


#SQL文の中に変数を入れる場合
for count in range(1, 3):
    print "Campaign00" + str(count)
    print sqlContext.sql("SELECT count(*) as access_num FROM whole_log_table where campaignID == 'Campaign00" + str(count) + "'").show()

#Campaign001
#+----------+
#|access_num|
#+----------+
#|      2407|
#+----------+
#
#Campaign002
#+----------+
#|access_num|
#+----------+
#|      1674|
#+----------+

#Sub Queryの場合:
print sqlContext.sql("SELECT count(*) as first_count FROM (SELECT userID, min(access_time) as first_access_date FROM whole_log_table GROUP BY userID) subquery_alias WHERE first_access_date < '2015-04-28'").show(5)
#+------------+
#|first_count |
#+------------+
#|       20480|
#+------------+

filter, selectで条件付き検索

Dataframeに対しての簡易的な検索機能です。上記にあるSQL文でQueryと機能は似ていますが、filter, selectは簡易的な検索機能という位置づけです。filterは条件に該当する行の抽出、selectは列を抽出します。RDDのfilterとちょっと文法が違うのに注意です。

#Sample for filter
print whole_log_df.filter(whole_log_df["access_time"] < "2015-04-28").count()
#41434
print whole_log_df.filter(whole_log_df["access_time"] > "2015-05-01").show(3)
#+--------------------+------+-----------+
#|         access_time|userID| campaignID|
#+--------------------+------+-----------+
#|2015-05-01 22:11:...|114157|Campaign002|
#|2015-05-01 23:36:...| 93708|Campaign055|
#|2015-05-01 22:51:...| 57798|Campaign046|
#+--------------------+------+-----------+

#Sample for select
print whole_log_df.select("access_time", "userID").show(3)
#+--------------------+------+
#|         access_time|userID|
#+--------------------+------+
#|2015-04-27 20:40:...|144012|
#|2015-04-27 00:27:...| 24485|
#|2015-04-27 00:28:...| 24485|
#+--------------------+------+

groupByで集計

groupByは、RDDのreduceByKeyに似た機能を提供しますが、groupByここにあるmethodをその後ろでCallする事で、様々な集計機能を実現できます。代表的なのはaggcountです。

groupBycountで集計

campaignIDをKeyにしてgroupByを実行し、そのRecord数をcount()で集計してくれます。groupByに複数のKeyを列挙すれば、その組み合わせをkeyとしてgroupByしてくれます。

print whole_log_df.groupBy("campaignID").count().sort("count", ascending=False).show(5)
#+-----------+-----+
#| campaignID|count|
#+-----------+-----+
#|Campaign116|22193|
#|Campaign027|19206|
#|Campaign047|18081|
#|Campaign107|13295|
#|Campaign131| 9068|
#+-----------+-----+

print whole_log_df.groupBy("campaignID", "userID").count().sort("count", ascending=False).show(5)
#+-----------+------+-----+
#| campaignID|userID|count|
#+-----------+------+-----+
#|Campaign047| 30292|  633|
#|Campaign086|107624|  623|
#|Campaign047|121150|  517|
#|Campaign086| 22975|  491|
#|Campaign122| 90714|  431|
#+-----------+------+-----+
groupByaggで集計

userIDをKeyにしてGroupByを実行し、その集計結果の平均や最大/最小を計算が可能です。agg({key:value})で、keyの列に対してvalueの関数(min,sum, ave etc)を実行した結果を返します。戻り値はDataframeなので、.filter()で更に行を絞る事も可能です。

print whole_log_df.groupBy("userID").agg({"access_time": "min"}).show(3)
#+------+--------------------+
#|userID|    min(access_time)|
#+------+--------------------+
#|  4831|2015-04-27 22:49:...|
#| 48631|2015-04-27 22:15:...|
#|143031|2015-04-27 21:52:...|
#+------+--------------------+

print whole_log_df.groupBy("userID").agg({"access_time": "min"}).filter("min(access_time) < '2015-04-28'").count()
#20480
groupBypivotで縦横変換

PivotはSpark v1.6からの新機能でSQLのPivotと似た機能を提供します。Sample codeのPivotの場合、以下の様に縦横が変化します。

  • pivot前(agged_df)
    • 行数が(UserID数 (=75,545) x campainID数 (=133) )
    • 列が3列
  • pivot後(pivot_df)
    • 行数がUserID数 (=75,545)
    • 列が UserID + CampainID数 = 1 + 133 = 134

必ず、groupBy("縦のままの列").pivot("縦から横へ変換したい列").sum("集計値の列")と3つのmethodをchainで呼ぶ必要があります。

agged_df = whole_log_df.groupBy("userID", "campaignID").count()
print agged_df.show(3)

#+------+-----------+-----+
#|userID| campaignID|count|
#+------+-----------+-----+
#|155812|Campaign107|    4|
#|103339|Campaign027|    1|
#|169114|Campaign112|    1|
#+------+-----------+-----+

#値が無いCellは、nullが入る
pivot_df = agged_df.groupBy("userID").pivot("campaignID").sum("count")
print pivot_df.printSchema()

#root
# |-- userID: integer (nullable = true)
# |-- Campaign001: long (nullable = true)
# |-- Campaign002: long (nullable = true)
# ..
# |-- Campaign133: long (nullable = true)

#値が無いCellを0で埋めたい場合
pivot_df2 = agged_df.groupBy("userID").pivot("campaignID").sum("count").fillna(0)

UDFで列の追加

Spark DataframeではUDFが使えます、主な用途は、列の追加になるかと思います。Dataframeは基本Immutable(不変)なので、列の中身の変更はできず、列を追加した別のDataframeを作成する事になります。

from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.types import DoubleType

def add_day_column(access_time):
    return int(access_time.strftime("%Y%m%d"))

my_udf = UserDefinedFunction(add_day_column, IntegerType())
print whole_log_df.withColumn("access_day", my_udf("access_time")).show(5)

#+--------------------+------+-----------+----------+
#|         access_time|userID| campaignID|access_day|
#+--------------------+------+-----------+----------+
#|2015-04-27 20:40:...|144012|Campaign077|  20150427|
#|2015-04-27 00:27:...| 24485|Campaign063|  20150427|
#|2015-04-27 00:28:...| 24485|Campaign063|  20150427|
#|2015-04-27 00:33:...| 24485|Campaign038|  20150427|
#|2015-04-27 01:00:...| 24485|Campaign063|  20150427|
#+--------------------+------+-----------+----------+

UDFの表記は、lambda関数を使って書くことも可能です。

my_udf2 = UserDefinedFunction(lambda x: x + 5, IntegerType())
print whole_log_df.withColumn("userID_2", my_udf2("userID")).show(5)

#+--------------------+------+-----------+--------+
#|         access_time|userID| campaignID|userID_2|
#+--------------------+------+-----------+--------+
#|2015-04-27 20:40:...|144012|Campaign077|  144017|
#|2015-04-27 00:27:...| 24485|Campaign063|   24490|
#|2015-04-27 00:28:...| 24485|Campaign063|   24490|
#|2015-04-27 00:33:...| 24485|Campaign038|   24490|
#|2015-04-27 01:00:...| 24485|Campaign063|   24490|
#+--------------------+------+-----------+--------+

逆に、特定の列を削除したいDataframeを作るにはdf.drop()を使います。

print whole_log_df.drop("userID").show(3)

#+--------------------+-----------+
#|         access_time| campaignID|
#+--------------------+-----------+
#|2015-04-27 20:40:...|Campaign077|
#|2015-04-27 00:27:...|Campaign063|
#|2015-04-27 00:28:...|Campaign063|
#+--------------------+-----------+

Joinで2つのDataframeを結合させる

2つのDataframeをJoinさせる事も可能です。ここでは、Heavy User(Access数が100回以上あるUser)のLogのみを全体のLogから抽出するケースを考えてみます。

まず、Access数が100回以上あるUserのUser IDとそのAccess数を、.groupBy("userID").count()で集計し、filterで100回以上に絞り込みます。

heavy_user_df1 = whole_log_df.groupBy("userID").count()
heavy_user_df2 = heavy_user_df1.filter(heavy_user_df1 ["count"] >= 100)

print heavy_user_df2 .printSchema()
print heavy_user_df2 .show(3)
print heavy_user_df2 .count()

#root
# |-- userID: integer (nullable = true)
# |-- count: long (nullable = false)
#
#+------+-----+
#|userID|count|
#+------+-----+
#| 84231|  134|
#| 13431|  128|
#|144432|  113|
#+------+-----+
#
#177

元のDataframe(こちらがLeftになる)でjoin methodを呼び、joinの相手(Rightになる)とjoinの条件を書くと、SQLのjoinの様にDataframeの結合が可能です。

joinの形式は、inner, outer, left_outer, rignt_outerなどが選べるはずなのですが、inner以外は意図した挙動で動いてくれない為(しかもouterとして処理される)、とりあえずinnerで外部結合して後にdropで要らないColumnを削除するようにしています。詳細Option等は公式Pageを参照して下さい。

以下のjoin処理により、Access数が100回以上あるUser(177名)に該当する38,729行のLogを取り出す事ができました(全体のLogは約32万行)。

joinded_df = whole_log_df.join(heavy_user_df2, whole_log_df["userID"] == heavy_user_df2["userID"], "inner").drop(heavy_user_df2["userID"]).drop("count")
print joinded_df.printSchema()
print joinded_df.show(3)
print joinded_df.count()

#root
# |-- access_time: timestamp (nullable = true)
# |-- campaignID: string (nullable = true)
# |-- userID: integer (nullable = true)

#None
#+--------------------+-----------+------+
#|         access_time| campaignID|userID|
#+--------------------+-----------+------+
#|2015-04-27 02:07:...|Campaign086| 13431|
#|2015-04-28 00:07:...|Campaign086| 13431|
#|2015-04-29 06:01:...|Campaign047| 13431|
#+--------------------+-----------+------+
#
#38729

Dataframeから列を取り出す

  • 列のLabelを取り出すには、df.columnsで列のLabelのList(not Dataframe)が取り出せる
  • 特定の列を取り出すには、df.select("userID").map(lambda x: x[0]).collect()で、"userID"列のList(not RDD/Dataframe)が取り出せる
  • 特定の列で重複が無いListを取り出すには.distinct()をDataframeの最後に追加すればOK
print whole_log_df.columns
#['access_time', 'userID', 'campaignID']

print whole_log_df.select("userID").map(lambda x: x[0]).collect()[:5]
#[144012, 24485, 24485, 24485, 24485]

print whole_log_df.select("userID").distinct().map(lambda x:x[0]).collect()[:5]
#[4831, 48631, 143031, 39631, 80831]

DataframeからRDD/Listに戻す

DataframeをRDDに戻すには、大きく2つの方法があります

  • .mapを呼ぶ
    • DataframeのSchema情報は破棄され、Dataframeの各行がそれぞれlistになったRDDに変換されます
  • .rddを呼ぶ
    • Dataframeの各行がそれぞれRow OjbectなRDDに変換されます。Row ObjectはSpark SQLで一行分のデータを保持する為のObjectです
    • my_rdd.rdd.map(lambda x:x.asDict())と、Row objectに対して.asDict()を呼んであげると、Key-ValueなRDDに変換可能です
#convert to rdd by ".map"
print whole_log_df.groupBy("campaignID").count().map(lambda x: [x[0], x[1]]).take(5)
#[[u'Campaign033', 786], [u'Campaign034', 3867], [u'Campaign035', 963], [u'Campaign036', 1267], [u'Campaign037', 1010]]

# rdd -> normal list can be done with "collect".
print whole_log_df.groupBy("campaignID").count().map(lambda x: [x[0], x[1]]).collect()[:5]
#[[u'Campaign033', 786], [u'Campaign034', 3867], [u'Campaign035', 963], [u'Campaign036', 1267], [u'Campaign037', 1010]]

#convert to rdd by ".rdd" will return "Row" object
print whole_log_df.groupBy("campaignID").rdd.take(3)
#[Row(campaignID=u'Campaign033', count=786), Row(campaignID=u'Campaign034', count=3867), Row(campaignID=u'Campaign035', count=963)]

#`.asDict()` will convert to Key-Value RDD from Row object
print whole_log_df.groupBy("campaignID").rdd.map(lambda x:x.asDict()).take(3)
#[{'count': 786, 'campaignID': u'Campaign033'}, {'count': 3867, 'campaignID': u'Campaign034'}, {'count': 963, 'campaignID': u'Campaign035'}]

DataframeからParquet fileに書き出す

DataframeをParquet形式でfileに書き出せば、schema情報を保持したままfileにExportが可能です。なお、ExportするS3 bucketのdirectoryが既に存在する場合には書き込みが失敗します、まだ存在していないDirectory名を指定して下さい。

#write to parquet filed
whole_log_df.select("access_time", "userID").write.parquet("s3n://my_S3_bucket/parquet_export") 

#reload from parquet filed
reload_df = sqlContext.read.parquet("s3n://my_S3_bucket/parquet_export") 
print reload_df.printSchema()
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away