こちらのソリューションアクセラレータForecast demand at the part level for streamlined manufacturingをウォークスルーします。
ノートブックはこちら。
前編ではノートブック01_Introduction_And_Setup
と02_Fine_Grained_Demand_Forecasting
をカバーします。前回の記事同様、Unity Catalog対応しています。
外部ロケーションの作成
こちらのソリューションアクセラレータでは、テーブルを外部ロケーションに保存しています。外部ロケーションがない場合には事前に設定してください。
_resources/00-setup
import os
import re
import mlflow
spark.conf.set("spark.databricks.cloudFiles.schemaInference.sampleSize.numFiles", "10")
db_prefix = "demand_planning"
catalogName = "takaakiyayoi_catalog"
# Get dbName and cloud_storage_path, reset and create database
current_user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().apply('user')
if current_user.rfind('@') > 0:
current_user_no_at = current_user[:current_user.rfind('@')]
else:
current_user_no_at = current_user
current_user_no_at = re.sub(r'\W+', '_', current_user_no_at)
dbName = db_prefix+"_"+current_user_no_at
#cloud_storage_path = f"/Users/{current_user}/field_demos/{db_prefix}"
cloud_storage_path = f"s3://taka-external-location-bucket/external-tables/{db_prefix}"
reset_all = dbutils.widgets.get("reset_all_data") == "true"
spark.sql(f"""USE CATALOG {catalogName}""")
if reset_all:
spark.sql(f"DROP DATABASE IF EXISTS {dbName} CASCADE")
dbutils.fs.rm(cloud_storage_path, True)
spark.sql(f"""create database if not exists {dbName} MANAGED LOCATION '{cloud_storage_path}/tables' """)
spark.sql(f"""USE {dbName}""")
_resources/01-data-generator
# Unity Catalog対応
catalogName = "takaakiyayoi_catalog"
spark.sql(f"""USE CATALOG {catalogName}""")
イントロダクション
01_Introduction_And_Setup
を実行します。
製造業者にとって需要予測は重要なビジネスプロセスとなります。製造業者は以下を行うために正確な予測を必要とします:
- 製造オペレーション拡大の計画
- 十分な在庫の確保
- 顧客充当の保証
部品レベルの需要予測は、製造業者が自身のサプライチェーンに依存している個々の製造業において重要なものとなっています。近年では、履歴データに基づく数量ベースの予測、統計テクニックや機械学習テクニックで強化された予測に大きな投資を行っています。
需要予測はパンデミック前の数年においてはとても大きな成功を収めていました。製品の需要曲線は比較的変動が少なく、素材が欠品する可能性は比較的小さいものでした。そのため、製造業者は出荷商品数をシンプルに「真の」需要と解釈し、将来に対して外挿を行うために高度に洗練された統計モデルを使用していました。これまでは、これによって以下を提供していました:
- 販売計画の改善
- 回転率を最大化することを可能にする高度に最適化された安全在庫により、非常に優れたサービスデリバリーのパフォーマンスの提供
- 部品表(BoM: Bill of materials)を用いて製品アウトプットから原材料レベルまで追跡することによる、最適化された製造計画
しかし、パンデミック以降、需要の変動は激しいものとなりました。当初は需要が大きく落ち込みますが、計画不十分によるV字の回復曲線が続きます。結果として生じる、下流の製造業者への注文の増加は実際に供給者における災害の最初のフェーズを引き起こしました。基本的に、もはや製造アウトプットは実際の需要にマッチせず、いかなる変造性の増加は多くの場合、安全在庫を増加させるような妥当性のない推奨値につながることになりました。製造および販売の計画は、実際の需要ではなく原材料の可用性によって立てられることを余儀なくされました。標準的な需要計画のアプローチは大きな限界に近づいています。
完璧な例は、チップの災害に見ることができます。注文を最初に削減し、後で増加させた後、車のメーカーやサプライヤーはリモートワークによる半導体メーカーで増加する需要と競争しなくてはなりませんでした。さらに状況を悪化させたのは、変動性をさらに増加させたいくつかの重大なイベントです。中国とアメリカ間の貿易戦争は、中国最大の半導体メーカーへの制限を突きつけることになりました。2021年のテキサスのアイスストームは、いくつかのコンピューターチップ設備の閉鎖に繋がった電力災害を引き起こしました。テキサスはアメリカにおける半導体製造の中心地です。対話では、さらなる供給不足を引き起こした深刻的な半導体の枯渇を体験することになりました。日本の2つのプラントが火事となり、そのうちの一つは地震によるものでした。
参考: Boom & Bust Cycles in Chips (https://www.economist.com/business/2022/01/29/when-will-the-semiconductor-cycle-peak)
統計的な需要予測は上述の「不可抗力」のイベントを予測できたのでしょうか?間違いなくNOです!しかし、我々は、製造業者がこれらの困難を通じて操業できるようにするための大規模予測ソリューションを構築するための素晴らしいプラットフォームをDatabricksが提供すると考えています。
- (Python、R、SQL、Scalaによる)コラボレーティブなノートブックによる、ビジネス知識やドメイン専門性を適用しながらの複数ソースからのデータの探索、補強、可視化
- 個々のアイテム(製品、SKU、部品など)ごとのモデルを並列化し、数千のアイテムにスケール
- MLflowを用いた実験追跡によって、再現性、追跡可能なパフォーマンス指標、容易な再利用の実現
このソリューションアクセラレータでは、シミュレーションしたデータセットを用いてDatabricksを活用するメリットをご紹介します。ここでは、高度な運転アシストシステムを製造するTier 1の自動車メーカーのロールを想定しています。そして、3つのステップで進めていきます:
セットアップ
%run ./_resources/00-setup $reset_all_data=true
データの理解
demand_df = spark.read.table(f"{dbName}.part_level_demand")
display(demand_df.select("Product").dropDuplicates())
高精細な需要予測
02_Fine_Grained_Demand_Forecasting
を実行します。
前提条件: このノートブックを実行する前に 01_Introduction_And_Setup を実行してください。
このノートブックでは、適切な時系列モデルの特定からスタートし、優れたスピードとコスト効率性を持つ並列処理で複数モデルをトレーニングするために非常に類似したアプローチを適用します。
このノートブックのハイライト:
- 適切な時系列モデルを特定するために、Databricksのコラボレーティブでインタラクティブなノートブックを活用
- シングルノードのデータサイエンスコードを、さまざまなキー(SKUなど)で分散させるためのPandas UDF(ユーザー定義関数)
- Pandas UDF内でハイパーパラメータチューニングを実行することもできるHyperopt
%run ./_resources/00-setup $reset_all_data=false
print(cloud_storage_path)
print(dbName)
s3://taka-external-location-bucket/external-tables/demand_planning
demand_planning_takaaki_yayoi
import os
import datetime as dt
from random import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.dates as md
from statsmodels.tsa.api import ExponentialSmoothing, SimpleExpSmoothing, Holt
from statsmodels.tsa.statespace.sarimax import SARIMAX
import mlflow
import hyperopt
from hyperopt import hp, fmin, tpe, SparkTrials, STATUS_OK, space_eval
from hyperopt.pyll.base import scope
mlflow.autolog(disable=True)
from statsmodels.tsa.api import Holt
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_percentage_error
import pyspark.sql.functions as f
from pyspark.sql.types import *
モデルの構築
Databricksのコラボレーティブ、インタラクティブな環境を活用
データの読み込み
demand_df = spark.read.table(f"{dbName}.part_level_demand")
demand_df = demand_df.cache() # このサンプルノートブックためだけのものです
サンプルの検証: 単一の時系列を抽出してpandasデータフレームに変換
example_sku = demand_df.select("SKU").orderBy("SKU").limit(1).collect()[0].SKU
print("example_sku:", example_sku)
pdf = demand_df.filter(f.col("SKU") == example_sku).toPandas()
# 単一の時系列の作成
series_df = pd.Series(pdf['Demand'].values, index=pdf['Date'])
series_df = series_df.asfreq(freq='W-MON')
display(pdf)
example_sku: CAM_0X6CLF
予測期間の決定
forecast_horizon = 40
is_history = [True] * (len(series_df) - forecast_horizon) + [False] * forecast_horizon
train = series_df.iloc[is_history]
score = series_df.iloc[~np.array(is_history)]
外因性変数の導出
covid_breakpoint = dt.date(year=2020, month=3, day=1)
exo_df = pdf.assign(Week = pd.DatetimeIndex(pdf["Date"]).isocalendar().week.tolist())
exo_df = exo_df \
.assign(covid = np.where(pdf["Date"] >= np.datetime64(covid_breakpoint), 1, 0).tolist()) \
.assign(christmas = np.where((exo_df["Week"] >= 51) & (exo_df["Week"] <= 52) , 1, 0).tolist()) \
.assign(new_year = np.where((exo_df["Week"] >= 1) & (exo_df["Week"] <= 4) , 1, 0).tolist()) \
.set_index('Date')
exo_df = exo_df[["covid", "christmas", "new_year" ]]
exo_df = exo_df.asfreq(freq='W-MON')
print(exo_df)
train_exo = exo_df.iloc[is_history]
score_exo = exo_df.iloc[~np.array(is_history)]
covid christmas new_year
Date
2018-07-23 0 0 0
2018-07-30 0 0 0
2018-08-06 0 0 0
2018-08-13 0 0 0
2018-08-20 0 0 0
... ... ... ...
2021-06-21 1 0 0
2021-06-28 1 0 0
2021-07-05 1 0 0
2021-07-12 1 0 0
2021-07-19 1 0 0
[157 rows x 3 columns]
Holt’s Winters季節性手法をトライ
Holt's Winters季節性手法は季節性とトレンドのコンポーネントをモデリングします。最初にこれのさまざまなバージョンをトライすることは適切と言えます。時系列とそのイレギュラーなコンポーネントをトレーニング期間に適用させる際にはうまくフィットしましたが、予測期間ではそうではないことを観測しました。このモデルはクリスマス効果やパンデミック開始からの回復期間の予測ではうまく動作しませんでした。
fit1 = ExponentialSmoothing(
train,
seasonal_periods=3,
trend="add",
seasonal="add",
use_boxcox=True,
initialization_method="estimated",
).fit(method="ls")
fcast1 = fit1.forecast(forecast_horizon).rename("Additive trend and additive seasonal")
fit2 = ExponentialSmoothing(
train,
seasonal_periods=4,
trend="add",
seasonal="mul",
use_boxcox=True,
initialization_method="estimated",
).fit(method="ls")
fcast2 = fit2.forecast(forecast_horizon).rename("Additive trend and multiplicative seasonal")
fit3 = ExponentialSmoothing(
train,
seasonal_periods=4,
trend="add",
seasonal="add",
damped_trend=True,
use_boxcox=True,
initialization_method="estimated",
).fit(method="ls")
fcast3 = fit3.forecast(forecast_horizon).rename("Additive damped trend and additive seasonal")
fit4 = ExponentialSmoothing(
train,
seasonal_periods=4,
trend="add",
seasonal="mul",
damped_trend=True,
use_boxcox=True,
initialization_method="estimated",
).fit(method="ls")
fcast4 = fit4.forecast(forecast_horizon).rename("Additive damped trend and multiplicative seasonal")
plt.figure(figsize=(12, 8))
(line0,) = plt.plot(series_df, marker="o", color="black")
plt.plot(fit1.fittedvalues, color="blue")
(line1,) = plt.plot(fcast1, marker="o", color="blue")
plt.plot(fit2.fittedvalues, color="red")
(line2,) = plt.plot(fcast2, marker="o", color="red")
plt.plot(fit3.fittedvalues, color="green")
(line3,) = plt.plot(fcast3, marker="o", color="green")
plt.plot(fit4.fittedvalues, color="orange")
(line4,) = plt.plot(fcast4, marker="o", color="orange")
plt.axvline(x = min(score.index.values), color = 'red', label = 'axvline - full height')
plt.legend([line0, line1, line2, line3, line4], ["Actuals", fcast1.name, fcast2.name, fcast3.name, fcast4.name])
plt.xlabel("Time")
plt.ylabel("Demand")
plt.title("Holts Winters Seasonal Method")
Text(0.5, 1.0, 'Holts Winters Seasonal Method')
SARIMAX手法をトライ
SARIMAXによって、説明変数を組み込むことができます。ビジネス観点では、これは需要を引き起こすイベントに関するビジネス知識を組み込む助けとなります。これは、クリスマス効果だけではなく、プロモーションのアクションにすることも可能です。ビジネス知識を活用しない際には、このモデルは大したパフォーマンスを示していませんでしたが、外因性変数を組み込んだ場合、クリスマス効果やパンデミック後のトレンドは予測期間でよくフィットしています。
最初のモデル
fit1 = SARIMAX(train, order=(1, 2, 1), seasonal_order=(0, 0, 0, 0), initialization_method="estimated").fit(warn_convergence = False)
fcast1 = fit1.predict(start = min(train.index), end = max(score_exo.index)).rename("Without exogenous variables")
fit2 = SARIMAX(train, exog=train_exo, order=(1, 2, 1), seasonal_order=(0, 0, 0, 0), initialization_method="estimated").fit(warn_convergence = False)
fcast2 = fit2.predict(start = min(train.index), end = max(score_exo.index), exog = score_exo).rename("With exogenous variables")
RUNNING THE L-BFGS-B CODE
* * *
Machine precision = 2.220D-16
N = 3 M = 10
At X0 0 variables are exactly at the bounds
At iterate 0 f= 7.65053D+00 |proj g|= 1.51071D-01
At iterate 5 f= 7.62496D+00 |proj g|= 3.68205D-02
At iterate 10 f= 7.61924D+00 |proj g|= 5.37559D-03
At iterate 15 f= 7.61069D+00 |proj g|= 1.37879D-01
At iterate 20 f= 7.59161D+00 |proj g|= 1.86274D-03
* * *
Tit = total number of iterations
Tnf = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip = number of BFGS updates skipped
Nact = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F = final function value
* * *
N Tit Tnf Tnint Skip Nact Projg F
3 21 25 1 0 0 4.156D-06 7.592D+00
F = 7.5916086358081500
CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL
RUNNING THE L-BFGS-B CODE
* * *
Machine precision = 2.220D-16
N = 6 M = 10
At X0 0 variables are exactly at the bounds
At iterate 0 f= 6.74830D+00 |proj g|= 1.34739D-01
At iterate 5 f= 6.65996D+00 |proj g|= 4.02106D-03
At iterate 10 f= 6.60770D+00 |proj g|= 3.35486D-03
At iterate 15 f= 6.60425D+00 |proj g|= 1.57733D-03
At iterate 20 f= 6.59134D+00 |proj g|= 2.66757D-02
At iterate 25 f= 6.58669D+00 |proj g|= 9.06320D-05
At iterate 30 f= 6.58666D+00 |proj g|= 2.52569D-03
At iterate 35 f= 6.58364D+00 |proj g|= 2.76137D-02
At iterate 40 f= 6.57553D+00 |proj g|= 3.60288D-04
At iterate 45 f= 6.57473D+00 |proj g|= 1.49793D-02
At iterate 50 f= 6.57373D+00 |proj g|= 4.58809D-05
* * *
Tit = total number of iterations
Tnf = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip = number of BFGS updates skipped
Nact = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F = final function value
* * *
N Tit Tnf Tnint Skip Nact Projg F
6 50 57 1 0 0 4.588D-05 6.574D+00
F = 6.5737302019351569
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT
This problem is unconstrained.
This problem is unconstrained.
plt.figure(figsize=(18, 6))
plt.plot(series_df, marker="o", color="black")
plt.plot(fcast1[10:], color="blue")
(line1,) = plt.plot(fcast1[10:], marker="o", color="blue")
plt.plot(fcast2[10:], color="green")
(line2,) = plt.plot(fcast2[10:], marker="o", color="green")
plt.axvline(x = min(score.index.values), color = 'red', label = 'axvline - full height')
plt.legend([line0, line1, line2], ["Actuals", fcast1.name, fcast2.name])
plt.xlabel("Time")
plt.ylabel("Demand")
plt.title("SARIMAX")
Text(0.5, 1.0, 'SARIMAX')
SARIMAXモデルの最適なパラメータを特定するためにMLflowとHyperoptを活用
最初のモデルでは、好適なパラメータを特定するために手動のトライアンドエラーの手法を適用しました。自動で最適なパラメータを特定するために、MLflowとHyperoptを活用することができます。
はじめに評価関数を定義します。これは、与えられたパラメータでSARIMAXモデルをトレーニングし、平均二乗誤差を計算することで評価を行います。
def evaluate_model(hyperopt_params):
# モデルパラメータの設定
params = hyperopt_params
assert "p" in params and "d" in params and "q" in params, "Please provide p, d, and q"
if 'p' in params: params['p']=int(params['p']) # hyperoptは値をfloatとして供給しますがモデルはintを必要とします
if 'd' in params: params['d']=int(params['d']) # hyperoptは値をfloatとして供給しますがモデルはintを必要とします
if 'q' in params: params['q']=int(params['q']) # hyperoptは値をfloatとして供給しますがモデルはintを必要とします
order_parameters = (params['p'],params['d'],params['q'])
# このサンプルではシンプルさのために季節性を考慮しません
model1 = SARIMAX(train, exog=train_exo, order=order_parameters, seasonal_order=(0, 0, 0, 0))
fit1 = model1.fit(disp=False)
fcast1 = fit1.predict(start = min(score_exo.index), end = max(score_exo.index), exog = score_exo )
return {'status': hyperopt.STATUS_OK, 'loss': np.power(score.to_numpy() - fcast1.to_numpy(), 2).mean()}
次に、モデルを評価するためのパラメータの探索空間を定義します。
space = {
'p': scope.int(hyperopt.hp.quniform('p', 0, 4, 1)),
'd': scope.int(hyperopt.hp.quniform('d', 0, 2, 1)),
'q': scope.int(hyperopt.hp.quniform('q', 0, 4, 1))
}
これで、自動で最適なパラメータを特定するための探索空間と評価関数を活用することができます。これらのモデルはエクスペリメントで自動で追跡されることに注意してください。
rstate = np.random.default_rng(123)
with mlflow.start_run(run_name='mkh_test_sa'):
argmin = fmin(
fn=evaluate_model,
space=space,
algo=tpe.suggest, # hyperoptが探索空間をどのように移動するのかを制御するアルゴリズム
max_evals=10,
trials=SparkTrials(parallelism=1),
rstate=rstate,
verbose=False
)
この方法によって、最適なパラメータセットを特定することができます。
displayHTML(f"The optimal parameters for the selected series with SKU '{pdf.SKU.iloc[0]}' are: d = '{argmin.get('d')}', p = '{argmin.get('p')}' and q = '{argmin.get('q')}'")
The optimal parameters for the selected series with SKU 'CAM_0X6CLF' are: d = '2.0', p = '4.0' and q = '3.0'
次のステップでは、この手順を数千のモデルにスケールさせます。
いつでも大規模に数千のモデルをトレーニング
あなたの好きなライブラリやアプローチを使い続けることができます
FORECAST_HORIZON = 40
def add_exo_variables(pdf: pd.DataFrame) -> pd.DataFrame:
midnight = dt.datetime.min.time()
timestamp = pdf["Date"].apply(lambda x: dt.datetime.combine(x, midnight))
calendar_week = timestamp.dt.isocalendar().week
# 外因性変数に対する柔軟かつカスタムのロジックを定義
covid_breakpoint = dt.datetime(year=2020, month=3, day=1)
enriched_df = (
pdf
.assign(covid = (timestamp >= covid_breakpoint).astype(float))
.assign(christmas = ((calendar_week >= 51) & (calendar_week <= 52)).astype(float))
.assign(new_year = ((calendar_week >= 1) & (calendar_week <= 4)).astype(float))
)
return enriched_df[["Date", "Product", "SKU", "Demand", "covid", "christmas", "new_year"]]
enriched_schema = StructType(
[
StructField('Date', DateType()),
StructField('Product', StringType()),
StructField('SKU', StringType()),
StructField('Demand', FloatType()),
StructField('covid', FloatType()),
StructField('christmas', FloatType()),
StructField('new_year', FloatType()),
]
)
def split_train_score_data(data, forecast_horizon=FORECAST_HORIZON):
"""
- すでにデータがdate/timeでソートされていることを前提
- forecast_horizonは週単位
"""
is_history = [True] * (len(data) - forecast_horizon) + [False] * forecast_horizon
train = data.iloc[is_history]
score = data.iloc[~np.array(is_history)]
return train, score
今回は外因性変数を持つコアのデータセットを見てみましょう。今回は、すべてのSKUに対するデータはSparkデータフレームで論理的に統合されており、大規模分散処理を可能にしています。
enriched_df = (
demand_df
.groupBy("Product")
.applyInPandas(add_exo_variables, enriched_schema)
)
display(enriched_df)
ソリューション: ハイレベルの概要
メリット:
- ピュアなPython & Pandas: 開発やテストが簡単
- あなたの好きなライブラリを使い続けることができます
- 単一のSKUに対するPandasデータフレームの操作をシンプルに仮定できます
Pandas UDFによるそれぞれのSKUごとのモデルの構築、チューニング、スコアリング
def build_tune_and_score_model(sku_pdf: pd.DataFrame) -> pd.DataFrame:
"""
この関数はそれぞれのSKUごとのモデルを構築、チューニング、スコアリングし、Pandas UDFとして分散することができます
"""
# 常に日付で適切な順序、インデックスであることを保証
sku_pdf.sort_values("Date", inplace=True)
complete_ts = sku_pdf.set_index("Date").asfreq(freq="W-MON")
print(complete_ts)
# (Product, SKU)で大規模なSparkデータフレームをグルーピング
PRODUCT = sku_pdf["Product"].iloc[0]
SKU = sku_pdf["SKU"].iloc[0]
train_data, validation_data = split_train_score_data(complete_ts)
exo_fields = ["covid", "christmas", "new_year"]
# トレーニングデータセットに対するモデルの評価
def evaluate_model(hyperopt_params):
# SARIMAXではPython integerのタプルが必要
order_hparams = tuple([int(hyperopt_params[k]) for k in ("p", "d", "q")])
# トレーニング
model = SARIMAX(
train_data["Demand"],
exog=train_data[exo_fields],
order=order_hparams,
seasonal_order=(0, 0, 0, 0), # 我々のサンプルでは季節性は考慮せず
initialization_method="estimated",
enforce_stationarity = False,
enforce_invertibility = False
)
fitted_model = model.fit(disp=False, method='nm')
# 検証
fcast = fitted_model.predict(
start=validation_data.index.min(),
end=validation_data.index.max(),
exog=validation_data[exo_fields]
)
return {'status': hyperopt.STATUS_OK, 'loss': np.power(validation_data.Demand.to_numpy() - fcast.to_numpy(), 2).mean()}
search_space = {
'p': scope.int(hyperopt.hp.quniform('p', 0, 4, 1)),
'd': scope.int(hyperopt.hp.quniform('d', 0, 2, 1)),
'q': scope.int(hyperopt.hp.quniform('q', 0, 4, 1))
}
rstate = np.random.default_rng(123) # このノートブックの再現性確保のためのもの
best_hparams = fmin(evaluate_model, search_space, algo=tpe.suggest, max_evals=10)
# Training
model_final = SARIMAX(
train_data["Demand"],
exog=train_data[exo_fields],
order=tuple(best_hparams.values()),
seasonal_order=(0, 0, 0, 0), # 我々のサンプルでは季節性は考慮せず
initialization_method="estimated",
enforce_stationarity = False,
enforce_invertibility = False
)
fitted_model_final = model_final.fit(disp=False, method='nm')
# 評価
fcast = fitted_model_final.predict(
start=complete_ts.index.min(),
end=complete_ts.index.max(),
exog=validation_data[exo_fields]
)
forecast_series = complete_ts[['Product', 'SKU' , 'Demand']].assign(Date = complete_ts.index.values).assign(Demand_Fitted = fcast)
forecast_series = forecast_series[['Product', 'SKU' , 'Date', 'Demand', 'Demand_Fitted']]
return forecast_series
tuning_schema = StructType(
[
StructField('Product', StringType()),
StructField('SKU', StringType()),
StructField('Date', DateType()),
StructField('Demand', FloatType()),
StructField('Demand_Fitted', FloatType())
]
)
分散処理の実行: groupBy("SKU")
+ applyInPandas(...)
# 並列度を最大化するために、それぞれの("Product", SKU")グループに自身のSparkタスクを割り当てることができます。
# 以下によってこれを実現することができます:
# - このグループのみのAdaptive Query Execution (AQE)を無効化
# - 以下のように入力Sparkデータフレームをパーティショニング:
spark.conf.set("spark.databricks.optimizer.adaptive.enabled", "false")
n_tasks = enriched_df.select("Product", "SKU").distinct().count()
forecast_df = (
enriched_df
.repartition(n_tasks, "Product", "SKU")
.groupBy("Product", "SKU")
.applyInPandas(build_tune_and_score_model, schema=tuning_schema)
)
display(forecast_df)
Deltaに保存
forecast_df_delta_path = os.path.join(cloud_storage_path, 'forecast_df_delta')
# データの書き出し
forecast_df.write \
.mode("overwrite") \
.format("delta") \
.save(forecast_df_delta_path)
spark.sql(f"DROP TABLE IF EXISTS {dbName}.part_level_demand_with_forecasts")
spark.sql(f"CREATE TABLE {dbName}.part_level_demand_with_forecasts USING DELTA LOCATION '{forecast_df_delta_path}'")
display(spark.sql(f"SELECT * FROM {dbName}.part_level_demand_with_forecasts"))
後編に続きます。