LoginSignup
4
3

Streamlit & LangChain & ConversationalRetrievalChain でメモリの内容を覚える

Last updated at Posted at 2023-05-15

LangChain & ConversationalRetrievalChain & langchain.memory を使うと、簡単に過去の会話を覚えてくれているチャットボット的なものを作れるのですが、Streamlitを使用した場合、submitを押すとメモリの中身を忘れて、想定していた動きになりません。

すぐに風化してしまいそうな気もしますが、これの対応方法がネット見ても見当たらなかったので書いておきます。

元のコード

CLI等で下記のように動かすと、問題なくchat_historyをつけて自動的に過去に送信したメッセージを付けてくれます(下記の例はFAISSに独自のデータを入れていますが、違うものを使っていれば、その辺りは適宜読み替えてください)。いやぁ簡単。

が、これをそのままStreamlitで実装すると、POSTを押すたびにmemoryが初期化されて思ったように動いてくれません。。

def chat_chain(query):
    embeddings = OpenAIEmbeddings()
    db = FAISS.load_local("faiss_index", embeddings)
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    qa = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(temperature=0), db.as_retriever(), memory=memory
    )
    chat = qa({"question": query})
    return chat

def main():
    while True:
        query = input("質問を入力してください: ")
        chat = chat_chain(query)
        print(chat)

公式ドキュメント
https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html

こんな風にすると動く

@st.cache_resource(ver 1.18.1以上で使えます)がポイントぽいのは検索してすぐに分かったのですが、

@st.cache_resource
def chat_chain(query):
   ...

とすると、queryを含んだものをキャッシュするのでqueryが毎回違う今回の用途では動いてくれません。

下記のようにすれば思ったように動いてくれました。

@st.cache_resource
def save_memory()
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    return memory

def chat_chain(query):
    embeddings = OpenAIEmbeddings()
    db = FAISS.load_local("faiss_index", embeddings)

    # キャッシュされたメモリに保存する
    memory = save_memory()

    qa = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(temperature=0), db.as_retriever(), memory=memory
    )
    chat = qa({"question": query})
    return chat

def main():
    query = st.text_input("質問を入力してください: ")
    submit_btn = st.button("送信")
    if submit_btn:
        chat = chat_chain(query)
        st.write(chat)

下記は試していませんが、上記の場合は複数ユーザで使用した場合も一律覚えてしまうので、実際にはユーザIDやセッションID単位などでキャッシュさせると良い気がします。

@st.cache_resource
def save_memory(userid)
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    return memory

def chat_chain(query):
    ....

    userid = 123

    # userid単位でキャッシュされたメモリに保存する
    memory = save_memory(userid)

    ....

st.session_stateを使う?

おそらく、st.session_stateを使ってhistoryをためていく方法もありそうなのですが、どうやれば良いかが分かりませんでした(挑戦はしてみたのですが、上記と同じ動きにならない)。上記の例で、具体的にどう書けば良いか分かる方いらっしゃいましたら、コメントください。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3