概要
Backlog の Wiki の内容を使って RAG (Retrieval-Augmented Generation) をしてみます。
使用言語は Python です。
主に使うパッケージは LangChain (0.2.x) と Streamlit です。
AI は OpenAI API を使いますが、OpenAI パッケージを直接使うのでなく、LangChain の機能を通して使います。
RAG とは、ChatGPT 等の生成 AI を使う時に、AI が学習していない情報 (社内資料など) を与えることで AI の知識を拡張する技術です。
参考: RAG (検索拡張生成) とは?
LangChain や OpenAI API は開発スピードが速い (頻繁に更新される) ので、この記事の内容もすぐに古くなると思います。
なるべく公式サイトへのリンクを貼るようにしたので、そちらから最新の情報を入手するようにしてください。
おおまかな手順
- Backlog API を使って Wiki の内容 (Markdown) を取得します。
- Wiki の内容をベクトル化します (VectorStore にします)。
- チャット機能を作ります。
- VectorStore を検索してから OpenAI API に問い合わせるようにします。
- 会話履歴を保存・参照するようにします。
- UI をつけて Web アプリにします。
準備するもの
- Backlog の API キー
- Backlog の「個人設定」で発行できます。
- OpenAI の API キー
- OpenAI の Dashboard で発行できます。
- 以下の環境 (括弧内は執筆時のバージョン)
- Windows 11
- Python (3.12.4)
- VSCode (1.90.2)
プロジェクトを作成する
適当なフォルダ (ここでは wiki-rag) に Python の仮想環境を作って入ります。
これ以降は仮想環境の中での作業になります。
wiki-rag> python -m venv .venv --upgrade-deps
wiki-rag> .venv\Scripts\activate
(.venv) wiki-rag> _
1. Wiki の内容を取得する
Backlog の ライブラリのページ を見ると Python 用のライブラリがあるのですが、この後の手順でインストールするパッケージと依存関係が合わなかったので、ライブラリを使わずに直に API を呼びます。
パッケージをインストールする
この手順で必要なパッケージをインストールします。
(.venv) wiki-rag> pip install python-dotenv requests tqdm
パッケージ名 | 簡単な説明 | 執筆時のバージョン |
---|---|---|
python-dotenv | .env ファイルを読み込んで環境変数にする。 | 1.0.1 |
requests | http リクエストをして結果を受け取る。 | 2.32.3 |
tqdm | コンソールに処理の進捗を表示する。 | 4.66.4 |
.env ファイルを作る
ソースコードと分離したい情報を .env ファイルに書いておきます。
BACKLOG_API_KEY=your-backlog-api-key # Backlog の API キー
BACKLOG_DOMAIN=your-company.backlog.com # Backlog のドメイン
BACKLOG_PROJ_ID=your-project-id # Backlog のプロジェクト ID
Wiki の内容を取得する関数を作る
まずコード全体を掲載して、そのあと各部の解説をします。
以下のコードを書いて retrieve_wiki()
を実行すると Wiki の内容を取得して wiki フォルダに保存します。
Wiki 全体ではなく、指定したページの子孫ページのみ取得するようにしています。
ファイル名が make_vectors_from_wiki.py となっているのは、この後の手順で VectorStore を作る処理を追加するからです。
コード全体
import os
import re
import shutil
from pathlib import Path
import requests
from dotenv import load_dotenv
from tqdm import tqdm
load_dotenv(override=True)
BACKLOG_API_KEY = os.getenv('BACKLOG_API_KEY') # Backlog の API キー
BACKLOG_DOMAIN = os.getenv('BACKLOG_DOMAIN') # Backlog のドメイン
BACKLOG_PROJ_ID = os.getenv('BACKLOG_PROJ_ID') # Backlog のプロジェクト ID
# 取得した Wiki コンテンツの保存先
wiki_dir = Path(__file__).parent / 'wiki'
# 機密情報をマスクするためのパターン
conf_masks = [
{'pattern': re.compile(r'[A-Z]:\\\S+'), 'replace': 'このパスは機密情報です。'},
{'pattern': re.compile(f'https?://{BACKLOG_DOMAIN}[^)\\s]*'), 'replace': ''},
{'pattern': re.compile(r'\([^()]*/alias/wiki/[^)]*\)'), 'replace': '()'},
]
def get_wiki_pages():
'''Backlog の Wiki ページ一覧を取得する
'''
url = f"https://{BACKLOG_DOMAIN}/api/v2/wikis"
params = {"apiKey": BACKLOG_API_KEY, "projectIdOrKey": BACKLOG_PROJ_ID}
response = requests.get(url, params=params)
return response.json()
def get_child_wiki_ids(parent_title):
'''指定した Wiki ページの子孫ページの ID を取得する
'''
ppath = parent_title + '/'
return [w['id'] for w in get_wiki_pages() if w['name'].startswith(ppath)]
def get_wiki_content(wiki_id):
'''指定した Wiki ページのコンテンツを取得する
'''
url = f"https://{BACKLOG_DOMAIN}/api/v2/wikis/{wiki_id}"
params = {"apiKey": BACKLOG_API_KEY}
response = requests.get(url, params=params)
content = response.json()['content']
# 機密情報をマスクする
for mask in conf_masks:
content = re.sub(mask['pattern'], mask['replace'], content)
return content
def retrieve_wiki(parent_title):
'''指定した Wiki ページの子孫ページのコンテンツを取得して保存する
'''
# wiki フォルダを空にする (削除して作り直す)
if wiki_dir.exists():
shutil.rmtree(wiki_dir)
wiki_dir.mkdir()
# wiki から取得したコンテンツを保存する
print('Wiki からコンテンツを取得しています...')
for wiki_id in tqdm(get_child_wiki_ids(parent_title)):
content = get_wiki_content(wiki_id)
with open(wiki_dir / f'{wiki_id}.md', 'w', encoding='utf-8') as f:
f.write(content)
コードの解説
上の方から少しずつ解説します。
load_dotenv(override=True)
BACKLOG_API_KEY = os.getenv('BACKLOG_API_KEY') # Backlog の API キー
BACKLOG_DOMAIN = os.getenv('BACKLOG_DOMAIN') # Backlog のドメイン
BACKLOG_PROJ_ID = os.getenv('BACKLOG_PROJ_ID') # Backlog のプロジェクト ID
.env ファイルの内容を環境変数に追加して、環境変数から必要な情報を Python の変数に読み込んでいます。
# 機密情報をマスクするためのパターン
conf_masks = [
{'pattern': re.compile(r'[A-Z]:\\\S+'), 'replace': 'このパスは機密情報です。'},
{'pattern': re.compile(f'https?://{BACKLOG_DOMAIN}[^)\\s]*'), 'replace': ''},
{'pattern': re.compile(r'\([^()]*/alias/wiki/[^)]*\)'), 'replace': '()'},
]
あとで Wiki の内容を読み込む時に機密情報をマスクするので、そのための正規表現パターンとそれを置き換える文字列を定義しています。
ここでは以下の情報をマスクしています。
- ローカルファイルのパス
- Backlog 内のページの URL
- Wiki 内のハイパーリンクのアドレス
OpenAI は API に送信された情報をモデルの学習に使いません (*1) が、念のため送りたくない情報は送らずに済む仕組みを作っています。
*1: 執筆時点でのヘルプセンターの情報
How your data is used to improve model performance | OpenAI Help Center
def get_wiki_pages():
'''Backlog の Wiki ページ一覧を取得する
'''
url = f"https://{BACKLOG_DOMAIN}/api/v2/wikis"
params = {"apiKey": BACKLOG_API_KEY, "projectIdOrKey": BACKLOG_PROJ_ID}
response = requests.get(url, params=params)
return response.json()
Backlog の Wiki ページ一覧を取得する API については こちら にリファレンスがあります。
JSON で返ってくるので .json()
して Python のオブジェクト (この場合はリスト) にして返しています。
def get_child_wiki_ids(parent_title):
'''指定した Wiki ページの子孫ページの ID を取得する
'''
ppath = parent_title + '/'
return [w['id'] for w in get_wiki_pages() if w['name'].startswith(ppath)]
Backlog の Wiki はページの名前が階層を表しているので、子孫ページを取得するにはページ名が 「<親ページ名>/xx」 となっているページを取得します。
先ほどの get_wiki_pages()
関数で取得したページ一覧から、リスト内包表記 で子孫ページの id を取り出しています。
def get_wiki_content(wiki_id):
'''指定した Wiki ページのコンテンツを取得する
'''
url = f"https://{BACKLOG_DOMAIN}/api/v2/wikis/{wiki_id}"
params = {"apiKey": BACKLOG_API_KEY}
response = requests.get(url, params=params)
content = response.json()['content']
# 機密情報をマスクする
for mask in conf_masks:
content = re.sub(mask['pattern'], mask['replace'], content)
return content
子孫ページの id を取り出せるようになったので、その id を使ってページのコンテンツを取得する関数を作っています。
ページのコンテンツを取得するのに こちらの API を使っています。
返ってきたデータの content
キーに Markdown のテキストが入っています。
これに上の方で定義した conf_masks
を適用して機密情報をマスクした状態にしています。
この記事では Wiki が Markdown 記法で書かれていることを想定しています。
def retrieve_wiki(parent_title):
'''指定した Wiki ページの子孫ページのコンテンツを取得して保存する
'''
# wiki フォルダを空にする (削除して作り直す)
if wiki_dir.exists():
shutil.rmtree(wiki_dir)
wiki_dir.mkdir()
# wiki から取得したコンテンツを保存する
print('Wiki からコンテンツを取得しています...')
for wiki_id in tqdm(get_child_wiki_ids(parent_title)):
content = get_wiki_content(wiki_id)
with open(wiki_dir / f'{wiki_id}.md', 'w', encoding='utf-8') as f:
f.write(content)
ここまでで作った関数を使って子孫ページのコンテンツを取得して、wiki_dir
フォルダに保存しています。
子孫ページの id のリストを tqdm
でラップすることで、for
ループの進捗を表示するようにしています。
親ページの内容も使いたければ、get_child_wiki_ids()
で取得したリストに親ページの id も追加して、そのリストを for
ループすれば良いです。
実行してみる
wiki-rag/make_vectors_from_wiki.py の最後に以下のコードを追加して実行してみます。
if __name__ == '__main__':
retrieve_wiki('親ページのタイトル')
実行するとプログレスバーで進捗が表示され、すべての子孫ページを取得します。
(.venv) wiki-rag> python .\make_vectors_from_wiki.py
Wiki からコンテンツを取得しています...
100%|███████████████████████████████████████████████████| 17/17 [00:01<00:00, 8.90it/s]
(.venv) wiki-rag> _
ちゃんと取得できているか、wiki フォルダを確認します。
(.venv) wiki-rag> ls wiki
ディレクトリ: C:\wiki-rag\wiki
Mode LastWriteTime Length Name
---- ------------- ------ ----
-a---- 2024/07/08 16:24 1944 3335633.md
-a---- 2024/07/08 16:24 2665 3335852.md
-a---- 2024/07/08 16:24 1128 3343121.md
...
2. Wiki の内容をベクトル化する
ここまでで、AI の知識を拡張するための情報を Markdown 形式で準備できるようになりました。
ここからはその情報をベクトル化するためのコードを書きます。
ベクトル化とは簡単にいうと、文書を数値で表して、内容が似ているものは数値が近くなるようにすることです。
これにより、質問に関係ありそうな文書だけを選んで OpenAI API に送信できるようになります。
パッケージをインストールする
この手順で必要なパッケージをインストールします。
(.venv) wiki-rag> pip install langchain-community langchain-openai chromadb unstructured markdown psutil
パッケージ名 | 簡単な説明 | 執筆時のバージョン |
---|---|---|
langchain-community | 文書をベクトル化するのに必要。 | 0.2.6 |
langchain-openai | 文書を OpenAI の形式でベクトル化するのに必要。 | 0.1.13 |
chromadb | LangChain がベクトルデータを保存するのに使っている。 | 0.5.3 |
unstructured | LangChain が Markdown を読むのに使っている。 | 0.14.9 |
Markdown | unstructured が Markdown を読むのに使っている。 | 3.6 |
psutil | システムのプロセス情報を取得する。unstructured が使っている。 | 6.0.0 |
以上のパッケージをインストールすれば、LangChain や OpenAI パッケージは一緒にインストールされます。
他のパッケージから import
されているパッケージも、必須でないものは依存関係に書かれていない (一緒にインストールされない) ので、自分でインストールする必要があります。
何が必要か分からない場合は、実行すればエラーが出るので分かります。
.env ファイルを変更する
.env ファイルに OpenAI API の API キーを追加します。
# 前略
OPENAI_API_KEY=your-openai-api-key # OpenAI API の API キー
OpenAI パッケージはこの環境変数を使うようになっているので、ここに書いておけば自分で Python の変数に読み込む必要はありません。
Wiki の内容をベクトル化する関数を作る
wiki-rag/make_vectors_from_wiki.py を実行した時の if
文の前に関数を追加します。
また、追加した関数をその if
文の中で実行するようにします。
コードの追加分
# 前略
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores import Chroma
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_openai import OpenAIEmbeddings
vector_dir = Path(__file__).parent / 'vectors'
# 中略
def make_vectors():
# wiki_dir の中の .md ファイルを読み込む (single mode: 1ファイル=1ドキュメント)
wiki_docs = []
print('ベクトル化の準備をしています...')
for file_path in tqdm(wiki_dir.glob('*.md')):
loader = UnstructuredMarkdownLoader(file_path, mode='single', strategy='fast')
docs = loader.load()
wiki_docs.extend(docs)
# metadata の source が Path の場合はエラーになるので除外する
filter_complex_metadata(wiki_docs)
# 保存先フォルダを削除しておく
if vector_dir.exists():
shutil.rmtree(vector_dir)
# VectorStore を作成して保存する
print('ベクトル化しています...')
Chroma.from_documents(
documents=wiki_docs,
embedding=OpenAIEmbeddings(),
persist_directory=str(vector_dir) # str にしないとエラーになる
)
if __name__ == '__main__':
retrieve_wiki('親ページのタイトル')
make_vectors()
コードの解説
make_vectors()
関数の中を細かく解説していきます。
# wiki_dir の中の .md ファイルを読み込む (single mode: 1ファイル=1ドキュメント)
wiki_docs = []
print('ベクトル化の準備をしています...')
for file_path in tqdm(wiki_dir.glob('*.md')):
loader = UnstructuredMarkdownLoader(file_path, mode='single', strategy='fast')
docs = loader.load()
wiki_docs.extend(docs)
ここでは UnstructuredMarkdownLoader
を使ってMarkdown ファイルの内容を LangChain の Document
オブジェクトに変換しています。
変数 wiki_docs
は Document
オブジェクトのリストになります。
Document
オブジェクトには情報ソース等のメタデータが含まれます。
この記事では扱いませんが、AI の返答に情報ソースを表示するといったこともできそうです。
mode
を single
にしているので1個のファイルが1個の Document
になります。
mode
を element
にすると文書の構成に合わせて分割してくれますが、検索もその単位になるのであまり細かく分けすぎると必要な情報が取り出せなくなります。
ファイル単位で取り込んで、LangChain の Text Splitters でいい感じに分割するという方法もあるようです。
# metadata の source が Path の場合はエラーになるので除外する
filter_complex_metadata(wiki_docs)
エラー回避のために Document
の metadata
を削除する関数を呼んでいます。
具体的には、metadata
の中の source
キーの値が str
ではなく Path
であるためにエラーになるので、filter_complex_metadata()
で metadata
からそういうキーを削除しているのです。
Document
オブジェクトに情報ソースを残しておきたい場合は filter_complex_metadata()
をせずに、以下のように source
キーの値を str
に変換することもできます。
for doc in wiki_docs:
source = doc.metadata['source']
if type(source) is not str:
doc.metadata['source'] = str(source)
# VectorStore を作成して保存する
print('ベクトル化しています...')
Chroma.from_documents(
documents=wiki_docs,
embedding=OpenAIEmbeddings(),
persist_directory=str(vector_dir) # str にしないとエラーになる
)
LangChain の Chroma
クラスを使って、上で作った Document
オブジェクトを VectorStore として保存しています。
この from_documents()
関数の embedding
引数にはベクトル化に使うモデルとして OpenAIEmbeddings
を与えています。
persist_directory
は VectorStore の保存先ですが、これも先ほどの Document
の metadata
と同じく Path
オブジェクトだとエラーになるので、str
に変換して渡しています。
OpenAIEmbeddings
は LangChain のクラスですが、OpenAI API を使うので料金がかかります。
OpenAIEmbeddings
をインスタンス化する時に以下のようにモデルを指定することもできます。
embedding=OpenAIEmbeddings(model="text-embedding-3-large"),
指定できるモデルと料金については OpenAI の こちらのドキュメント に書かれています。
OpenAIEmbeddings
がデフォルトで使うモデルは text-embedding-ada-002
です。
実行してみる
(.venv) wiki-rag> python .\make_vectors_from_wiki.py
Wiki からコンテンツを取得しています...
100%|███████████████████████████████████████████████████| 17/17 [00:01<00:00, 10.12it/s]
ベクトル化の準備をしています...
17it [00:02, 7.18it/s]
ベクトル化しています...
(.venv) wiki-rag> _
vectors フォルダを確認すると、バイナリファイルが入ったディレクトリと sqlite ファイルができています。
(.venv) wiki-rag> ls vectors
ディレクトリ: C:\Users\sada-k\local\study\langchain\vectors
Mode LastWriteTime Length Name
---- ------------- ------ ----
d----- 2024/07/09 10:39 668b800d-3a13-4b24-b872-d9033d6840fa
-a---- 2024/07/09 10:39 761856 chroma.sqlite3
Markdown を Document
オブジェクトにするところでプログレスバーが出ないのは、tqdm
にジェネレータを渡していて、実行してみないとループ数が分からないからです。
プログレスバーを表示したい場合は for
文の前でリスト化しておくなどの工夫が必要です。
3. チャット機能を作る
あらたに chat.py というモジュールを作って、ここまでで作った VectorStore を使ってチャットするプログラムを作ります。
ここからは LangChain の機能をたくさん使うので、いろいろと新しい概念が出てきます。
チャット機能で使う LangChain 関連のパッケージは ここまででインストールしているので、ここで新たにインストールするパッケージはありません。
LLM に問い合わせる関数を作る
まずはコード全体を掲載します。
コード全体
from textwrap import dedent
from pathlib import Path
from dotenv import load_dotenv
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
load_dotenv(override=True)
vector_dir = Path(__file__).parent / 'vectors'
# 保存済み VectorStore の retriever (類似度の高い3件を取得する) を作成する
vectors = Chroma(
persist_directory=str(vector_dir),
embedding_function=OpenAIEmbeddings())
retriever = vectors.as_retriever(search_kwargs={'k': 3})
def format_docs(docs):
'''retriever から取得したドキュメントのテキストを連結して返す
'''
return "\n\n".join(doc.page_content for doc in docs)
# retriever を使って LLM に送るプロンプトを作るテンプレート
prmp_tmpl = ChatPromptTemplate.from_template(dedent('''\
以下のコンテキストを踏まえて質問に回答してください。
Context: """
{context}
"""
Question: {question}
''')
)
# LLM を準備する
llm = ChatOpenAI(model_name='gpt-3.5-turbo', max_tokens=400)
# OutputParser を準備する
output_parser = StrOutputParser()
# ユーザの入力に関連するコンテキストを取得するチェーン
context_chain = retriever | format_docs
# プロンプトを入力して AI に問い合わせ、結果を得るチェーン
llm_chain = llm | output_parser
def chat(user_input: str) -> str:
'''ユーザの入力を受け取り、AI に問い合わせて結果を返す
'''
try:
# ユーザの入力内容に関連するコンテキストを取得する
context = context_chain.invoke(user_input)
# LLM に問い合わせるためのプロンプトを作る
prmp = prmp_tmpl.format(context=context, question=user_input)
# 作ったプロンプトをユーザの入力として LLM に問い合わせる
messages = [HumanMessage(prmp)]
result = llm_chain.invoke(messages)
except Exception as e:
result = f'エラーが起きました。\n{str(e)}'
return result
if __name__ == '__main__':
# ユーザの入力を受け取って、AI に問い合わせて結果を返す
while user_input := input('質問をどうぞ > '):
print('\n' + chat(user_input) + '\n')
コードの解説
モジュールの上の方から解説していきます。
# 保存済み VectorStore の retriever (類似度の高い3件を取得する) を作成する
vectors = Chroma(
persist_directory=str(vector_dir),
embedding_function=OpenAIEmbeddings())
retriever = vectors.as_retriever(search_kwargs={'k': 3})
ここで作っている retriever
は、LangChain の VectorStoreRetriever
オブジェクトです。
これは VectorStore の中を検索する機能で、ユーザの入力を受けるたびに使うので あらかじめチャットのループの外で作っておきます。
retriever
の元になる VectorStore には、ベクトル化した時と同じモデル (OpenAIEmbeddings
) を指定する必要があります。
def format_docs(docs):
'''retriever から取得したドキュメントのテキストを連結して返す
'''
return "\n\n".join(doc.page_content for doc in docs)
format_docs()
は retriever
から返る Document
オブジェクトのリスト (=検索結果) をまとめて1個の文字列にする関数です。
1個の文字列にする理由は、下記のテンプレートに挿入するためです。
# retriever を使って LLM に送るプロンプトを作るテンプレート
prmp_tmpl = ChatPromptTemplate.from_template(dedent('''\
以下のコンテキストを踏まえて質問に回答してください。
Context: """
{context}
"""
Question: {question}
''')
)
LLM に問い合わせる時の「プロンプト」を作るためのテンプレートです。
{context}
部分に検索した文書 (上記の format_docs()
の出力) を、{question}
部分にユーザの入力を入れて使います。
今回の使い方であれば単に文字列の format()
を使ってもプロンプトを作れますが、LangChain を理解する上で ChatPromptTemplate
クラスを使うやり方も知っておくと良いです。
このようなクラスは LCEL (LangChain Expression Language) という書き方に対応していて、処理の流れをすっきりと記述したい時に使えます。
# LLM を準備する
llm = ChatOpenAI(model_name='gpt-3.5-turbo', max_tokens=400)
# OutputParser を準備する
output_parser = StrOutputParser()
ChatOpenAI
は OpenAI の Chat Completions API を LangChain の文脈に合わせて使いやすくするクラスです。
OpenAI API で使えるモデルの一覧は こちらのページ にあります。
また、そのうち Chat Completions API で使えるものは同じページの Model endpoint compatibility というセクションを見ると分かります (/v1/chat/completions
というエンドポイント)。
StrOutputParser
は llm
から返ってきたメッセージ (AIMessage
オブジェクト) を文字列にパースするためのクラスです。
# ユーザの入力に関連するコンテキストを取得するチェーン
context_chain = retriever | format_docs
# プロンプトを入力して AI に問い合わせ、結果を得るチェーン
llm_chain = llm | output_parser
ここで、チェーン という考え方が登場します。
LangChain で扱うオブジェクトには、チェーンという 「一連の処理の流れ」 に組み込むことができる (Runnable な) オブジェクトがあります。
ここまでで準備した retriever
や llm
、output_parser
は Runnable なオブジェクト です。
また、format_docs
は普通の関数ですが、|
(パイプ) 演算の相手である retriever
が Runnable として扱ってくれます。
prmp_tmpl
も Runnable ですが、入力が複数あってややこしいのでチェーンにしていません。
この |
演算子で Runnable なオブジェクト (またはその相手となり得るオブジェクト) を連結する書き方が、先ほど少し触れた LCEL になります。
def chat(user_input: str) -> str:
'''ユーザの入力を受け取り、AI に問い合わせて結果を返す
'''
try:
# ユーザの入力内容に関連するコンテキストを取得する
context = context_chain.invoke(user_input)
# LLM に問い合わせるためのプロンプトを作る
prmp = prmp_tmpl.format(context=context, question=user_input)
# 作ったプロンプトをユーザの入力として LLM に問い合わせる
messages = [HumanMessage(prmp)]
result = llm_chain.invoke(messages)
except Exception as e:
result = f'エラーが起きました。\n{str(e)}'
return result
この章のメインとなる関数です。
上で作った context_chain
を使うことで、ユーザの入力との類似度の高い文書を取り出して文字列にするところまでを1行でやっています。
また、AI に問い合わせをして結果を文字列にするところもチェーン (llm_chain
) を使っています。
llm_chain
に渡す入力はチェーンの先頭にある llm
への入力となるので、「Message のリスト」になるようにしています。
ここでは Message は1個なので要素が1個のリストになります。
この Message はユーザの入力として AI に渡すので HumanMessage
というオブジェクトにしています。
AIMessage
や HumanMessage
という種類があるのは、LLM に複数の Message を渡す時に「誰が何を言ったか」が分かるようにです。
API は過去のやりとりを憶えていないので、問い合わせるたびに過去に誰が何を言ったかの情報を渡す必要があるのです。
実行してみる
以下、Wiki の内容が「社内の課題管理」だった場合をイメージした実行結果です。
質問するたびに AI への問い合わせが発生し、料金がかかります。
(.venv) wiki-rag> python .\chat.py
質問をどうぞ > どんな質問に答えられますか。
このコンテキストに関連する以下のような質問に答えることができます:
1. 社内での課題管理のしかたについて教えてください。
2. 課題を起票する方法について説明してください。
3. 未読のお知らせをすべて既読にする方法について教えてください。
質問をどうぞ > 未読のお知らせをすべて既読にする方法について教えてください。
未読のお知らせをすべて既読にする方法はありません。
質問をどうぞ > (Enter を空打ちすると終了します)
(.venv) wiki-rag> _
会話履歴を参照するようにする
OpenAI の Chat Completions API は履歴を記憶しないので、文脈に沿った返答を期待する場合は 単品の入力ではなく一連のやりとりを送信する必要があります。
記事の冒頭の「おおまかな手順」で「会話履歴を保存・参照するようにする」と書きましたが、ここでは参照して API に送信するところだけ作ります。
理由は、このあと別のモジュールから chat()
関数を呼ぶようにして、そっちで会話履歴の初期化や保存をしたいからです。
という訳で、ここでの変更は少しだけになります。
コードの変更箇所
# 前略
from langchain.memory import ConversationBufferMemory # これを追加
# 中略
def chat(user_input: str, memory: ConversationBufferMemory) -> str: # ここを変更
'''ユーザの入力を受け取り、AI に問い合わせて結果を返す
'''
try:
# ユーザの入力内容に関連するコンテキストを取得する
context = context_chain.invoke(user_input)
# LLM に問い合わせるためのプロンプトを作る
prmp = prmp_tmpl.format(context=context, question=user_input)
# 作ったプロンプトをユーザの入力として LLM に問い合わせる
messages = memory.buffer_as_messages + [HumanMessage(prmp)] # ここを変更
result = llm_chain.invoke(messages)
except Exception as e:
result = f'エラーが起きました。\n{str(e)}'
return result
# これ以降は削除する
コードの解説
会話履歴を保存するのに LangChain の ConversationBufferMemory
オブジェクトを使っています。
これを chat()
関数で受け取れるように引数に追加しています。
そして、LLM に渡す messages
を 「過去の履歴に今回のプロンプトを追加したもの」 にしています。
最後の if
ブロック (このモジュールを直に実行した場合の処理) はエラーになるので削除しておいてください。
4. UI をつけて Web アプリにする
Streamlit というパッケージを使います。
Streamlit は Web フレームワークというよりは、対話的なプレゼンテーションのためのツールです。
といっても私も使うのは初めてなのですが、とにかく短いコードでプロっぽいものが作れると思います。
完成イメージ
パッケージをインストールする
(.venv) wiki-rag> pip install streamlit
執筆時のバージョンは 1.36.0 です。
チャット画面を作る
Streamlit には「要素を描画するための関数」がたくさん用意されていて、これをコマンドのように Python モジュールの中に記述していきます。
会話履歴のように再描画が起こった時にも憶えておきたい情報は、Session State に入れておくだけで保持できます。
では、新しく app.py というモジュールを作ります。
コードの内容については下で解説します。
コード全体
from textwrap import dedent
import streamlit as st
from langchain_core.messages import SystemMessage
from langchain.memory import ConversationBufferMemory
from chat import llm, chat
st.title('社内の課題管理マニュアルボット')
# 会話履歴を初期化する
if 'memory' not in st.session_state:
memory = ConversationBufferMemory(llm=llm)
memory.chat_memory.add_message(SystemMessage(dedent('''\
課題管理マニュアルを参照して質問に回答してください。
質問に関係ありそうなマニュアルの抜粋をコンテキストとして渡します。
回答の中では「コンテキスト」ではなく「マニュアル」と言ってください。
マニュアルには以下の内容が書かれています。
- 社内での課題管理の概要
- 課題を起票する
- 課題にコメントする
- 課題を完了する
「回答:」等の接頭辞を付けずに回答だけを返してください。
ただし丁寧な言葉遣いを心がけてください。
回答は200文字以内になるように、長い場合は要約してください。
''')
))
memory.chat_memory.add_ai_message(
'こんにちは! 課題管理マニュアルに書かれた範囲内で質問にお答えします。'
)
st.session_state.memory = memory
# 画面更新時に、チャットコンテナに会話履歴を表示する
for message in st.session_state.memory.buffer_as_messages:
if message.type in ('human', 'ai'):
with st.chat_message(message.type):
st.markdown(message.content)
# ユーザの入力を受け付ける
if user_input := st.chat_input('質問をどうぞ。'):
# 入力内容をチャットコンテナに表示する
with st.chat_message('human'):
st.markdown(user_input)
# AI からの返答を生成する
response = chat(user_input, st.session_state.memory)
# AI からの返答をチャットコンテナに表示する
with st.chat_message('ai'):
st.markdown(response)
# 入力内容と AI からの返答を会話履歴に追加する
st.session_state.memory.chat_memory.add_user_message(user_input)
st.session_state.memory.chat_memory.add_ai_message(response)
コードの解説
st.title('社内の課題管理について答えるボット')
st.title()
は、大き目のフォントでタイトルを表示したい時に使うコマンドです。
# 会話履歴を初期化する
if 'memory' not in st.session_state:
memory = ConversationBufferMemory(llm=llm)
memory.chat_memory.add_message(SystemMessage(dedent('''\
課題管理マニュアルを参照して質問に回答してください。
<< 中略 >> (上の「コード全体」を参照してください。)
回答は200文字以内になるように、長い場合は要約してください。
''')
))
memory.chat_memory.add_ai_message(
'こんにちは! 課題管理マニュアルに書かれた範囲内で質問にお答えします。'
)
st.session_state.memory = memory
st.session_state
は辞書のように使えるので、'memory' というキーで会話履歴を記憶しています。
つまり、st.session_state
にそのキーがなければ作って初期化します。
上で chat()
関数でやったように、これは ConversationBufferMemory
オブジェクトです。
インスタンスを作ったら最初に AI への指示として SystemMessage
を追加しています。
そして次に AI に最初に言わせたいこととして AIMessage
を追加しています。
# 画面更新時に、チャットコンテナに会話履歴を表示する
for message in st.session_state.memory.buffer_as_messages:
if message.type in ('human', 'ai'):
with st.chat_message(message.type):
st.markdown(message.content)
会話履歴は、memory
の中の Message を順番に描画していきます。
各 Message には type
プロパティがあって、これを使って st.chat_message()
で描画します。
OpenAI の 'human'
や 'ai'
というメッセージタイプに対応しているので、それっぽいアイコンを表示してくれます。
st.chat_message()
がアイコンつきのコンテナを描画して、そのコンテキストの中で st.markdown()
で Message の内容を描画しています。
st.markdown()
を使うのは OpenAI API のモデルが Markdown で答えることがあるからです。
また、message.type
で分岐しているのは、AI への指示 (SystemMessage) を表示したくないからです。
# ユーザの入力を受け付ける
if user_input := st.chat_input('質問をどうぞ。'):
不思議な書き方ですけど、これで入力フィールドの描画 (フロント側) と入力された時の処理 (サーバ側) を指示することができます。
# 入力内容をチャットコンテナに表示する
with st.chat_message('human'):
st.markdown(user_input)
# AI からの返答を生成する
response = chat(user_input, st.session_state.memory)
# AI からの返答をチャットコンテナに表示する
with st.chat_message('ai'):
st.markdown(response)
# 入力内容と AI からの返答を会話履歴に追加する
st.session_state.memory.chat_memory.add_user_message(user_input)
st.session_state.memory.chat_memory.add_ai_message(response)
今回入力されたばかりでまだ履歴にないユーザの入力と、AI から返ってきた返答を st.chat_message()
で描画しています。
このタイミングで st.chat_message()
しても、ちゃんと会話履歴を表示しているコンテナの中に描画してくれます。
最後に、今回の入力内容と AI の返答を履歴に追加して、次回から表示されるようにしています。
chat()
関数の中ではコンテキストを含めたプロンプトを生成して LLM に送っていますが、会話履歴に追加するのはそのプロンプトではなくユーザが入力したとおりの文言にします。
実行してみる
以下のコマンドでローカルサーバが起動してブラウザでアプリが開きます。
(.venv) wiki-rag> streamlit run app.py
ローカルサーバを止めるには Ctrl-C
を押します。
宿題
1. ストリーミング対応
まだ試していませんが、Streamlit の st.write_stream()
を使うと、ストリーミング表示ができるようです。
AI の返答をタイピングするかのように少しずつ表示する、あれです。
単にエフェクトとしてそのように表示するのでなく、API からストリーミングでレスポンスを受け取って受け取った分だけ表示するというのが、ChatOpenAI
オブジェクトの stream()
関数を使えば割と簡単にできるはずです。
2. 会話履歴の切り詰め
会話履歴に使っている ConversationBufferMemory
はシンプルに Massege を溜めていくオブジェクトです。
履歴が長くなると API に送信できる限界に達したり、料金がたくさんかかるようになったりするので、ある程度の長さに切り詰めるようにしたいです。
ConversationTokenBufferMemory
を使うと記憶するトークン数を制限しながら履歴を溜めていけるようです。
また、ConversationSummaryBufferMemory
を使うと古い履歴を自動的に要約するようです。
これらの Memory を使って API に送るトークン数を押さえつつ、画面に表示する履歴は ConversationBufferWindowMemory
を使ってシンプルにメッセージ数で切り詰めるのが良いのではと思っています。
会話履歴の切り詰めをする場合、最初の SystemMessage
は残すようなロジックにする必要があります。
Memory の種類は こちら に一覧があります。