Help us understand the problem. What is going on with this article?

Spark MLパイプラインの使い方

More than 3 years have passed since last update.

Spark MLにはデータフレームの連続した変換操作を一つにまとめるPipelineという仕組みがあります。これを使うとコードがスッキリ書けるようになるほか、Spark内部でのメモリの利用効率も上がるらしいです。

大まかな流れは次の通り

  • ステージの用意
  • パイプライン構築
  • モデル生成
  • 実行

前回投稿したSpark MLで主成分分析をPipelineを使って書き換えてみます。

ステージとパイプライン

パイプライン中の段階をステージと呼びます。主成分分析の例では次の3段階がありました。

  • ベクトルの作成
  • 標準化
  • PCA

この3つを使ってパイプラインを宣言します。dfは入力データが格納されたデータフレームです。dfの詳細は前回の記事を参照してください。

from pyspark.ml.pipeline import Pipeline

# Pipelineの各ステージ
assembler = VectorAssembler(inputCols=df.columns[1:], outputCol="変量")
scaler = StandardScaler(inputCol="変量", outputCol="標準化変量", withStd=True, withMean=True)
pca = PCA(k=3, inputCol="標準化変量", outputCol="主成分得点")

# Pipelineの宣言
pipeline = Pipeline(stages=[
    assembler,
    scaler,
    pca
])

モデルの生成

構築したパイプラインにデータを入力し、モデルを作ります。

model = pipeline.fit(df)

モデルの実行

result = model.transform(df)
result.select("主成分得点").show(truncate=False)

実行結果

前回の記事で個別に実行した時と同じ結果が得られました。

+---------------------------------------------------------------+
|主成分得点                                                          |
+---------------------------------------------------------------+
|[-2.2620712255691466,0.4021126641946994,0.35861418406317674]   |
|[1.3672950172090064,-0.516574975843834,0.8240383763102186]     |
|[-0.35784774304549694,1.0654633785914394,-0.7670998522924913]  |
|[0.3930334607140129,-1.220525792393691,-0.05437714111925901]   |
|[0.9712806670593661,1.7644947192188811,-0.2783291638335238]    |
|[0.8556397135650156,-0.9097726336587761,-1.0627843972001996]   |
|[1.0076787432724863,0.1504509197015279,1.2009982469039933]     |
|[-1.8977055313059759,-0.9270196509736093,-0.005660728153863093]|
|[0.4960234396284956,-0.24274673811341405,-0.6858245266064249]  |
|[-0.5733265415277634,0.43411810927677885,0.47042500192836967]  |
+---------------------------------------------------------------+

ステージの参照方法

ステージのオブジェクトはmodel.stages[]で参照することができます。第3ステージのPCAモデルを参照してみます。

print("==== 固有ベクトル ====")
print(model.stages[2].pc)

print("==== 寄与率 ====")
print(model.stages[2].explainedVariance)

まとめ

Pipelineを使うことで、中間の変数がなくなり、コードがスッキリ書けました。
各ステージの個々のモデルの参照も可能なので、Pipelineを使わない理由はなさそうです。

全ソース

# -*- coding: utf-8 -*-
from pyspark.sql import SparkSession
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.feature import PCA, VectorAssembler, StandardScaler

# Initialize SparkSession
spark = (SparkSession
         .builder
         .appName("news")
         .enableHiveSupport()
         .getOrCreate())

# Read raw data
df = spark.read.csv('news.csv', header=True, inferSchema=True, mode="DROPMALFORMED", encoding='UTF-8')

print("==== 生データ ====")
df.show(truncate=False)

# pipelineの部品を用意する
assembler = VectorAssembler(inputCols=df.columns[1:], outputCol="変量")
scaler = StandardScaler(inputCol="変量", outputCol="標準化変量", withStd=True, withMean=True)
pca = PCA(k=3, inputCol="標準化変量", outputCol="主成分得点")

pipeline = Pipeline(stages=[
    assembler,
    scaler,
    pca
])

# パイプラインを実行して入力データからモデルを作る
model = pipeline.fit(df)

# モデルを実行する
result = model.transform(df)
result.show(truncate=False)

# Pipelineのステージは.stagesで参照できる
print("==== 固有ベクトル ====")
print(model.stages[2].pc)

print("==== 寄与率 ====")
print(model.stages[2].explainedVariance)
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away