LoginSignup
51
74

ベイズ最適化(実験点提案)アプリをStreamlitで構築するぜ!

Last updated at Posted at 2024-02-15

はじめに

Streamlitって気づいたら神アプデしてますよね。もっと大々的に宣伝してほしいものです(←自分で情報取りに行け)。

さて、化学メーカーに勤めている"自称"データサイエンティストとしてはやはりベイズ最適化したい衝動に駆られます。Notebook上では実装できていたのですが、もっと簡単に使いたいし、周囲に広めるためにもアプリの方が便利だなぁと思いました。

世の中にはそんなアプリがあるけど、ダウンロードが必要だったり(社内申請めんどくさい)、お金がかかったり・・・そうだ!自作しよう!

ということで、Streamlitでベイズ最適化による実験点提案アプリを自作しました。
ひとまず必要最低限の機能だけ実装したので、今後アップデートしていきます。

メインライブラリ
streamlit==1.30.0
scikit-learn==1.4.0

ガウス過程回帰、ベイズ最適化
・scikit-learn の GaussianProcessRegressorを使用しました。
・カーネル関数は11種類をクロスバリデーションで最適化しています。
・獲得関数はEIです(内部的にはPIやMIも選べるようにしていますが、ユーザーはそんなこと気にしないのでひとまずEI固定です)

実装した機能
・データの読み込み、書き出し
・動的散布図
・実験点提案
・ガウス過程回帰によるy-yプロット
・クロスバリデーションによるy-yプロット
・SHAP解析
・テストデータの読み込み・予測

ベイズ最適化とは

ベイズ最適化ってなんかとっつきにくいですよね。ガウス過程回帰だとか獲得関数だとか。
詳しい説明は各所で記事があるのでそちらに譲るとして、私は超簡単に図示してみました。
(ベイズ最適化だけで一つの記事になりそうだから気が向いたら・・・)

image.png

・「ベイズ最適化」は適応的実験計画法(実験計画と実験を繰り返すこと)の一つ
・ガウス過程回帰はベイズ最適化を行うための手段
・ガウス過程回帰は目的変数の推定値だけでなく、分散も計算できる
・カーネルトリックにより非線形モデルへ拡張(カーネル関数の使用)
・目的変数の推定値を最大化(最小化)したいが既存のサンプルに近いものが選ばれやすい(局所解に陥りやすい)
⇒ベイズ最適化では獲得関数が最大になる説明変数の値を選択
⇒獲得関数の計算にガウス過程回帰で得られた推定値の分散を利用

正直、中身は理解していなくても使えますが、ざっくり抑えとくと良いですね。
テキトーに実験するよりも効率的に最適解に辿り着きやすいよっていうことです。

アプリ画面

アプリを起動すると、タイトルとデータ読み込みウィジェットが現れます。
データはcsvに慣れていない人向けにExcelファイルで読み込むことにしました。

image.png

データ(適当に作った仮想データです)を読み込むと、Sheet1に記載された既存の実験データ、Sheet2に記載された探索範囲が表示されます。
また、機能ごとにタブで分けました。streamlitは普通に作成していくと、縦長になってしまい見栄えが悪いのでタブで分けるかページで分けるかするとスッキリしたUIになります。

(タブは最近化学メーカーからAI/DX関連の組織に転職したstreamlitマスターの後輩が教えてくれました、ありがとう(゚∀゚))

image.png

データの下部には散布図を作成できるスペースを作りました。
3次元+色で4次元プロットになります。ここでなんとなくデータ解析ができます。
探索されていない空間も把握できますね。

image.png

実験点の提案タブ

実験点の提案タブに移動すると、チェックボックスが出てきます。わざわざチェックボックスを置いたのは、データ読み込みと同時にこちらのタブで計算が始まらないように、ワンクッション置きたかったからです。
また、ボタンではなくチェックボックスにすることで、ウィジェット操作による影響で関数がリロード(streamlitの仕様)されてもチェック状態が保持されるので、session_stateに保持されているグラフやデータを読み込めば画面を維持することができるためです。
session_stateに関しては後述します。

image.png

チェックボックスをチェックすると、計算が開始されます(ガウス過程回帰モデルの構築、カーネル関数の最適化、実験点の探索)。上部にstreamlitのデフォルトで「RUNNING...」が表示されますが、気づかない人用にチェックボックス真下にスピナーを設置しました。

