LoginSignup
15
7

More than 1 year has passed since last update.

【streamlit】cacheとsession_stateを使ってウィジェットを管理しよう

Last updated at Posted at 2021-12-07

前回記事に引き続きstreamlit関連で、ウィジェットの状態管理についてです。

ウィジェットの値変更の度に再実行が入るstreamlit。直感的で非常に分かりやすいのですが私は変数の値の管理で何回もつまづいています。

そんな時、cacheとsession_stateを使ってだいたい解決出来るケースが多いと思いましたので、実際に自分が遭遇したユースケースのひとつを記事にします。

またまた限定的な使い道になってしまいそうですが、同じような所で妙な拘りが出てしまう同志も居るかもしれませんのであしからず。。。

まずはcacheの話

stereamlitと言えばcacheというぐらい色んな所で注目機能として扱われている気がしますが、軽く触れておきますとこのcacheは読んで字の如く関数にキャッシュ機能を持たせるものです。
(※cache:貯蔵庫、隠し場所)
 

下記のようにデコレータとしてst.cacheを付けた関数は、1回目の呼び出しでの戻り値が内部的に記憶され、2回目以降同様の関数を呼び出す際には関数実行せずに、記憶された値をそのまま戻り値として返すという処理が行われます。

cache_test.py
import streamlit as st
import pandas as pd

@st.cache
def read_file(file):
    df=pd.read_csv(file)
    return df

st.sidebar.title("cache_test")
upload_file=st.sidebar.file_uploader("upload")

if upload_file != None:
    upload_data=read_file(upload_file)
    st.dataframe(upload_data)

試しに下記のように読み込んだデータのアプリ内表示をcache関数内で行ってみます。

初めてファイルをアップロードした時はデータフレームがアプリ内に表示されると思いますが、テスト用に設置したreloadボタンを押すとデータフレームの表示が消えるかと思います(代わりにcacheに関する警告文が出てきます)。

cache_test.py
import streamlit as st
import pandas as pd

@st.cache
def read_file(file):
    df=pd.read_csv(file)
    st.dataframe(df)
    return df

st.sidebar.title("cache_test")
upload_file=st.sidebar.file_uploader("upload")

if upload_file != None:
    upload_data=read_file(upload_file)

st.button("reload")

1回目の関数実行で保存されたcache値にヒットした結果、2回目は関数実行がすっ飛ばされ、値の受け渡しだけが行われたというのが分かると思います。

このcacheを適切に用いることで無駄な読み込みを防ぎ、アプリの動作軽減などが出来るという寸法です。
 
関数内部も一部分だけ見てみましょう。

caching.py
try:
    return_value = _read_from_cache(
                         mem_cache=mem_cache,
                         key=value_key,
                         persist=persist,
                         allow_output_mutation=allow_output_mutation,
                         func_or_code=func,
                         hash_funcs=hash_funcs,
    )
    _LOGGER.debug("Cache hit: %s", func)

except CacheKeyNotFoundError:
    _LOGGER.debug("Cache miss: %s", func)

    with _calling_cached_function(func):
        if suppress_st_warning:
            with suppress_cached_st_function_warning():
                return_value = func(*args, **kwargs)
        else:
            return_value = func(*args, **kwargs)

全部理解は出来ていないのですが、キャッシュの読み込みを行ってみて、キャッシュミスが起きなければキャッシュ内の値を関数の戻り値として格納、キャッシュミスが起きればそのまま関数実行という流れになっているのが見て取れます。

 
実際のソースコードではこのあとに_write_to_cacheなる関数でキャッシュ値やキーの保存を行っているようです。

本題

ここから先ほどのcacheの説明を踏まえた実際のユースケースです。

前述の通り、streamlitでは何らかのウィジェットの値が変更される度にアプリが再実行されるため、例えばボタンウィジェットの押下を契機にセレクトボックスが出現するような動作を入れようとすると、セレクトボックスの値を変更する度にボックスが消えてしまう問題は巷でよく聞かれると思います。

↓こんなやつです

cache.py
import streamlit as st

st.sidebar.title("cache_test")

button_return = st.button("this is button")
if button_return:
    st.selectbox("this is select", ["test1","test2","test3"])

これに対して、例えばこちらのstackoverflowを参考にcacheを用いて押下状態を保持するボタンを作成したとします。

確かにボタンの押下状態は保持されるようになりました。

これで万事解決めでたしめでたしと行きたいのですが、自分の場合は下記のようにボタン押下前に何かを設定するセレクトボックスを設置したいという欲にかられる場面が出てきました。

画面上にも記載しましたが、ボタンを一度押下したら押下状態を保持するようにはしたいけど、特定のタイミングでリセットしたいという状況も稀にあるかと思います。

今回はそんな状況のための記事です。

◆解決法1 - st.sesssion_stateを用いる

前回記事で記載したst.session_stateです。
さっさくコードを下記します。

cache.py
import streamlit as st

@st.cache(allow_output_mutation=True)
def button_states():
    return {"pressed": None}

def session_change():
    if "is_pressed" in st.session_state:
        st.session_state["is_pressed"].update({"pressed": None})

select_button = st.selectbox("解析設定",["設定1","設定2","設定3"],
                             on_change=session_change)
st.markdown("##### ↑これの値が変わったら、再度ボタンはOFFにしたくありませんか??")
press_button = st.button("解析開始")
st.session_state["is_pressed"] = button_states()

if press_button:
    st.session_state["is_pressed"].update({"pressed": True})

if st.session_state["is_pressed"]["pressed"]:
    th = st.number_input("解析開始ボタン押下後に出てくるウィジェット")

◆解決法2 - cache_clearを使う

cache関数の中身を見ていたらcache_clearという如何にもなものが見つかりましたので、試してみたところ無事にクリア出来ました。
※モジュール名がlegacy_cachingとなっていたので最新のものは別にあるのかもしれません・・・?

cache.py
import streamlit as st

@st.cache(allow_output_mutation=True)
def button_states():
    return {"pressed": None}

def cache_clear():
    st.legacy_caching.caching.clear_cache()

select_button = st.selectbox("解析設定",["設定1","設定2","設定3"],
                             on_change=cache_clear)
st.markdown("##### 選択変えたらボタンOFFになります")
press_button = st.button("解析開始")
is_pressed = button_states()

if press_button:
    is_pressed.update({"pressed": True})

if is_pressed["pressed"]:
    th = st.number_input("解析開始ボタン押下後に出てくるウィジェット")

このcache_clearを使うと全てのキャッシュがクリアされてしまうので、特定のウィジェットだけに作用させたい場合などはまたやり方を考えなければいけませんね。

バージョン情報

・streamlit 2.0

15
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
7