この記事は以下のNotebookのコードをベースにした解説になっています。紹介するコードは全て以下のドキュメントに含まれています(Notebookが提供されています)。
Distributed model inference using TensorFlow Keras
概要
Apache Sparkの並列分散処理を使うと、さまざまな計算処理が高速化できます。機械学習における計算、つまり、前処理、特徴量エンジニアリング、学習、推論にもこの考え方は適用できます。この記事では、推論処理に関してSparkの並列分散処理を適用する方法を見ていきます。
例えば、Deep Learningで事前学習済みのモデルを用意して、大量のデータの推論(prediction)を実施する場面を想像してみましょう。認識系の処理には多いケースだと思います。Sparkを使用すると、従来のpredict()
のコードを関数化(UDF化)することで、Sparkの分散実行環境に乗せることができ、すぐにスケーラビリティを持たせることができます。ここでは具体例として事前学習済みのTensorflow Kerasモデルを使っていますが、全く同じ方法で、他の機械学習ライブラリにもそのまま応用できます。
ステップ
既存のコードからの変更は以下の通りです。
- 推論に充てるデータをSpark Dataframe化させておく
- モデルの推論処理(
predict()
処理部分)を関数化させる(UDF化) - SparkでDataframeに対してUDFを一括適用させ、結果をDataframeで受け取る
コードの説明
それでは、Notebookにあるコードを順に追っていきましょう。
import os
import shutil
import time
import pandas as pd
from PIL import Image
import numpy as np
import uuid
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
file_name = "image_data.parquet"
temp_path = "tmp/flowers_{uuid}".format(uuid=str(uuid.uuid1()))
dbfs_file_path = "/dbfs/{}/".format(temp_path)
local_file_path = "/{}/image_data.parquet".format(temp_path)
output_file_path = "/{}/predictions".format(temp_path)
ここでは、ResNet-50の事前学習モデルを使用するため、必要なモジュールをimport
しています。また、推論を実施するデータをSpark Dataframeで扱うため、親和性の高いParquetフォーマットでデータを保存し読み出しやすいようにしていきます(もちろんparquetよりもDelta lakeフォーマットの方がベターです)。そのためのパス設定などがされています。
# Prepare trained model and data for inference
## Load the ResNet-50 Model and broadcast the weights.
model = ResNet50()
bc_model_weights = sc.broadcast(model.get_weights())
コメントにある通り、ResNet50のモデルをロードしています。ここで、sc.broadcast()
という、普段見慣れないコードが出てきます。これは、Sparkの分散処理に関係してくるところになります。Sparkは分散処理を、その名の通り、複数のホスト(クラスター!)で並列実行します。このクラスターのホスト間でデータを共有する方法の一つが、Broadcast変数になります。名前から分かる通り、変数の値を全てのホストに配っておいて(Broadcast!)、同じデータを参照可能にできるようにするためのものです。sc.broadcast()
を使用することで、このBroadcast変数が定義できます。
推論で使用する学習モデルを処理ホストで毎回ロードしても良いのですが、今回は一つの共通のモデルを使用するので、このようにBroadcast変数によって共有する方が効率的です(速度が全く異なります)。この後説明しますが、分散処理するところで、ここでbroadcastしておいた変数からモデルをロードします。
それでは続きを見ていきましょう。
## Load the data and save the datasets to one Parquet file.
import os
local_dir = "/dbfs/databricks-datasets/flower_photos"
files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(local_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']
files = files[:2048]
len(files)
image_data = []
for file in files:
img = Image.open(file)
img = img.resize([224, 224])
data = np.asarray( img, dtype="float32" ).reshape([224*224*3])
image_data.append({"data": data})
pandas_df = pd.DataFrame(image_data, columns = ['data'])
pandas_df.to_parquet(file_name)
os.makedirs(dbfs_file_path)
shutil.copyfile(file_name, dbfs_file_path+file_name)
推論する対象データ(Databricksのサンプルデータ/dbfs/databricks-datasets/flower_photos/*.jpg
)をDataframeにしてから、Parquetファイルとして保存しています。ここでは、SparkではなくPandasを使っています。データサイズ、ファイル数が膨大な場合には、最初からSparkを使うことでスケールできます。
また、今回の対象データはJPEGファイルです。そのままバイナリファイルとしてDataframe化・Parquet化してもよいのですが、ここでは、デコードした後のバイト列としてParquetに入れています。
## Load the data into Spark.
from pyspark.sql.types import *
df = spark.read.parquet(local_file_path)
print(df.count())
# Decrease the batch size of the Arrorw reader to avoid OOM errors on smaller instance types.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
assert len(df.head()) > 0, "`df` should not be empty" # This line will fail if the vectorized reader runs out of memory
一度、保存したparquetファイルをここでSpark Dataframeとして読み込んで、その後、推論処理を適用していきます。
# Run model inference via pandas UDF
## Define the function to parse the input data.
def parse_image(image_data):
image = tf.image.convert_image_dtype(image_data, dtype=tf.float32) * (2. / 255) - 1
image = tf.reshape(image,[224,224,3])
return image
@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER) # (1)
def predict_batch_udf(image_batch_iter): # (2)
batch_size = 64
model = ResNet50(weights=None) # (3)
model.set_weights(bc_model_weights.value)
for image_batch in image_batch_iter: # (4)
images = np.vstack(image_batch)
dataset = tf.data.Dataset.from_tensor_slices(images)
dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(5000).batch(batch_size)
preds = model.predict(dataset)
yield pd.Series(list(preds)) # (5)
ここで、ようやく推論の処理def predict_batch_udf()
を関数化(UDF化)しています。この部分のコード内容は以下の通りです。
- (1) デコレータを使ったPandasUDFの宣言。この関数の戻り値として
ArrayType(FloatType())
、つまり、Spark DataFrameにおけるfloat型の配列が使われることを宣言しています。また、PandasUDFの中でもPandasUDFType.SCALAR_ITER
のタイプを使用する宣言になっています。
PandasUDFについては下記の記事が参考になります。ここで使っているUDFを簡単に説明すると、Spark DataframeのカラムデータをPandas Seriesとして出力するイテレータが関数に渡されます。
-
(2) UDF
predict_batch_udf()
の定義。この中で推論処理をパック化する。上記の説明した通り、image_batch_iter
が関数が受け取るイテレータになります。このイテレータ(の.__next__()
)はSpark DataframeのカラムデータをPandas Seriesとして出力します。つまり、このUDFの中では、SparkのDataframeは登場せず、Pandas Seriesを使った処理になります。PandasUDFはSparkデータをUDF内ではPandasデータとして受け取る、という部分がやや混同しやすいです。 -
(3) 先ほどBroadcast変数で配っておいたResNet50のモデルをロードしています。
-
(4) 受け取ったイテレータを展開して、推論を実行しています。
-
(5) イテレータとして処理をするので、推論結果も
return
ではなくyield
で返します。
これで、推論するデータと、その処理を定義したUDFが用意できました。最後にこれらを掛け合わせて、推論処理を並列分散実行します。
## Run the model inference and save the result to Parquet file.
predictions_df = df.select(predict_batch_udf(col("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path)
## Load and check the prediction results.
result_df = spark.read.parquet(output_file_path)
display(result_df)
推論結果を一度Parquetで書き出してから、それを再度読み込んで確認しています。ResNet50の認識結果なので、カテゴリーごとの確率が出力になっています。
補足1
今回のコードをそのまま実行すると、環境によっては、並列分散化がうまく行われず、パフォーマンスゲインが全くない場合があります。原因は、推論に充てるデータをSpark Dataframeとして読み込んだ際に、パーティションサイズが小さく(getNumPartition()
=> 2)なってしまう点にあります。
これを回避するために、一度ParquetファイルからSpark Dataframeに読み込んだ後に、repartition
させるとよいです。
コードの変更箇所は以下の通りです。
##### [変更前]
## Load the data into Spark.
from pyspark.sql.types import *
df = spark.read.parquet(local_file_path)
print(df.count())
# Decrease the batch size of the Arrorw reader to avoid OOM errors on smaller instance types.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
##### [変更後]
## Load the data into Spark.
from pyspark.sql.types import *
df = spark.read.parquet(local_file_path)
print(df.count())
df = df.repartition(400, 'data') # <== 追加!
print( df.rdd.getNumPartitions() ) # 確認
# Decrease the batch size of the Arrorw reader to avoid OOM errors on smaller instance types.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")
補足2
PandasUDFは色々な方式があり、今回の実装もイテレーションタイプ以外でも可能です。例えば、よく使うSeries to Series
タイプのPandasUDFを使う場合は、以下のようなコードになります。(UDFの部分のみ変更、その他は同一のコードで可能)
import pandas as pd
@pandas_udf( ArrayType(FloatType()) )
def predict_batch_udf_scalar(image_batch: pd.Series) -> pd.Series:
batch_size = 64
model = ResNet50(weights=None)
model.set_weights(bc_model_weights.value)
images = np.vstack(image_batch)
dataset = tf.data.Dataset.from_tensor_slices(images)
dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(5000).batch(batch_size)
preds = model.predict(dataset)
return pd.Series(list(preds))
predictions_df = df.select(predict_batch_udf_scalar(col("data")).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path)
こちらの方が、従来のUDFに近い形なので、わかりやすいかもれません。
参考