この記事に書いてあること
- Snowflake ML Pythonのライブラリを使ってLightGBMのモデルを作成
- 作成したモデルを使って推論
- 推論時の説明可能性をSHAP値から可視化
はじめに
SnowflakeのML関連機能は日々進化しています。
2025/7/8に説明可能性を可視化する機能がリリースされたので、早速使ってみました。
実行環境
ノートブック
今回は全てSnowflake Notebook上で実行します。
- ランタイム
Snowflake Warehouse Runtime 2.0 - クエリウェアハウス
XSMALL - ノートブック ウェアハウス
XSMALL
ライブラリ
- snowflake-ml-python 1.9.0
- pandas 2.2.3
- LightGBM 4.6.0
今回はsnowflake-ml-pythonのLightGBMを使用しますが、通常のLightGBMもインストールしないと使用できないので、パッケージで指定しておきます。
やってみる
データ準備
ダミーデータを用意します。
TRANSACTIONDATE | PRODUCT | SALES | LAG_1 | LAG_2 | LAG_3 | LAG_7 |
---|---|---|---|---|---|---|
SALESの日付 | 製品番号 | 販売数 | 1日前の販売数 | 2日前の販売数 | 3日前の販売数 | 7日前の販売数 |
2025-07-10 | A1 | 12 | 3 | 4 | 2 | 5 |
2025-07-11 | A2 | 5 | 3 | 1 | 3 | 10 |
今回はとある小売店の製品ごとの一日当たりの販売数を予測する機械学習モデルを作りたいと思います。
このデータを学習データと推論データに分けて、それぞれTRAIN_DATA
とPRED_DATA
というテーブルに入れておきます。
学習
今回はSHAPを計算することが目的なので、お作法的な部分はすっ飛ばして学習させます。
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.ml.modeling.lightgbm import LGBMRegressor
from snowflake.ml.registry import Registry
import pandas as pd
session = get_active_session()
# データの取得
train_df = (
session.table("demo.public.train_data")
.select(col('LAG_1'), col('LAG_2'), col('LAG_3'), col('LAG_7'),col('SALES'))
)# -> Snowpark DataFrame
# 特徴量列を指定
feature_cols= ['LAG_1','LAG_2','LAG_3','LAG_7']
# モデルのトレーニング
# input_col/label_colsをインスタンス作成時に指定する必要があります
model = LGBMRegressor(input_col=feature_cols, label_cols='SALES')
model.fit(train_df)
# モデルレジストリに保存
registry = Registry(session=session, database_name="DEMO", schema_name="PUBLIC")
registry.log_model(model=model, model_name="practice_model",version_name='v1')
以下のような出力が得られればモデルレジストリにモデルを格納できています。
//出力
ModelVersion(
name='PRACTICE_MODEL',
version='V2',
)
学習したモデルを確認
モデルレジストリの画面を開くと、登録したモデルが一覧表示されます。
今作成したのはv1なので、v1をクリックします。
下の方にスクロールしていくと、Functionsの欄にEXPLAIN
という記載があります。
今回使用したいのはこちらのメソッドです。
SHAPを計算して可視化
EXPLAIN
を使ってSHAPを計算します。
from snowflake.ml.monitoring import explain_visualize
from snowflake.snowpark.functions import avg
# 登録済みモデルを取得
reg = Registry(session=session, database_name='DEMO', schema_name='PUBLIC')
mv = reg.get_model('PRACTICE_MODEL').version('V1')
# 推論用データセットを取得
pred_tbl= (
session.table("demo.public.pred_data")
.orderBy(col('PRODUCT'),col('TRANSACTIONDATE'))
.select("LAG_1", "LAG_2", "LAG_3", "LAG_7")
)
# 推論を実施
pred_result = mv.run(pred_tbl, function_name='predict')
# 推論値の平均値を算出
# SALESという列に対して推論を行った場合はOUTPUT_SALESという列に推論値が吐かれる
pred_avg = (
pred_result
.select(
avg(col('output_sales'))
.alias('VALUE')
)
).collect()
base = pred_avg[0].asDict()['VALUE']
# SHAPを計算
feat = pred_tbl.to_pandas()
shap = mv.run(feat,function_name='explain')
# 描画
row_number = 15
explain_visualize.plot_force(
shap.iloc[row_number],
feat.iloc[row_number],
base_value=base,
figsize=(900, 400)
)
今回はbase_valueに推論値の平均値を指定しているので、個々の推論結果が平均値からどれだけ離れたかを測ります。
推論値を大きくする方に働いた特徴量は赤の矢印、小さくする方に働いた特徴量は青の矢印で表示されます。
row_number
を変えることで見たいレコードについて描画が可能です。
2025/7/12現在でplot_force
の他にplot_violin
とplot_influence_sensitivity
が利用可能です。用途に合わせて使い分けてみてください。
感想
モデルレジストリがリリースされたばかりの頃は正直使いにくいなと思っていたのですが、最近は「おっ結構いいじゃん」的な印象になってきました。
皆さんもSnowflakeで良いMLライフをお送りください!