image.png

計算が終了するとまずはy-yプロットが表示されます。左側がトレーニングデータのy-yプロットになります。ガウス過程回帰は回帰の親玉みたいなもんなんで、訓練データはビッチリ対角線上に乗ります。
そのためモデルの精度としては、クロスバリデーションで正確に評価するのが良いです。右側にクロスバリデーションによるモデルの評価結果を表示させています。割といいモデルですね。

image.png

y-yプロットの下部に次の実験点を10個提案します。提案の仕方は、初めに1つ提案させて、その予測値を実測値としてデータに追加→再度モデルを構築しベイズ最適化の流れで行っています。
予測値の横に95%信頼区間の値を表示するようにしました。これを見ることで、どのあたりの空間を探索しているのかわかります。

image.png

また、streamlit==1.28.0からデータフレームのダウンロードボタンがデフォルトで備わるようになりました。これまではst.download_buttonで自作していましたが、便利になりましたね。

image.png

解析タブ

解析タブではSHAP解析を実装しています。例えば決定木ベースのモデルではジニ不純度で特徴量の重要度を算出しますが、ガウス過程回帰モデルでそのような手法は使えないのでSHAP解析により、特徴量の寄与度を算出して表示させています。
SHAP理論だけで一つの記事になりそうなので、ここではざっくり、ゲーム理論に基づいたShapley値という概念を用いて、個々の特徴量がモデルの予測にどのように貢献しているかを定量的に評価していると理解していただければオーケーです。

image.png

チェックボックスをチェックすると、計算が始まり、結果が表示されます。
「?」マークはカーソルを合わせると「モデルを構築してから」という注意書きが表示されます。
今後、その他可視化を追加してく予定です。

image.png

予測タブ

予測タブでは「僕が考えた最強の条件」の予測ができます。
streamlit==1.29.0からはデータアップロードによる画面のリロードが発生しなくなったので使いやすくなりました。

image.png

予測したいデータを読み込むと、「実験点の提案」で構築したガウス過程回帰モデルで回帰分析を行い、予想結果を表示します。
全然最強の条件じゃなかったですね(;'∀')

image.png

とりあえず実装した機能は以上です。次はコードに関して少し解説します。

コード

ベイズ最適化のコードは金子弘昌先生が公開しているgithubを参考に作成しています。その他のコード部分に関して以下で解説します。

def main():
    st.set_page_config(
        page_title='実験点提案@ベイズ最適化',
        layout='wide',
        page_icon = 'random'
        )
    st.set_option('deprecation.showPyplotGlobalUse', False)
    st.title('実験点提案@ベイズ最適化')

    uploaded_file = st.file_uploader('データを読み込んで下さい。', type=['xlsx'], key = 'train_data')
    if uploaded_file:
        data = pd.read_excel(uploaded_file, sheet_name=0)
        limit_data = pd.read_excel(uploaded_file, sheet_name=1, index_col=0)
        
        tabs_set(data, limit_data)

メイン関数部分です。ページの初期画面の設定をしています
page_icon = 'random'とすることで、ファビコンが毎回切り替わる遊び心を入れています。
st.file_uploaderkey設定しておかないと、後のテストデータのst.file_uploaderと競合してエラーが出ます
・データがアップロードされたら、タブに移動します

def tabs_set(data, limit_data):
    tab_titles = ["データの確認", "実験点の提案", "解析", "予測"]
    tabs = st.tabs(tab_titles)

    ss = st.session_state
    
    with tabs[0]:
        st.write('実験データの確認')
        st.write(data)
        st.write('limitデータの確認')
        st.write(limit_data)

        st.subheader('散布図')
        x_col = st.selectbox("X軸の列を選択", data.columns)
        y_col = st.selectbox("Y軸の列を選択", data.columns)
        z_col = st.selectbox("Z軸の列を選択", data.columns)
        color_col = st.selectbox("色の列を選択", data.columns)

        plot_scatter(data, x_col, y_col, z_col, color_col)

    with tabs[1]:
        if st.checkbox("実験点の提案"):
            if 'next_samples' not in ss:
                generation_sample = sample_generation(limit_data)
                BO_rsults = BO(data, generation_sample)
                next_samples = BO_rsults[0]
                BO_model = BO_rsults[1]
                autoscaled_x = BO_rsults[2]
                x = BO_rsults[3]
                
