LoginSignup
5
11

ローカルLLMを使った全部盛り(streaming, RAG, Streamlit, ...)の作り方

Last updated at Posted at 2024-05-19

 前回の投稿で一部しか紹介出来ておらず、詳細は別途記事にする予定がだいぶ遅くなってしまいました。今回はStreaming, RAG, Streamlitを使って、アップロードしたPDFファイルの内容に関する質問に、ストリーミング形式で回答してくれるチャットボットの作り方を紹介をさせてもらいます。
各個別の機能の詳細は既にQiitaを始め多くの記事で紹介されていますので、本記事ではちゃんと動かせることに焦点を絞り、全体像を掴んでもらえることが出来ればと考えてます。

1. 前提となる環境
・Python 3.9以上(ただしGPUで使用する場合はtorchの関係で3.11を推奨)。
・必要なパッケージ: Pythonのバージョンに合わせて下記を修正して下さい。

Python3.11又は3.12の場合
chromadb==0.5.0
langchain==0.1.20
langchain_community==0.0.38
langchain_text_splitters==0.0.1
llama-index==0.9.34
pypdf[crypto]
streamlit
streamlit-chat
sentence-transformers
#Embedding modelにLUKEモデルを使う場合
sentencepiece
#Embedding modelにBERTモデルを使う場合は下記2パッケージを追加
#fugashi
#ipadic

・Ollamaをインストール: Win/Macは公式サイトからアプリをインストール。linuxは次の通りです。

curl -fsSL https://ollama.com/install.sh | sh

 必要なPython系の環境は以上ですが、使用するローカルLLMはollamaフォルダ傘下に./modelsフォルダを作成し、そこにダウンロードしておきます。その後ollamaフォルダ直下に使用する各LLMモデル用のModelfile_XXXを作成し、下記のような設定をしておきます(一例です)。

FROM ./models/XXX-Q4_K_M.gguf
PARAMETER temperature 0.3
PARAMETER num_ctx 2048
PARAMETER num_predict -1
PARAMETER repeat_last_n -1
PARAMETER repeat_penalty 2.1
PARAMETER top_k 10
PARAMETER top_p 0.5
PARAMETER tfs_z 2.0
PARAMETER num_gpu 5 #CPUの場合は削除

詳細は下記リンクを参照して下さい。
 リンク:OllamaのModelfileのパラメータ 

2. 各機能のコード
 まずは各モジュールを読み込んで、Streamlitのタイトルを設定します。

main.py (1)
import os, tempfile
import json
from pathlib import Path
import streamlit as st
st.title("Your Page Title Here")

from typing import Any, List #コメント頂き追加しました
from llama_index import (
   download_loader,
   VectorStoreIndex,
   ServiceContext,
   StorageContext,
   SimpleDirectoryReader,
)
from llama_index.postprocessor import SentenceEmbeddingOptimizer
from llama_index.prompts.prompts import QuestionAnswerPrompt
from llama_index.readers import WikipediaReader, Document
from langchain_community.chat_models import ChatOllama
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.callbacks.base import BaseCallbackHandler
#from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import CharacterTextSplitter
import chromadb
from llama_index.vector_stores import ChromaVectorStore

 langchainやllamaindexはバージョンアップに伴いインポート元が頻繁に変更になるので、上記コードはあくまでも、上記パッケージバージョンでの前提です(langchainやllamaindexはバージョンを固定しておいた方が良いと思います)。
次は今回一番苦労したストリーミングの処理ですが、独自のコールバックハンドラーを用意します(後から知ったのですがLangchainにはStreamlitCallbackHandlerが用意されているので、こちらを使うのも良いかもしれません)。これをmain.py(4)の初期化時のself.llm設定で、ChatOllamaのパラメータとしてcallbacks=[stream_handler]で呼び出します。

