4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DatabricksでLightGBM on Apache Sparkを動かしてみる

Posted at

考えたら多分動かしたことありませんでした。同僚の方の記事やマニュアルを参考に。

ライブラリのインストール

こちらの手順に従います。

Screenshot 2024-10-31 at 20.02.14.png

サンプルのウォークスルー

こちらのノートブックをウォークスルーします。

LightGBMとは

LightGBMは、オープンソースの分散型高性能勾配ブースティング(GBDT、GBRT、GBM、またはMART)フレームワークです。このフレームワークは、ランキング、分類、およびその他多くの機械学習タスクのための高品質でGPU対応の決定木アルゴリズムの作成に特化しています。LightGBMは、MicrosoftのDMTKプロジェクトの一部です。

LightGBMの利点

  • 構成可能性: LightGBMモデルは既存のSparkMLパイプラインに組み込むことができ、バッチ、ストリーミング、およびサービングワークロードに使用できます。
  • パフォーマンス: Spark上のLightGBMは、HiggsデータセットでSparkMLよりも10-30%高速であり、AUCが15%向上します。並列実験により、特定の設定で複数のマシンを使用してトレーニングすることで線形のスピードアップが達成できることが確認されています。
  • 機能性: LightGBMは、決定木システムをカスタマイズするために使用できる調整可能なパラメータの幅広い配列を提供します。Spark上のLightGBMは、分位回帰などの新しいタイプの問題もサポートしています。
  • クロスプラットフォーム: Spark上のLightGBMは、Spark、PySpark、およびSparklyRで利用可能です。

LightGBMの使用法

  • LightGBMClassifier: 分類モデルの構築に使用されます。例えば、企業が破産するかどうかを予測するために、LightGBMClassifierを使用して二項分類モデルを構築できます。
  • LightGBMRegressor: 回帰モデルの構築に使用されます。例えば、住宅価格を予測するために、LightGBMRegressorを使用して回帰モデルを構築できます。
  • LightGBMRanker: ランキングモデルの構築に使用されます。例えば、ウェブサイトの検索結果の関連性を予測するために、LightGBMRankerを使用してランキングモデルを構築できます。

LightGBMClassifierを使用して分類モデルをトレーニングする

この例では、LightGBMを使用して分類モデルを構築し、破産を予測します。

データセットの読み込み