--------------------省略----------------------------------

                ss['next_samples'] = next_samples
                ss['BO_model'] = BO_model
                ss['autoscaled_x'] = autoscaled_x
                ss['x'] = x

--------------------省略----------------------------------

                col1, col2 = st.columns(2)
                
                with col1:
                        st.subheader('トレーニングデータの予測結果')
                        st.plotly_chart(y_estimated_y_plot)
                        st.write('r^2 for training data :', y_estimated_y_r2)
                        st.write('RMSE for training data :', y_estimated_y_mae)
                        st.write('MAE for training data :', y_estimated_y_rmse)
                
                with col2:
                        st.subheader('クロスバリデーションによる予測結果')
                        st.plotly_chart(y_estimated_y_in_cv_plot)
                        st.write('r^2 in cross-validation :', y_estimated_y_in_cv_r2)
                        st.write('RMSE in cross-validation :', y_estimated_y_in_cv_mae)
                        st.write('MAE in cross-validation :', y_estimated_y__in_cv_rmse)
                
                st.subheader('提案された実験点')
                st.write(next_samples)
            else:
                col1, col2 = st.columns(2)
                
                with col1:
                        st.subheader('トレーニングデータの予測結果')
                        st.plotly_chart(ss['y_estimated_y_plot'])

                with col2:
                        st.subheader('クロスバリデーションによる予測結果')
                        st.plotly_chart(ss['y_estimated_y_in_cv_plot'])

                st.subheader('提案された実験点')
                st.write(ss['next_samples'])
                
    with tabs[2]:
        if st.checkbox('SHAP解析', help='「実験点の提案」タブでモデルを構築してから'):
            if 'SHAP_plot' not in ss:

--------------------省略----------------------------------
                

tabの設定です。

  • st.tabs()でタブを作成します
  • st.session_stateをいちいち書くのがめんどくさいので、ssという変数に格納しておきます
  • それぞれのタブはwith tabs[]:で設定できます
  • tabs[0]ではアップロードされたデータの表示と散布図の作成をしています
  • tabs[1]ではガウス過程回帰モデルの構築とベイズ最適化を行います
    ①session_stateにデータが保持されていない場合(if 'next_samples' not in ss:):モデルを構築する関数を動かして、データを表示。データはsession_stateに保持しておく。
    ②session_stateにデータが保持されている場合(else:):保持されているデータを表示させます。つまり、後工程でウィジェット操作によるtabs_set()関数のリロードが発生してもモデルの計算が再び走ることなく、session_stateに保持していたデータを表示させるだけで画面の状態を維持します。
  • tabs[2]ではSHAP解析を行います。st.checkboxの引数にhelpを指定することで、チェックボックスの横にヒントを出すことができます
def SHAP_explain(BO_model, autoscaled_x, x):
    with st.spinner():
        explainer = shap.KernelExplainer(BO_model.predict, autoscaled_x)
        shap_values = explainer.shap_values(autoscaled_x)

        plt.figure()
        shap.summary_plot(shap_values, autoscaled_x, feature_names=x.columns)
        plt.savefig('shap_summary_plot.png')
    
    return 'shap_summary_plot.png'

SHAP解析の関数です。

  • st.spinner()を設置して計算してるよ感を出してます
  • SHAP解析自体はimport shapをして、st.spinner()以下二行で終わりです
  • 結果をst.pyplot(plt)で表示しようとしたらMatplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.というエラーを吐いたので、.pngで保存して、st.image()でstreamlit上に表示させています

今後の追加機能

  • 提案する実験点の数をユーザーが選べるようにする
  • 可視化バリエーション追加
  • 欠損値補完、エンコーディング
  • 化学構造対応
  • etc...

おわりに

ノートブック上では実装できていたので、アプリ化は思ったよりすんなりできました(本業の片手間で1週間くらい?)。
問題は周囲への共有方法と、実際に使ってくれるのか・・・?ということ。これはコミュ力とプレゼン力が試されますね(´Д`)

streamlitもさらに使いやすくなってきているので、ガシガシアプリ作成していきたいです。

この記事が少しでもみなさんのお役に立てれば幸いです。

それでは次の記事でお会いしましょう(=゚ω゚)ノ

51
74
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
51
74