main.py (2)
class StreamHandler(BaseCallbackHandler):
   def __init__(self, initial_text="wait for moment"):
       self.initial_text = initial_text
       self.text = initial_text
       self.flag = True
   def on_llm_start(self, *args: Any, **kwargs: Any):
       self.text = self.initial_text
       with st.chat_message("assistant"):
           self.container = st.empty()
       self.container.markdown(self.text+"  "+"  ")
   def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
       if self.flag == True:
           print("Stream start: ",datetime.now())
           self.flag = False
       self.text += token
       self.container.markdown(self.text)
   def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
       st.session_state.messages.append({
           "role": "assistant",
           "content": self.text
       })

次はPDFファイルの読み込みClassを作成します。

main.py (3)
class PDFReader:
    def __init__(self):
        self.pdf_reader = download_loader("PDFReader", custom_path="local_dir")()
    def load_data(self, file_name):
        return self.pdf_reader.load_data(file=Path(file_name))

 次がローカルLLM及び埋め込みLLMを設定し、pdfファイルを読み込んでベクトルインデックスを生成してChromaDBに保存するとともに、このインデックスデータを使ってLLMに推論させるRAGの処理を行う、主要なClassになります。self.QA_PROMPT_TMPLは、各モデルに応じて適切なプロンプトを設定して下さい。また、下記は埋め込みモデルをCPUで動作させる前提の記載なので、GPUを使用したい場合はtorchをインストールして、設定を変更して下さい。また、モデル選択対象のモデル名("XXX", "YYY", "ZZZ")は、ollama createで作成した際のモデル名と一致させて下さい。

main.py (4)
ollama_url = "http://localhost:11434"
class QAResponseGenerator:
   def __init__(self, selected_model, pdf_reader):
       stream_handler = StreamHandler()
       self.llm = ChatOllama(base_url=ollama_url, model=selected_model, streaming=True, callbacks=[stream_handler], verbose=True)
       self.pdf_reader = pdf_reader
       self.QA_PROMPT_TMPL ='### 指示:{query_str}\n ### 応答:' #一例です
       self.device = "cpu"
       self.embed_model = HuggingFaceEmbeddings(model_name="Your embedding model from HF", model_kwargs={"device": self.device})
       self.service_context = ServiceContext.from_defaults(llm=self.llm, embed_model=self.embed_model)

   def generate(self, question, file_name, uploaded_file):
       try:
           db2 = chromadb.PersistentClient(path="./chroma_db")
           chroma_collection = db2.get_collection(file_name)
           vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
           pdf_index = VectorStoreIndex.from_vector_store(
               vector_store,
               service_context=self.service_context,
           )
       except: #FileNotFoundError
           pdf_documents = self.pdf_reader.load_data(file_name)
           db = chromadb.PersistentClient(path="./chroma_db")
           chroma_collection = db.get_or_create_collection(
               name=uploaded_file.name,
               metadata={"hnsw:space": "cosine"}
           )
           vector_store = ChromaVectorStore(
               chroma_collection=chroma_collection,
           )
           storage_context = StorageContext.from_defaults(vector_store=vector_store)
           pdf_index = VectorStoreIndex.from_documents(
               pdf_documents, storage_context=storage_context, service_context=self.service_context
           )
           pdf_engine = pdf_index.as_query_engine(
           similarity_top_k=2,
           text_qa_template=QuestionAnswerPrompt(self.QA_PROMPT_TMPL),
           node_postprocessors=[SentenceEmbeddingOptimizer(embed_model=self.service_context.embed_model, threshold_cutoff=0.4)],
       )
           try:
               pdf_result = pdf_engine.query(question)
               pdf_result.get_formatted_sources(1000)
               doc_to_update = chroma_collection.get()
               return pdf_result.response
           except ValueError as e:
               print("PDF Error: ",e)

PDFファイルのアップロードと保存処理部分です。

main.py (5)
def save_uploaded_file(uploaded_file, save_dir):
    try:
        with open(os.path.join(save_dir, uploaded_file.name), "wb") as f:
            f.write(uploaded_file.getvalue())
        return True
    except Exception as e:
        st.error(f"Error: {e}")
        return False

