LoginSignup
2
2

PDFと対話しよう。 LangChainとStreamlitを使ったChatGPTアプリの作成

Last updated at Posted at 2023-07-31

概要

ここでは、ChatGPT APIを活用して、ChatGPTをはじめてとする大規模言語モデル(LLM)を利用したアプリケーションの開発を支援するのに多くの方が利用しているLangChainと、Webアプリを容易に作成・共有できるPythonベースのOSSフレームワークであるStreamlitを用いた、PDFと対話するアプリを作成します。

Screenshot from 2023-07-29 17-09-32.png

ML_Bearさんが非常によいテキストを公開されています。是非、こちらの記事を参考にChatGPTのアプリ作成を楽しんでください。(私のコードは、ML_Bearさんの記事を参考にさせていただきました。ありがとうございます。)

はじめに

2022年末にChatGPTがリリースされ、2023年3月にChatGPT APIの公開、その後続々といろいろなLLMがリリースされ、またそれらを容易に利用するためのLangChainやLlamaIndexなどのモジュールがリリースされるなど、手を付けるてもすぐに技術が更新されるので、どこから手をつけてよいのかわかりませんでした。
そんなとき、ML_Bearさんの記事に出会い、多くの方が試されているChatGPT APIによるアプリを作成すれば、次のステップにつながるかな、と思い学習しました。
私が興味あったのは、個々の文書(PDF等)の書かれている内容を理解をサポートするためのツールであり、ちょうど「PDFに質問しよう」がこれに該当しましたのでよいタイミングでした。 こちらの記事の最後に、”「返ってきた質問に対してさらに質問できないの?」と思われた方もいらっしゃるでしょう。(中略)(こちらの機能については後日追記するかもしれません)”とあり、公開される前に自分でつくり”答え合わせをしよう”と思ったのがきっかけです。
OpenAアカウント、 Streamlit, LangchainおよびQdrantの環境設定については、同じくML_Bearさんの記事をご参考ください。丁寧に書かれていますので、理解しやすいと思います。
前置きはこれぐらいにし、コードについて説明します。コード全体については、以下をご参考ください。

完全コード
# pip install pycryptodome, pymupdf
from glob import glob
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.callbacks import get_openai_callback
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.callbacks.base import BaseCallbackHandler

from PyPDF2 import PdfReader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Qdrant
from langchain.chains import RetrievalQA
from langchain.chains import ConversationalRetrievalChain
from langchain.schema import (
    SystemMessage,
    HumanMessage,
    AIMessage
)
from langchain.callbacks import get_openai_callback

from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

import json

#QDRANT_PATH = "./local_qdrant"
COLLECTION_NAME = "my_collection_2"

import os
os.environ["OPENAI_API_KEY"] =  "sk-**************************************"
os.environ['QDRANT_CLOUD_ENDPOINT'] = "https://***********************************"
os.environ['QDRANT_CLOUD_API_KEY'] = "***********************************************"

def init_page():
    st.set_page_config(
        page_title="Ask My PDF(s)",
        page_icon="🤗"
    )
    st.sidebar.title("Nav")
    st.session_state.costs = []


def select_model():
    model = st.sidebar.radio("Choose a model:", ( "GPT-3.5-16k", "GPT-4", "GPT-3.5"))
    if model == "GPT-3.5":
        st.session_state.model_name = "gpt-3.5-turbo"
    elif model == "GPT-3.5-16k":
        st.session_state.model_name = "gpt-3.5-turbo-16k"
    else:
        st.session_state.model_name = "gpt-4"
    
    # 300: 本文以外の指示のトークン数 (以下同じ)
    st.session_state.max_token = OpenAI.modelname_to_contextsize(st.session_state.model_name) - 300
    return ChatOpenAI(temperature=0, model_name=st.session_state.model_name)

def init_messages():
    clear_button = st.sidebar.button("Clear Conversation", key="clear")
    if clear_button or "messages" not in st.session_state:
        st.session_state.messages = [
            SystemMessage(content="You are a helpful assistant.")
        ]
        st.session_state.costs = []

        # Init memory
        st.session_state.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True,
        )

def get_pdf_text(uploaded_file):
    if uploaded_file:
        pdf_reader = PdfReader(uploaded_file)
        text = '\n\n'.join([page.extract_text() for page in pdf_reader.pages])
        text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
            model_name="text-embedding-ada-002",
            # 適切な chunk size は質問対象のPDFによって変わるため調整が必要
            # 大きくしすぎると質問回答時に色々な箇所の情報を参照することができない
            # 逆に小さすぎると一つのchunkに十分なサイズの文脈が入らない
            chunk_size=500,
            chunk_overlap=0,
        )
        return text_splitter.split_text(text)
    else:
        return None


