はじめに
Sparkの設定は結構めんどくさいので、google colaboratoryでSparkを触れないかと検索してみたら案の定できたので備忘録としてまとめてみた。
(参考リンクのほぼコピペ)
セットアップ用のコード
当然だが、こちらに該当するSparkのバージョンがないとエラーになるので、そこだけ注意が必要。
(参考リンクは旧バージョンを参照しているので、そのままコピペするとエラーになる。)
# 各ライブラリインストール
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://apache.osuosl.org/spark/spark-2.4.0/spark-2.4.0-bin-hadoop2.7.tgz
!tar xf spark-2.4.0-bin-hadoop2.7.tgz
!pip install -q findspark
# 環境変数設定
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.0-bin-hadoop2.7"
# findsparkで環境設定
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
以下のコードを実行して、DataFrameで読み込み出来れば、正常にSparkが使えている。
spark.read.csv('sample_data/california_housing_train.csv', header=True).show(5)
+-----------+---------+------------------+-----------+--------------+-----------+----------+-------------+------------------+
| longitude| latitude|housing_median_age|total_rooms|total_bedrooms| population|households|median_income|median_house_value|
+-----------+---------+------------------+-----------+--------------+-----------+----------+-------------+------------------+
|-114.310000|34.190000| 15.000000|5612.000000| 1283.000000|1015.000000|472.000000| 1.493600| 66900.000000|
|-114.470000|34.400000| 19.000000|7650.000000| 1901.000000|1129.000000|463.000000| 1.820000| 80100.000000|
|-114.560000|33.690000| 17.000000| 720.000000| 174.000000| 333.000000|117.000000| 1.650900| 85700.000000|
|-114.570000|33.640000| 14.000000|1501.000000| 337.000000| 515.000000|226.000000| 3.191700| 73400.000000|
|-114.570000|33.570000| 20.000000|1454.000000| 326.000000| 624.000000|262.000000| 1.925000| 65500.000000|
+-----------+---------+------------------+-----------+--------------+-----------+----------+-------------+------------------+
only showing top 5 rows
(追記) Apache Arrowを併用するには
SparkはArrowと組み合わせると、DataFrameのtoPandasメソッドが劇的に高速化される。
(Speeding up PySpark with Apache Arrow)
グラフによる可視化をする際に、toPandasは多用するので手軽に高速化できるのは結構便利。
# 各ライブラリインストール
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://apache.osuosl.org/spark/spark-2.4.0/spark-2.4.0-bin-hadoop2.7.tgz
!tar xf spark-2.4.0-bin-hadoop2.7.tgz
!pip install -q findspark
!pip install -q pyarrow
# 環境変数設定
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.0-bin-hadoop2.7"
# findsparkで環境設定
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
# Arrowの設定
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
以下のコードを実行すると、高速化が確認できる。
%%timeit
from pyspark.sql.functions import rand
df = spark.range(1 << 22).toDF("id").withColumn("x", rand())
_ = df.toPandas()
導入前
1 loop, best of 3: 18.1 s per loop
導入後
1 loop, best of 3: 1.29 s per loop