3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

SHAP for LSTM 時系列データの特徴量貢献度を定量評価

Last updated at Posted at 2023-06-16

前置き

編集中です。
tesnsorflow 2系のLSTM等時系列回帰におけるSHAP解析の例が中々ないので(公式もtesorflow1系、かつ言語モデル)、サンプルコードを上げておきます。
SHAP部分だけ知りたい場合は、最後のコードブロックを見ればできると思います。
キモは、tesnsorflow1系の動作に指定して(tf.compat.v1.disable_v2_behavior())
SHAP用のモデルとしてhidden stateを入力に指定するLSTMネットワークを構築する事です。

解析データ準備と前処理

https://www.data.jma.go.jp/gmd/risk/obsdl/index.php
気象庁のデータベースから、北海道網走の天候データをダウンロードしました。
日本各地の観測値が提供されていて、計測パラメータも多くおもろいんでお遊びに最適です。

csvの列名に日本語を含みますので、csv読み込みの際にはencoding='SHIFT=JIS'を指定します。

#必要なデータを、LSTM入力のデータ期式に整形していきます
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Input, Dense, LSTM, Reshape

file_path = 'D:/data_for_DNN/time_series/Hokkaido_abashiri.csv'

#ラベルに日本語を含む場合は encoding='SHIFT=JIS' を指定
raw_data = pd.read_csv(file_path, encoding='SHIFT-JIS')
#日付データ列(0列目)を格納してから削除
time = raw_data.iloc[:,0]
raw_data.drop(columns=raw_data.columns[[0]], inplace=True)
print('data shape =', raw_data.shape)

#変数名を取得
feature_names = raw_data.columns.values.tolist()
print('feature name =', feature_names)
output
data shape = (361, 10)
feature name = ['降水量の合計(mm)', '日照時間(時間)', '平均風速(m/s)', '平均蒸気圧(hPa)', '平均湿度(%)', '平均現地気圧(hPa)', '平均海面気圧(hPa)', '最低気温(℃)', '最高気温(℃)', '平均気温(℃)']

一応データを見ておきます。

raw_data.describe()

df_describe.png

LSTMで扱えるよう、正規化とデータ整形をします。

#目的変数yを格納。目的変数は降水量の合計(mm)を選択
raw_all_y = raw_data.iloc[:,0]
print('target =', feature_names[0])

#min-max正規化のために、最大/最小値を保管しておく
state = raw_data.describe()

scalar_max = state.loc['max'].values
scalar_min = state.loc['min'].values
scalar_max = scalar_max.reshape(1,-1)
scalar_min = scalar_min.reshape(1,-1)

#min-max正規化を行う
value_data = raw_data.values
norm_data = (value_data - scalar_min) / (scalar_max - scalar_min)

#正規化データを目的変数と説明変数に分ける
norm_all_x = norm_data[:, 1:]
norm_all_y = norm_data[:, 0].reshape(1,-1).T

#説明変数の数を取得
n_feature = norm_all_x.shape[1]
print('number of explainer =', n_feature)

#データをtrain, testに分ける
split_rate = 0.8

norm_train = norm_data[0:int(split_rate*len(norm_data)), :]
norm_test = norm_data[int(split_rate*len(norm_data)):, :]

print('train data shape =', norm_train.shape)
print('test data shape =', norm_test.shape)

#LSTM用にデータを整形する
norm_train_x = norm_train[:, 1:]
norm_train_y = norm_train[:, 0]

window = 10
norm_set_x = []
norm_set_y = []
for i in range(0, norm_train_x.shape[0] - window):
    norm_set_x.append(norm_train_x[i:i+window])
    norm_set_y.append(norm_train_y[i:i+window])
norm_set_x = np.array(norm_set_x)
norm_set_y = np.array(norm_set_y)

print('train set x shape =', norm_set_x.shape)
print('train set y shape =', norm_set_y.shape)

モデルの定義、訓練

簡単のため、単純な1層LSTM+(出力次元調整用)Denseネットワークにします。
3つモデルを定義していますが、実際訓練するのはtrain_lstm_modelです。
state_lstm_modek, shap_lstm_modelには、train_lstm_modelの訓練済み重みをload_weightsさせます。なので、3者で学習パラメータがある部分の構造は共通化してください。

#訓練するLSTMモデル, 後々に用いるreturn state入りLSTMモデル, SHAP用LSTMモデルを定義する

def train_lstm_model(units=10, input_dim=10):
    model = Sequential()
    model.add(Input(shape=(None, input_dim)))
    model.add(LSTM(units, return_sequences=True))
    model.add(Dense(1))
    
    model.compile(loss='mse', optimizer='adam')
    
    return model

def state_lstm_model(units=10, input_dim=10):
    x_in = Input(shape=(None, input_dim))
    x, h, c = LSTM(units, return_state=True)(x_in)
    output = Dense(1)(x)
    
    state_model = Model(inputs=x_in, outputs=[output, h, c])
    
    return state_model

def shap_lstm_model(units=10, input_dim=10):
    shap_input_dim = input_dim + (units*2)
    
    x_in = Input(shape=(shap_input_dim))
    x, h, c = tf.split(x_in, [input_dim, units, units], axis=1)
    x = Reshape((1, input_dim))(x)
    x = LSTM(units)(x, initial_state=[h,c])
    output = Dense(1)(x)
    
    shap_model = Model(inputs=x_in, outputs=output)
    
    return shap_model

