Scala
Spark
Python3

scalaで書いたUDFをpysparkから呼び出したい

動機

sparkのプログラムを書いていて計算速度が思ったよりでないときがあった。調べるとpythonで書かれたuser defined function (UDF) は速度が遅いらしい。というのはspark自体がJavaで書かれていて、UDFをpythonで書いてしまうとJavaとpythonの間を変換する作業が走るから遅いとのこと。もうだいぶpythonで書き進めてしまい、今更scalaやjavaで書き直すのはめんどくさい。そこでscalaで書いたUDFをpythonから呼び出すことにした。

参考サイト

ここが一番わかりやすい
http://blog.einext.com/apache-spark/scala-udf-in-pyspark

http://blog.einext.com/apache-spark/scala-udf-in-pyspark
https://stackoverflow.com/questions/33233737/spark-how-to-map-python-with-scala-or-java-user-defined-functions
https://qiita.com/suin/items/b8a7af13b00cfdecfd1e
https://qiita.com/ytanak/items/97ecc67786ed7c5557bb
https://www.slideshare.net/SparkSummit/getting-the-best-performance-with-pyspark

動作環境

ここを参考につくってください。
https://qiita.com/neppysan/items/0fe706f04b001c082d38

scalaでUDFを書く

scalaビルド用ディレクトリ作成

以下のディレクトリを作成

├── project
│   └── target
│       └── config-classes
├── src
│   ├── main
│   │   ├── resources
│   │   └── scala
│   └── test
│       ├── resources
│       └── scala
└── target

ビルド設定ファイルの作成

./build.sbtに以下を書き込む(よくわかってない)

name := "SparkUDFs"
version := "0.1"
scalaVersion := "2.11.8"

libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-sql"       % "2.1.0"
)

scalaのUDFを作成

./src/main/scala/scalaudf.scalaに以下を書き込む

package com.example.spark.udfs

import org.apache.spark.sql.api.java.UDF1

class addOne extends UDF1[Integer, Integer] {
  def call(x: Integer) = x + 1
} 

.jarの作成

次のコマンドを叩く。15分ぐらいかかる。

root@spark:~/scala_udf# sbt clean package

./target/scala-2.11/scala_udf-assembly-0.1-SNAPSHOT.jarができたことを確認。

サンプルデータの作成

./sample.jsonを作成

{"name":"a","value":1}
{"name":"b","value":2}
{"name":"c","value":3}

pythonファイルの作成

./test.pyを作成

from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext

spark = SparkSession.builder.appName("test").getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc,sparkSession = spark)

sqlContext.registerJavaFunction("add_one", "com.example.spark.udfs.addOne")
readfile = "/root/scala_hello_world/sample.json"
sample = spark.read.json(readfile)
sample.createOrReplaceTempView("sample")
data = spark.sql("SELECT name,value, add_one(cast(value as int)) from sample")
data.show()

追記
spark2.3からはこれでOK

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("test").getOrCreate()
spark.udf.registerJavaFunction("get100meshid","com.example.spark.udfs.addOne","integer")
readfile = "/root/scala_hello_world/sample.json"
sample = spark.read.json(readfile)
sample.createOrReplaceTempView("sample")
data = spark.sql("SELECT name,value, add_one(cast(value as int)) from sample")
data.show()

コマンド実行

root@spark:~/scala_udf# spark-submit --jars ./target/scala-2.11/scala_udf-assembly-0.1-SNAPSHOT.jar test.py
+----+-----+-----------------------+                                            
|name|value|UDF(cast(value as int))|
+----+-----+-----------------------+
|   a|    1|                      2|
|   b|    2|                      3|
|   c|    3|                      4|
+----+-----+-----------------------+

おわり。