0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Snowflake ML でSHAPを可視化しよう

Posted at

この記事に書いてあること

  • Snowflake ML Pythonのライブラリを使ってLightGBMのモデルを作成
  • 作成したモデルを使って推論
  • 推論時の説明可能性をSHAP値から可視化

はじめに

SnowflakeのML関連機能は日々進化しています。
2025/7/8に説明可能性を可視化する機能がリリースされたので、早速使ってみました。

実行環境

ノートブック

今回は全てSnowflake Notebook上で実行します。

  • ランタイム
    Snowflake Warehouse Runtime 2.0
  • クエリウェアハウス
    XSMALL
  • ノートブック ウェアハウス
    XSMALL

ライブラリ

  • snowflake-ml-python 1.9.0
  • pandas 2.2.3
  • LightGBM 4.6.0

今回はsnowflake-ml-pythonのLightGBMを使用しますが、通常のLightGBMもインストールしないと使用できないので、パッケージで指定しておきます。
package

やってみる

データ準備

ダミーデータを用意します。

TRANSACTIONDATE PRODUCT SALES LAG_1 LAG_2 LAG_3 LAG_7
SALESの日付 製品番号 販売数 1日前の販売数 2日前の販売数 3日前の販売数 7日前の販売数
2025-07-10 A1 12 3 4 2 5
2025-07-11 A2 5 3 1 3 10

今回はとある小売店の製品ごとの一日当たりの販売数を予測する機械学習モデルを作りたいと思います。
このデータを学習データと推論データに分けて、それぞれTRAIN_DATAPRED_DATAというテーブルに入れておきます。

学習

今回はSHAPを計算することが目的なので、お作法的な部分はすっ飛ばして学習させます。

from snowflake.snowpark.context import get_active_session
from snowflake.snowpark import Session
from snowflake.snowpark.functions import col
from snowflake.ml.modeling.lightgbm import LGBMRegressor
from snowflake.ml.registry import Registry
import pandas as pd

session = get_active_session()

# データの取得
train_df = (
    session.table("demo.public.train_data")
    .select(col('LAG_1'), col('LAG_2'), col('LAG_3'), col('LAG_7'),col('SALES'))
)# -> Snowpark DataFrame

# 特徴量列を指定
feature_cols= ['LAG_1','LAG_2','LAG_3','LAG_7']

# モデルのトレーニング
# input_col/label_colsをインスタンス作成時に指定する必要があります
model = LGBMRegressor(input_col=feature_cols, label_cols='SALES')
model.fit(train_df)

# モデルレジストリに保存
registry = Registry(session=session, database_name="DEMO", schema_name="PUBLIC")
registry.log_model(model=model, model_name="practice_model",version_name='v1')

以下のような出力が得られればモデルレジストリにモデルを格納できています。

//出力
ModelVersion(
  name='PRACTICE_MODEL',
  version='V2',
)

学習したモデルを確認

モデルレジストリの画面を開くと、登録したモデルが一覧表示されます。
model registory
今作成したのはv1なので、v1をクリックします。
下の方にスクロールしていくと、Functionsの欄にEXPLAINという記載があります。
今回使用したいのはこちらのメソッドです。

mv_function

SHAPを計算して可視化

EXPLAINを使ってSHAPを計算します。

from snowflake.ml.monitoring import explain_visualize
from snowflake.snowpark.functions import avg

# 登録済みモデルを取得
reg = Registry(session=session, database_name='DEMO', schema_name='PUBLIC')
mv = reg.get_model('PRACTICE_MODEL').version('V1')

# 推論用データセットを取得
pred_tbl= (
    session.table("demo.public.pred_data")
    .orderBy(col('PRODUCT'),col('TRANSACTIONDATE'))
    .select("LAG_1", "LAG_2", "LAG_3", "LAG_7")
)
# 推論を実施
pred_result = mv.run(pred_tbl, function_name='predict')

# 推論値の平均値を算出
# SALESという列に対して推論を行った場合はOUTPUT_SALESという列に推論値が吐かれる
pred_avg = (
    pred_result
    .select(
        avg(col('output_sales'))
        .alias('VALUE')
        )
    ).collect()
base = pred_avg[0].asDict()['VALUE']

# SHAPを計算
feat = pred_tbl.to_pandas()
shap = mv.run(feat,function_name='explain')

# 描画
row_number = 15
explain_visualize.plot_force(
    shap.iloc[row_number], 
    feat.iloc[row_number],
    base_value=base,
    figsize=(900, 400)
)

実行すると以下のような図が出力されます。
plot_force

今回はbase_valueに推論値の平均値を指定しているので、個々の推論結果が平均値からどれだけ離れたかを測ります。
推論値を大きくする方に働いた特徴量は赤の矢印、小さくする方に働いた特徴量は青の矢印で表示されます。
row_numberを変えることで見たいレコードについて描画が可能です。

2025/7/12現在でplot_forceの他にplot_violinplot_influence_sensitivityが利用可能です。用途に合わせて使い分けてみてください。

公式ドキュメント

感想

モデルレジストリがリリースされたばかりの頃は正直使いにくいなと思っていたのですが、最近は「おっ結構いいじゃん」的な印象になってきました。
皆さんもSnowflakeで良いMLライフをお送りください!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?