こんにちは。Qiita初投稿です。
はじめに
PythonコードだけでWebアプリを作れるStreamlitと
数行のコードで機械学習の前処理から推定までできるPyCaretというものがあります。
Streamlit上でPyCaretを動かす方法を調べると、ライブラリをイジる方法しかヒットせず、少しハードルが高かったのでライブラリをイジらずに動かす方法を試していこうと思います。
環境
Python 3.7
Streamlit 1.5.1
PyCaret 2.3.6
手順
クラス分類や回帰など問題ごとにライブラリが分かれているので、使用するものをimportします。
今回は回帰で試していきます。
from pycaret.regression import *
StreamlitとDataFrameも扱うので、そちらもimportしておきます。
import pandas as pd
import streamlit as st
データの取得
今回はボストンの住宅価格データセットを使います。
※現在、ボストン住宅価格データセットは倫理的な問題が指摘されておりカリフォルニア住宅価格データセットなど他データセットの使用が推奨されています。(追記:2022/10/27)
from pycaret.datasets import get_data
data = get_data("boston")
前処理
前処理はsetup関数
を使います。
デフォルトではデータの型推定が正しいか入力を求められますが、Streamlitでは入力を返せないので無効にします(html=False, silent=True)。
pipe = setup(data, target="medv", html=False, silent=True)
モデル比較
PyCaretでは学習に使うモデルを選択することができます。
それぞれのモデルの比較はcompare_models関数
で実行できます。
実行結果をpandasのDataFrameで取り出し、表示させます。
best = compare_models() # モデル比較
best_model_results = pull() # 比較結果の取得
st.write(best_model_results) # 比較結果の表示
デフォルトでは決定係数R2のスコアが良い順にソートされて出力されます。
モデル作成
モデル作成はcreate_model関数
を使用します。
モデル比較で表示されるモデル名(name)と、関数の引数で指定するモデル名(ID)が異なります。
(いつからからモデル比較結果のIndexにモデルIDが表示されるようになりました)
model = create_model("et") # Extra Trees Regressorを使用
複数のモデルでブレンドモデルやアンサンブルモデルを作成することもできます。
et = create_model("et") # Extra Trees Regressorを使用
lightgbm = create_model("lightgbm") # Light Gradient Boosting Machine
model = blend_models(estimator_list=[et, lightgbm])
ensembled_model = ensemble_model(model)
モデル作成は特にStreamlit向けに調整する事項はありません。
モデルの可視化
モデルの可視化はplot_model関数
を使用します。
可視化の内容は学習曲線や残差などいくつか用意されていますが、
リファレンスにもある通り、すべての可視化内容がStreamlitで表示できるようにはなっていません。
(対応していても表示に時間がかかるものもありました)
https://pycaret.readthedocs.io/en/latest/api/regression.html#pycaret.regression.plot_model
plot_model(model, plot="cooks", display_format="streamlit")
予測
予測はpredict_model関数
を使います。
特にデータを指定しなければ、setup実行時にテスト用に分割してあったデータセットの一部で予測を行います。
predictions = predict_model(model)
st.write(predictions)
予測結果はLabelというカラムに出力されます。
medv(住宅価格)を予測しましたが、何も考えずにモデルを作ったわりには、それなりの精度がでています。
st.cache
Streamlitの性質上、ユーザがUIを操作するたびにPyCaretの関数が再実行されることがあります。
@st.cacheを使って対策を取っておくと良いでしょう。
@st.cache(allow_output_mutation=True)
def create_model_cache(estimator):
return create_model(estimator)
さいごに
setupとplot_modelでそれぞれオプションを指定し、また必要に応じてst.cacheを活用すれば
Streamlit上でPyCaretを動かすことができます。
サンプルプログラムをGitHubに公開しているので、よければ試してみてください。