train model を学習させます。
callbacksはtf.keras.callbacks.LearningRateShedulerのみです。
実際のタスクではロス収束具合をみるためにtf.keras.callbacks.TensorBoardも入れておいていた方が良いと思います。

#訓練
batch_size = 4
epochs = 1000

def step_decay(epoch):
    x = 0.001
    if epoch >= 300: x = 0.001 * (1-(epoch/epochs))
    return x
lr_decay = tf.keras.callbacks.LearningRateScheduler(step_decay)

train_flag = True
if train_flag == True:
    hist = train_model.fit(
        norm_set_x, norm_set_y,
        validation_split=0.1,
        batch_size=batch_size,
        epochs=epochs,
        verbose=2,
        callbacks=[
            lr_decay
        ]
    )
    train_model.save('D:/data_for_DNN/time_series/Hokaido_LSTM.h5')
    train_model.save_weights('D:/data_for_DNN/time_series/Hokaido_LSTM_weight.h5')

else:
    train_model.load_weights('D:/data_for_DNN/time_series/Hokaido_LSTM_weight.h5')    

予測値の確認

#訓練したモデルへの入力用にデータを整形
pred_input = norm_all_x.reshape(1, norm_all_x.shape[0], n_feature)
print('prediction input data shape =', pred_input.shape)

#予測結果を格納
pred_output = train_model.predict(pred_input)[0]

#予測結果の正規化を解除
pred_output = pred_output*(scalar_max[0] - scalar_min[0]) + scalar_min[0]
print('prediction output data shape =', pred_output.shape)

pred_output = pred_output[:,0].reshape(-1,)
print(pred_output.shape)

index = np.linspace(0, len(pred_output), len(pred_output))
print(index.shape)

#予測結果と真値を可視化
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(1, 1, 1)
ax.plot(index, raw_all_y, label='true')
ax.plot(index, pred_output, label='prediction', alpha=0.7)

ax.legend(loc='upper right')
plt.show()
plt.clf()
plt.close()

pred.png

SHAP入力用データを作成

state_model = create_state_lstm_model(units=units, input_dim=n_feature)
state_model.load_weights('D:/data_for_DNN/time_series/Hokaido_LSTM_weight.h5')

xl, hl, cl = [], [], []
print(norm_all_x.shape)
for i in range(norm_all_x.shape[0]):
    state_pred_input = norm_all_x[i:i+1,:].reshape(1,-1,n_feature)
    x, h, c = state_model.predict(state_pred_input, verbose=0)
    xl.append(x)
    hl.append(h)
    cl.append(c)

xl = np.array(xl)
hl = np.array(hl)
cl = np.array(cl)
print(xl.shape, hl.shape, cl.shape)

xll = xl.reshape(-1,1)
hll = hl.reshape(-1, units)
cll = cl.reshape(-1, units)

#results = np.concatenate((norm_all_x, xll), axis=1)

results = np.concatenate((norm_all, hll), axis=1)
results = np.concatenate((results, cll), axis=1)

input_feature_names = feature_names.copy()

for i in range(hll.shape[1]):
    append_name = 'h' + str(i)
    input_feature_names.append(append_name)
for i in range(cll.shape[1]):
    append_name = 'c' + str(i)
    input_feature_names.append(append_name)
input_feature_names.pop(0)
print('feature names include hidden state =', input_feature_names)
print('results shape =', results.shape, len(input_feature_names))
pd.DataFrame(results, columns=input_feature_names).to_csv(
    'D:/data_for_DNN/time_series/Hokkaido_abashiri_results.csv',
    index=False,
    encoding='SHIFT-JIS'
)

SHAP適用

import shap
import numpy as np
import random
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import Dense, LSTM, Input, Reshape

tf.compat.v1.disable_v2_behavior()

state_shap_input = pd.read_csv('D:/data_for_DNN/time_series/Hokkaido_abashiri_results.csv', encoding='SHIFT-JIS')

state_feature_names = state_shap_input.columns.values.tolist()
n_feature = len(state_feature_names) - (units*2) #(h, cがunits数分入っているため、unitsの2倍を差し引く)
print('number of features =', n_feature)

row_shap_input = len(state_shap_input)
print('data length (time direction) =', row_shap_input)

BG_shap_input_size = 100
state_BG_shap_input = state_shap_input.sample(frac=BG_shap_input_size/row_shap_input, axis=0, replace=False)

state_shap_input = state_shap_input.values

state_SH_shap_input = state_shap_input[:,:]

print('shap back ground input data shape =', state_BG_shap_input.shape)
print('shap SHAP value calc. input data shape =', state_SH_shap_input.shape)

with_state_shap_model = create_shap_lstm_model(units=units, input_dim=n_feature)
with_state_shap_model.summary()
with_state_shap_model.load_weights('D:/data_for_DNN/time_series/Hokaido_LSTM_weight.h5')

explainer = shap.DeepExplainer(
    with_state_shap_model, state_BG_shap_input
)

with_state_shap_values = explainer.shap_values(state_SH_shap_input)

shap.summary_plot(
    with_state_shap_values[0],
    features=state_SH_shap_input,
    feature_names=state_feature_names,
    max_display=50,
    show=True,
    plot_type=None
)

shap_sample.png

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?