def load_qdrant():
    #client = QdrantClient(path=QDRANT_PATH)

    # 以前こう書いていたところ: client = QdrantClient(path=QDRANT_PATH)
    # url, api_key は Qdrant Cloud から取得する
    client = QdrantClient(
        url=os.environ['QDRANT_CLOUD_ENDPOINT'],
        api_key=os.environ['QDRANT_CLOUD_API_KEY']
    )

    # すべてのコレクション名を取得
    collections = client.get_collections().collections
    collection_names = [collection.name for collection in collections]

    # コレクションが存在しなければ作成
    if COLLECTION_NAME not in collection_names:
        # コレクションが存在しない場合、新しく作成します
        client.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
        )
        print('collection created')

    return Qdrant(
        client=client,
        collection_name=COLLECTION_NAME, 
        embeddings=OpenAIEmbeddings()
    )


def build_vector_store(pdf_text):
    qdrant = load_qdrant()
    qdrant.add_texts(pdf_text)

    # 以下のようにもできる。この場合は毎回ベクトルDBが初期化される
    # LangChain の Document Loader を利用した場合は `from_documents` にする
    # Qdrant.from_texts(
    #     pdf_text,
    #     OpenAIEmbeddings(),
    #     path="./local_qdrant",
    #     collection_name="my_documents",
    # )


def build_qa_model(llm, memory):
    qdrant = load_qdrant()
    retriever = qdrant.as_retriever(
        # "mmr",  "similarity_score_threshold" などもある
        search_type="similarity",
        # 文書を何個取得するか (default: 4)
        search_kwargs={"k":4}
    )

    return ConversationalRetrievalChain.from_llm(
       llm, retriever=retriever, memory=memory, verbose=True
    )

    

def page_pdf_upload_and_build_vector_db():
    # PDFファイルをアップロードする欄を作ります。
    uploaded_file = st.sidebar.file_uploader(
            label='Upload your PDF here😇',
            type='pdf'
        )
    if not uploaded_file:
            st.info("PDFファイルをアップロードしてください")
            st.stop()

    container = st.container()
    with container:
        pdf_text = get_pdf_text(uploaded_file)
        if pdf_text:
            with st.spinner("Loading PDF ..."""):
                build_vector_store(pdf_text)


def ask(qa, query):
    with get_openai_callback() as cb:
        # query / result / source_documents
        answer = qa.run(query)

    return answer, cb.total_cost


def page_ask_my_pdf():
    #st.title("Ask My PDF(s)")

    llm = select_model()

    # 会話のコンテキストを管理するメモリを設定する
    memory = st.session_state.memory 

    qa_chain = build_qa_model(llm, memory)

    # ユーザーの入力を監視
    if user_input := st.chat_input("聞きたいことを入力してね!"):
        print(user_input)
        #print(st.session_state.messages)
        with st.spinner("ChatGPT is typing ..."""):
            #answer = qa_chain.run(user_input)
            answer, cost = ask(qa_chain, user_input)
            st.session_state.costs.append(cost)

        #memory.chat_memory.add_user_message(user_input)
        #memory.chat_memory.add_ai_message(answer)

        st.session_state.messages.append(HumanMessage(content=user_input))
        st.session_state.messages.append(AIMessage(content=answer))

        messages = st.session_state.get('messages', [])
        for message in messages:
            if isinstance(message, AIMessage):
                with st.chat_message('assistant'):
                    st.markdown(message.content)
            elif isinstance(message, HumanMessage):
                with st.chat_message('user'):
                    st.markdown(message.content)
            else:  # isinstance(message, SystemMessage):
                st.write(f"System message: {message.content}")

        #メモリの内容をTeaminalで確認します。
        print(memory)

# メッセージオブジェクトを辞書に変換する関数
def message_to_dict(message):
    return {
        'type': type(message).__name__,
        'content': message.content,
    }


def main():
    init_page()
    init_messages()

    st.title("PDFと対話しよう!")

    page_pdf_upload_and_build_vector_db()

    page_ask_my_pdf()

    costs = st.session_state.get('costs', [])
    st.sidebar.markdown("## Costs")
    st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
    for cost in costs:
        st.sidebar.markdown(f"- ${cost:.5f}")


    messages_as_dicts = [message_to_dict(message) for message in st.session_state.messages]

    # ensure_ascii=Falseを設定
    messages_str = json.dumps(messages_as_dicts, ensure_ascii=False)

    st.download_button(
        label="Download messages",
        data=messages_str.encode(),  # bytesに変換
        file_name='messages.json',
        mime='application/json',
    )

if __name__ == '__main__':
    main()

ConversationRetrievalChainによる会話の記録

Chat GPTの対話は、ConvesationRetrievalChainの記憶機能を用います。そのため、ML_Bearさんの”PDFに質問しよう”のRetrievqalQAの部分を以下に変更します。

