0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Sparkによる線形回帰モデルのトレーニング

Last updated at Posted at 2024-03-27

2024/4/12に翔泳社よりApache Spark徹底入門を出版します!

書籍のサンプルノートブックをウォークスルーしていきます。Python/Chapter10/10-2 Linear Regressionとなります。

翻訳ノートブックのリポジトリはこちら。

ノートブックはこちら

回帰問題: レンタル価格の予測

このノートブックでは、サンフランシスコのAirbnbのレンタル価格を予測するために以前のlabでクレンジングしたデータセットを使用します。

注意
実際には以下のパスにクレンジング済みのデータが格納されています。

filePath = "/databricks-datasets/learning-spark-v2/sf-airbnb/sf-airbnb-clean.parquet"
airbnbDF = spark.read.parquet(filePath)
display(airbnbDF)

Screenshot 2024-03-27 at 17.48.40.png

トレーニング/テストの分割

MLモデルを構築する際、テストデータを参照するべきではありません(なぜでしょうか?)。

トレーニングデータセットに80%をキープし、20%をテストデータセットとして取っておきます。randomSplitメソッドPython/Scalaを活用します。

問題: なぜ、シードを設定する必要があるのでしょうか?

trainDF, testDF = airbnbDF.randomSplit([.8, .2], seed=42)
print(f"There are {trainDF.cache().count()} rows in the training set, and {testDF.cache().count()} in the test set")
There are 5780 rows in the training set, and 1366 in the test set

問題: クラスター設定を変更するとどうなるのでしょうか?

これを試すには、1台のみのワーカーを持つクラスターと2台のワーカーを持つ別のクラスターを起動します。

注意
このデータは非常に小さいもの(1パーティション)であり、違いを確認するにはdatabricks-datasets/learning-spark-v2/sf-airbnb/sf-airbnb-clean-100p.parquetのように、大規模なデータセット(2+のパーティションなど)でテストする必要があるかもしれません。しかし、以下のコードでは、異なるクラスター設定で、どのように異なってパーティショニングされるのかをシミュレートするために、シンプルにrepartitionを行い、我々のトレーニングセットで同じ数のデータポイントを取得できるかどうかを確認しています。

(trainRepartitionDF, testRepartitionDF) = (airbnbDF
                                           .repartition(24)
                                           .randomSplit([.8, .2], seed=42))

print(trainRepartitionDF.count())
5738

80/20でtrain/testを分割する際、これは80/20分割の「近似」となります。正確な80/20の分割ではなく、我々のデータのパーティショニングが変化すると、train/testで異なる数のデータポイントを取得するだけではなく、データポイント自体も異なるものになります。

おすすめは、再現性の問題に遭遇しないように、一度データを分割したらそれぞれのtrain/testフォルダに書き出すというものです。

bedroomsの数を指定したらpriceを予測する非常にシンプルな線形回帰モデルを構築します。

問題: 線形回帰モデルにおける仮定にはどのようなものがありますか?

display(trainDF.select("price", "bedrooms").summary())

Screenshot 2024-03-27 at 17.50.33.png

価格についてはデータセットでいくつかの外れ値があります(一晩$10,000??)。モデルを構築する際にはこのことを念頭に置いてください :)

Vector Assembler

線形回帰では、入力としてVector型のカラムを期待します。

VectorAssembler Python/Scalaを用いて、簡単にbedroomsの値を単一のベクトルに変換できます。VectorAssemblerはtransformerの一例です。トランスフォーマーはデータフレームを受け取り、1つ以上のカラムが追加された新規のデータフレームを返却します。これらはデータから学習は行いませんが、ルールベースの変換処理を適用します。

from pyspark.ml.feature import VectorAssembler

vecAssembler = VectorAssembler(inputCols=["bedrooms"], outputCol="features")

vecTrainDF = vecAssembler.transform(trainDF)

vecTrainDF.select("bedrooms", "features", "price").show(10)
+--------+--------+-----+
|bedrooms|features|price|
+--------+--------+-----+
|     1.0|   [1.0]|200.0|
|     1.0|   [1.0]|130.0|
|     1.0|   [1.0]| 95.0|
|     1.0|   [1.0]|250.0|
|     3.0|   [3.0]|250.0|
|     1.0|   [1.0]|115.0|
|     1.0|   [1.0]|105.0|
|     1.0|   [1.0]| 86.0|
|     1.0|   [1.0]|100.0|
|     2.0|   [2.0]|220.0|
+--------+--------+-----+
only showing top 10 rows

線形回帰

データの準備ができたので、最初のモデルを構築するためにLinearRegressionエスティメーター Python/Scala を活用できます。エスティメーターは、入力としてデータフレームを受け取ってモデルを返却し、モデルは.fit()メソッドを持ちます。

from pyspark.ml.regression import LinearRegression

lr = LinearRegression(featuresCol="features", labelCol="price")
lrModel = lr.fit(vecTrainDF)

モデルの調査

m = round(lrModel.coefficients[0], 2)
b = round(lrModel.intercept, 2)

print(f"The formula for the linear regression line is price = {m}*bedrooms + {b}")
The formula for the linear regression line is price = 123.68*bedrooms + 47.51

パイプライン

from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[vecAssembler, lr])
pipelineModel = pipeline.fit(trainDF)

テストセットへの適用

predDF = pipelineModel.transform(testDF)

predDF.select("bedrooms", "features", "price", "prediction").show(10)
+--------+--------+------+------------------+
|bedrooms|features| price|        prediction|
+--------+--------+------+------------------+
|     1.0|   [1.0]|  85.0|171.18598011578285|
|     1.0|   [1.0]|  45.0|171.18598011578285|
|     1.0|   [1.0]|  70.0|171.18598011578285|
|     1.0|   [1.0]| 128.0|171.18598011578285|
|     1.0|   [1.0]| 159.0|171.18598011578285|
|     2.0|   [2.0]| 250.0|294.86172649777757|
|     1.0|   [1.0]|  99.0|171.18598011578285|
|     1.0|   [1.0]|  95.0|171.18598011578285|
|     1.0|   [1.0]| 100.0|171.18598011578285|
|     1.0|   [1.0]|2010.0|171.18598011578285|
+--------+--------+------+------------------+
only showing top 10 rows

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?