from synapse.ml.core.platform import *
df = (
    spark.read.format("csv")
    .option("header", True)
    .option("inferSchema", True)
    .load(
        "wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv"
    )
)
# データセットのサイズを表示
print("records read: " + str(df.count()))
print("Schema: ")
df.printSchema()
records read: 6819
Schema: 
root
 |-- Bankrupt?: integer (nullable = true)
 |--  ROA(C) before interest and depreciation before interest: double (nullable = true)
 |--  ROA(A) before interest and % after tax: double (nullable = true)
 |--  ROA(B) before interest and depreciation after tax: double (nullable = true)
 |--  Operating Gross Margin: double (nullable = true)
 |--  Realized Sales Gross Margin: double (nullable = true)
 |--  Operating Profit Rate: double (nullable = true)
 |--  Pre-tax net Interest Rate: double (nullable = true)
 |--  After-tax net Interest Rate: double (nullable = true)
 |--  Non-industry income and expenditure/revenue: double (nullable = true)
 |--  Continuous interest rate (after tax): double (nullable = true)
 |--  Operating Expense Rate: double (nullable = true)
 |--  Research and development expense rate: double (nullable = true)
 |--  Cash flow rate: double (nullable = true)
 |--  Interest-bearing debt interest rate: double (nullable = true)
 |--  Tax rate (A): double (nullable = true)
 |--  Net Value Per Share (B): double (nullable = true)
 |--  Net Value Per Share (A): double (nullable = true)
 |--  Net Value Per Share (C): double (nullable = true)
 |--  Persistent EPS in the Last Four Seasons: double (nullable = true)
 |--  Cash Flow Per Share: double (nullable = true)
 |--  Revenue Per Share (Yuan ??): double (nullable = true)
 |--  Operating Profit Per Share (Yuan ??): double (nullable = true)
 |--  Per Share Net profit before tax (Yuan ??): double (nullable = true)
 |--  Realized Sales Gross Profit Growth Rate: double (nullable = true)
 |--  Operating Profit Growth Rate: double (nullable = true)
 |--  After-tax Net Profit Growth Rate: double (nullable = true)
 |--  Regular Net Profit Growth Rate: double (nullable = true)
 |--  Continuous Net Profit Growth Rate: double (nullable = true)
 |--  Total Asset Growth Rate: double (nullable = true)
 |--  Net Value Growth Rate: double (nullable = true)
 |--  Total Asset Return Growth Rate Ratio: double (nullable = true)
 |--  Cash Reinvestment %: double (nullable = true)
 |--  Current Ratio: double (nullable = true)
 |--  Quick Ratio: double (nullable = true)
 |--  Interest Expense Ratio: double (nullable = true)
 |--  Total debt/Total net worth: double (nullable = true)
 |--  Debt ratio %: double (nullable = true)
 |--  Net worth/Assets: double (nullable = true)
 |--  Long-term fund suitability ratio (A): double (nullable = true)
 |--  Borrowing dependency: double (nullable = true)
 |--  Contingent liabilities/Net worth: double (nullable = true)
 |--  Operating profit/Paid-in capital: double (nullable = true)
 |--  Net profit before tax/Paid-in capital: double (nullable = true)
 |--  Inventory and accounts receivable/Net value: double (nullable = true)
 |--  Total Asset Turnover: double (nullable = true)
 |--  Accounts Receivable Turnover: double (nullable = true)
 |--  Average Collection Days: double (nullable = true)
 |--  Inventory Turnover Rate (times): double (nullable = true)
 |--  Fixed Assets Turnover Frequency: double (nullable = true)
 |--  Net Worth Turnover Rate (times): double (nullable = true)
 |--  Revenue per person: double (nullable = true)
 |--  Operating profit per person: double (nullable = true)
 |--  Allocation rate per person: double (nullable = true)
 |--  Working Capital to Total Assets: double (nullable = true)
 |--  Quick Assets/Total Assets: double (nullable = true)
 |--  Current Assets/Total Assets: double (nullable = true)
 |--  Cash/Total Assets: double (nullable = true)
 |--  Quick Assets/Current Liability: double (nullable = true)
 |--  Cash/Current Liability: double (nullable = true)
 |--  Current Liability to Assets: double (nullable = true)
 |--  Operating Funds to Liability: double (nullable = true)
 |--  Inventory/Working Capital: double (nullable = true)
 |--  Inventory/Current Liability: double (nullable = true)
 |--  Current Liabilities/Liability: double (nullable = true)
 |--  Working Capital/Equity: double (nullable = true)
 |--  Current Liabilities/Equity: double (nullable = true)
 |--  Long-term Liability to Current Assets: double (nullable = true)
 |--  Retained Earnings to Total Assets: double (nullable = true)
 |--  Total income/Total expense: double (nullable = true)
 |--  Total expense/Assets: double (nullable = true)
 |--  Current Asset Turnover Rate: double (nullable = true)
 |--  Quick Asset Turnover Rate: double (nullable = true)
 |--  Working capitcal Turnover Rate: double (nullable = true)
 |--  Cash Turnover Rate: double (nullable = true)
 |--  Cash Flow to Sales: double (nullable = true)
 |--  Fixed Assets to Assets: double (nullable = true)
 |--  Current Liability to Liability: double (nullable = true)
 |--  Current Liability to Equity: double (nullable = true)
 |--  Equity to Long-term Liability: double (nullable = true)
 |--  Cash Flow to Total Assets: double (nullable = true)
 |--  Cash Flow to Liability: double (nullable = true)
 |--  CFO to Assets: double (nullable = true)
 |--  Cash Flow to Equity: double (nullable = true)
 |--  Current Liability to Current Assets: double (nullable = true)
 |--  Liability-Assets Flag: double (nullable = true)
 |--  Net Income to Total Assets: double (nullable = true)
 |--  Total assets to GNP price: double (nullable = true)
 |--  No-credit Interval: double (nullable = true)
 |--  Gross Profit to Sales: double (nullable = true)
 |--  Net Income to Stockholder's Equity: double (nullable = true)
 |--  Liability to Equity: double (nullable = true)
 |--  Degree of Financial Leverage (DFL): double (nullable = true)
 |--  Interest Coverage Ratio (Interest expense to EBIT): double (nullable = true)
 |--  Net Income Flag: double (nullable = true)
 |--  Equity to Liability: double (nullable = true)
display(df)

Screenshot 2024-10-31 at 20.05.52.png

データセットをトレーニングセットとテストセットに分割

train, test = df.randomSplit([0.85, 0.15], seed=1)

特徴量をベクトルに変換するためのフィーチャライザーを追加

from pyspark.ml.feature import VectorAssembler

feature_cols = df.columns[1:]
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")
train_data = featurizer.transform(train)["Bankrupt?", "features"]
test_data = featurizer.transform(test)["Bankrupt?", "features"]

データが不均衡かどうかを確認する

不均衡であることがわかります。

display(train_data.groupBy("Bankrupt?").count())

Screenshot 2024-10-31 at 20.07.09.png

モデルのトレーニング

from synapse.ml.lightgbm import LightGBMClassifier

model = LightGBMClassifier(
    objective="binary", featuresCol="features", labelCol="Bankrupt?", isUnbalance=True
)
model = model.fit(train_data)

"saveNativeModel"は、Sparkでトレーニングした後に高速な展開のために、lightGBMモデルを抽出することができます。

from synapse.ml.lightgbm import LightGBMClassificationModel

if running_on_synapse():
    model.saveNativeModel("/models/lgbmclassifier.model")
    model = LightGBMClassificationModel.loadNativeModelFromFile(
        "/models/lgbmclassifier.model"
    )
if running_on_synapse_internal():
    model.saveNativeModel("Files/models/lgbmclassifier.model")
    model = LightGBMClassificationModel.loadNativeModelFromFile(
        "Files/models/lgbmclassifier.model"
    )