def upload_pdf_file():
    uploaded_file = st.sidebar.file_uploader("upload file", type=["pdf", "txt"])
    print("uploaded_file",uploaded_file)
    if uploaded_file is not None:
        st.success(f"{uploaded_file.name} has been uploaded")
        return uploaded_file

最後はStremlitのmain処理部分になります。

main.py (6)
def main():
    pdf_reader = PDFReader()
    uploaded_file = st.sidebar.file_uploader("upload file", type=["pdf"])
    if uploaded_file is not None:
        st.sidebar.success(f"{uploaded_file.name} has been uploaded")
        with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
            tmp_file.write(uploaded_file.read())
    selected_model = st.sidebar.selectbox("select model", ["XXX", "YYY", "ZZZ"])
    question = st.text_input("質問入力")
    response_generator = QAResponseGenerator(selected_model, pdf_reader)

    submit_question = st.button("質問")
    clear_chat = st.sidebar.button("履歴消去")
    st.session_state.last_updated = ""
    st.session_state.last_updated_json = []

    # save history
    if "chat_history" not in st.session_state:
        st.session_state["chat_history"] = []

    if clear_chat:
        st.session_state["chat_history"] = []
        st.session_state.last_updated = ""
        st.session_state.last_updated_json = []

    if submit_question:
        print("pushed question button!")
        if question:
            response = response_generator.generate(question, tmp_file.name, uploaded_file)
            st.session_state["chat_history"].append({"user": question})
            st.session_state["chat_history"].append({"assistant": response})
            st.session_state.last_updated += json.dumps(st.session_state["chat_history"],indent=2, ensure_ascii=False)
            with open("./history/chat_history.txt","w") as o:
                print(st.session_state.last_updated,sep="",file=o)
            with open("./history/chat_history.txt", "r") as f:
                history_str = f.read()
                history_json = json.loads(history_str)
                #print(history_json)
                with open("./history/chat_history.json", "w") as o:
                    json.dump(history_json, o, ensure_ascii=False)

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        if "'tmp_file' referenced before assignment" in str(e) :
            st.success("ファイルがアップロードされてません。先にアップロードして下さい。")
        else:
            st.exception(f"An error occurred: {e}")

 以上でRAGを使いアップロードしたPDFファイルに関するQ&Aをストリーミングのチャット形式でやり取りすることが可能になります。

3. まとめと関連情報
 今回はGPU無しかApple SiliconレベルのGPUでお手軽にRAGを試すコードを紹介しました。はまりどころが多いので、一度動くコードを作成してから、それを土台にいろいろ手を加えて行くのが良いと思いました(ダメなら土台に戻れば良いので)。そういった観点でお役に立てば幸いです。
 当初は「全部盛り」なので、RAGに検索機能を付加したり、要約文や会話履歴をフィードバックする機能も掲載する予定でしたが、今回は結構長くなってしまうので端折らせてもらいました。
 最後に役に立つかもしれないと思われる関連情報を共有します。
① Streamlitは作成したWebサイトを無料で公開してくれる コミュニティクラウドサービス があります。無料なので今回のようなリソースを食い潰すサイトは頻繁にクラッシュしますが、軽めのサイトであれば簡単にインターネットに公開できます。
② langchainとllamaindexは、どちらも提供している機能が多くあり選択に迷うことが多いですが、どちらかにしか無い必須機能を持っている方に極力片寄した方が、後々のメンテナンスが楽だと思います。
③ ChromaDBを用いたテキストの類似度比較の手法には、上記コードにある"cosine"以外にも"l2"や"ip"が設定可能です。詳細は下記リンクを参照して下さい。
 リンク:ChromaDBの類似度距離関数の変更 

参考にさせてもらったサイトは下記の通りです。
 参考1: LangchainをStreamlit上で文字ストリーミングする方法 
 参考2: PDF Reader chatbot using langchain and open ai in 15 mins 
 参考3: 意外と簡単!GPTを使ってPDFとQ&Aできるアプリを作ってみた 

5
11
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
11