def build_qa_model(llm, memory):
    qdrant = load_qdrant()
    retriever = qdrant.as_retriever(
        # "mmr",  "similarity_score_threshold" などもある
        search_type="similarity",
        # 文書を何個取得するか (default: 4)
        search_kwargs={"k":4}
    )

    return ConversationalRetrievalChain.from_llm(
       llm, retriever=retriever, memory=memory, verbose=True
    )

多くの部分は同じですが、RetievalQAをConvesationRetrievalChainを用いるため更新しています。
ConversationRetrievalChainにはmemoryが引数であり、ここにこれまでの会話を記録します。
今回はmemoryを用いますので、初期化の関数にmemoryの初期化も追加します。

def init_messages():
    clear_button = st.sidebar.button("Clear Conversation", key="clear")
    if clear_button or "messages" not in st.session_state:
        st.session_state.messages = [
            SystemMessage(content="You are a helpful assistant.")
        ]
        st.session_state.costs = []

        # Init memory
        st.session_state.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True,
        )

これでmemory機能で会話を記録する準備はできました。次に、実際に会話を記憶していきます。
memory機能についてはこちらの記事にあるとおり、memoryにHumanMessageとAIMessageのタイトルで辞書形式で保存されます。会話を実行する度に過去の会話履歴をすべてつかって会話を行っています。
今回は、ML_Bearさんの記事と同じくStreamlitのsessaion_stateを用いて会話を記録します。
会話の記録のため実行部分をいかに更新しました。

def page_ask_my_pdf():
    #st.title("Ask My PDF(s)")

    llm = select_model()

    # 会話のコンテキストを管理するメモリを設定する
    memory = st.session_state.memory 

    qa_chain = build_qa_model(llm, memory)

    # ユーザーの入力を監視
    if user_input := st.chat_input("聞きたいことを入力してね!"):
        with st.spinner("ChatGPT is typing ... """):
            answer, cost = ask(qa_chain, user_input)
            st.session_state.costs.append(cost)

        st.session_state.messages.append(HumanMessage(content=user_input))
        st.session_state.messages.append(AIMessage(content=answer))

        messages = st.session_state.get('messages', [])
        for message in messages:
            if isinstance(message, AIMessage):
                with st.chat_message('assistant'):
                    st.markdown(message.content)
            elif isinstance(message, HumanMessage):
                with st.chat_message('user'):
                    st.markdown(message.content)
            else:  # isinstance(message, SystemMessage):
                st.write(f"System message: {message.content}")

        #メモリの内容をTeaminalで確認します。
        print(memory)

最初にmemoryを宣言します。こちらは空のdicとなります。その後、ChatGPTとの対話での質問と回答をst.session_state.messages.appendにて追加記録します。これは”PDFに質問しよう”と同じですね。
最後にChatGPTとの対話がどのように記録されいるのか、ローカルで確認するために出力させます。確認用ですので、アプリとしては不要です。実行させるかどうかはお好みとしてください。
これで、任意のPDFファイルとの対話するChatGPTアプリの設定ができました。対話を記録したいときのための、簡単なモジュールを作りました。json形式で保存されるため、よろしければお使いください。

# メッセージオブジェクトを辞書に変換する関数
def message_to_dict(message):
    return {
        'type': type(message).__name__,
        'content': message.content,
    }

def main():
    init_page()
    init_messages()

    st.title("PDFと対話しよう!")

    page_pdf_upload_and_build_vector_db()

    page_ask_my_pdf()

    costs = st.session_state.get('costs', [])
    st.sidebar.markdown("## Costs")
    st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
    for cost in costs:
        st.sidebar.markdown(f"- ${cost:.5f}")


    messages_as_dicts = [message_to_dict(message) for message in st.session_state.messages]

    # ensure_ascii=Falseを設定
    messages_str = json.dumps(messages_as_dicts, ensure_ascii=False)

    st.download_button(
        label="Download messages",
        data=messages_str.encode(),  # bytesに変換
        file_name='messages.json',
        mime='application/json',
    )

アプリのUIは、ML_Bearさんの”PDFに質問しよう”とは異なり1画面構成としました。こちらの方が、自分が何のPDFファイルと対話しているのかファイル名を確認することができます。こちらも好みですね。対話履歴を消去したいときは、”Clear Conversation”を実行してください。

最後に

LLMの技術進歩がとても速く、何から手をつけるのがよいのか迷っていましたが、ML_Bearさんの記事をみて、まずはベーシックにChatGPTをLangchainで動かすことから始めようと思い、こちらのアプリに挑戦しました。
今後は、利用者の質問力によらず知りたいことを自動で抽出するためにはどうしたらよいのか、LangChainのAgent機能が有効と思いますので、こちらを学んでいきます。アウトプットがまとまりましたら、記事にして公開する予定です。
こちらの記事がみなさまのお役に立てましたら非常に嬉しいです。

参考記事

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2