else:
    model.saveNativeModel("/tmp/lgbmclassifier.model")
    model = LightGBMClassificationModel.loadNativeModelFromFile(
        "/tmp/lgbmclassifier.model"
    )

Databricksの場合、DBFSに保存されます。

%sh
ls /dbfs/tmp/lgbmclassifier.model
_committed_1540161204692494717
_committed_4581164859921369446
_committed_921647600985059100
_committed_vacuum977166471950681237
_started_4581164859921369446
part-00000-tid-4581164859921369446-fdc9085b-c3fb-423f-9d7f-b21967e22321-23-1-c000.txt

特徴の重要度を可視化する

import pandas as pd
import matplotlib.pyplot as plt

# 特徴量の重要度を取得
feature_importances = model.getFeatureImportances()
fi = pd.Series(feature_importances, index=feature_cols)
fi = fi.sort_values(ascending=True)
f_index = fi.index
f_values = fi.values

# 特徴量の重要度を表示
print("f_index:", f_index)
print("f_values:", f_values)

# プロット
x_index = list(range(len(fi)))
x_index = [x / len(fi) for x in x_index]
plt.rcParams["figure.figsize"] = (20, 20)
plt.barh(
    x_index, f_values, height=0.028, align="center", color="tan", tick_label=f_index
)
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.show()

download (1).png

モデルで予測を生成する

predictions = model.transform(test_data)
predictions.limit(10).toPandas()

Screenshot 2024-10-31 at 20.11.23.png

こちらのコードはエラーになってしまいました。こちらと関係あるのかしら。

from synapse.ml.train import ComputeModelStatistics

metrics = ComputeModelStatistics(
    evaluationMetric="classification",
    labelCol="Bankrupt?",
    scoredLabelsCol="prediction",
).transform(predictions)
display(metrics)

LightGBMRegressor を使用して分位点回帰モデルを訓練する

この例では、LightGBM を使用して回帰モデルを構築する方法を示します。

データセットの読み込み

triazines = spark.read.format("libsvm").load(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/triazines.scale.svmlight"
)
# 基本情報を表示
print("records read: " + str(triazines.count()))
print("Schema: ")
triazines.printSchema()
display(triazines.limit(10))
records read: 105
Schema: 
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)

Screenshot 2024-10-31 at 20.14.24.png

データセットをトレーニングセットとテストセットに分割

train, test = triazines.randomSplit([0.85, 0.15], seed=1)

LightGBMRegressorを使用してモデルをトレーニングする

from synapse.ml.lightgbm import LightGBMRegressor

model = LightGBMRegressor(
    objective="quantile", alpha=0.2, learningRate=0.3, numLeaves=31
).fit(train)
print(model.getFeatureImportances())
[30.0, 9.0, 0.0, 0.0, 26.0, 4.0, 0.0, 8.0, 2.0, 0.0, 11.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 24.0, 19.0, 0.0, 20.0, 5.0, 0.0, 14.0, 0.0, 24.0, 9.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 10.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 11.0]

モデルを使用して予測を生成する

scoredData = model.transform(test)
display(scoredData)

Screenshot 2024-10-31 at 20.16.18.png

from synapse.ml.train import ComputeModelStatistics

metrics = ComputeModelStatistics(
    evaluationMetric="regression", labelCol="label", scoresCol="prediction"
).transform(scoredData)
display(metrics)

Screenshot 2024-10-31 at 20.16.42.png

LightGBMRankerを使用してランキングモデルをトレーニングする

データセットの読み込み

df = spark.read.format("parquet").load(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/lightGBMRanker_train.parquet"
)
# 基本情報を表示
print("読み込まれたレコード数: " + str(df.count()))
print("スキーマ: ")
df.printSchema()
display(df.limit(10))
読み込まれたレコード数: 3005
スキーマ: 
root
 |-- query: long (nullable = true)
 |-- labels: double (nullable = true)
 |-- features: vector (nullable = true)

Screenshot 2024-10-31 at 20.17.42.png

LightGBMRankerを使用してランキングモデルを訓練する

from synapse.ml.lightgbm import LightGBMRanker

features_col = "features"
query_col = "query"
label_col = "labels"
lgbm_ranker = LightGBMRanker(
    labelCol=label_col,
    featuresCol=features_col,
    groupCol=query_col,
    predictionCol="preds",
    leafPredictionCol="leafPreds",
    featuresShapCol="importances",
    repartitionByGroupingColumn=True,
    numLeaves=32,
    numIterations=200,
    evalAt=[1, 3, 5],
    metric="ndcg",
)
lgbm_ranker_model = lgbm_ranker.fit(df)

モデルで予測を生成する

dt = spark.read.format("parquet").load(
    "wasbs://publicwasb@mmlspark.blob.core.windows.net/lightGBMRanker_test.parquet"
)
predictions = lgbm_ranker_model.transform(dt)
predictions.limit(10).toPandas()

Screenshot 2024-10-31 at 20.19.58.png

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?