最初に
以下の記事の続編です。
今回試すこと
前回の記事では英語の文章をハードコーディングしていました。
今回はファイルを埋め込みして、なおかつ日本語対応してみようと思います。
Python環境の構築
今回もAnaconda環境での構築を想定しています。
condaだと新しいライブラリが使えないことがあるので、前回と同様にpipでパッケージを導入しています。
conda create -n embtest python=3.8
conda activate embtest
pip install chromadb streamlit
pip install sentence-transformers
pip install sentence-transformers
が前回からの差分です。
日本語対応している埋め込みモデルを自分で選ぶためにSentence Transformersというライブラリを利用します。
日本語のドキュメントの準備
どういうものを使ってもいいのですが、今回は以下のサイトからお店の情報を手動でコピペしたテキストファイルを5~6個用意しました。
Pythonプログラムの実装
まず、日本語のドキュメントを1ファイル1データとしてChromaに追加します。その後に検索するプログラムを作成します。
import chromadb
import streamlit as st
import glob
import os
# デフォルト以外の埋め込みモデルを読み込む
@st.cache_resource
def create_embedding_functions():
# Sentence Transformers (sbert) の多言語対応の埋め込みモデルを使う
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# 例えば下記のものはモデルサイズは小さいが日本語に対応していない。そのためピンと来ない検索結果になる
# model_name = "sentence-transformers/all-MiniLM-L6-v2"
ef = chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model_name)
return ef
# Chromaのコレクションを作成する
@st.cache_resource
def create_collection():
# ChromeDBのクライアントを取得する
chroma_client = chromadb.Client()
# コレクションを作成する。既にある場合はそれを参照する
collection = chroma_client.get_or_create_collection(name="test_collection")
return collection
# Chromaのコレクションにドキュメントを追加する
def add_collection(collection, ef, filename, document):
# 与えられた埋め込みモデルでEmbeddingを計算
embedding = ef([content])[0]
# 計算したEmbeddingを指定してドキュメントを追加する
collection.add(
embeddings=[embedding],
documents=[document],
metadatas=[{"source": filename}],
ids=[filename]
)
# 参考値としてEmbeddingを返す
return embedding
# ファイルを列挙する
def glob_files():
return glob.glob("data/*.txt")
# 画面レイアウトを設定
st.set_page_config(layout="wide")
col1, col2 = st.columns(2)
# 初期化
collection = create_collection()
# 画面左側:ベクトルデータベースの構築
with col1:
st.header("構築")
# 埋め込みモデルを取得
ef = create_embedding_functions()
# ドキュメントとして使うファイルを列挙
filelist = glob_files()
# 各ファイルをベクトルデータベースに追加する
for filepath in filelist:
with open(filepath, encoding="utf8") as f:
content = f.read()
filename = os.path.basename(filepath)
# ベクトルデータベースに追加
embedding = add_collection(collection, ef, filename, content)
# ベクトルデータベースに追加した内容を表示する
st.subheader(filename)
st.code(content)
st.write(embedding)
st.divider()
# 以下のコードでコレクションの中身を確認できる
# head = collection.peek()
# count = collection.count()
# 画面右側で実行されるコード:ユーザ入力があったときはこの関数内の処理だけが再実行される
@st.experimental_fragment
def search():
# 埋め込みモデルを取得
ef = create_embedding_functions()
# 検索文字列を取得してEmbeddingを計算
query_text = st.text_input("検索文字列", value="エビ餃子が食べたい。")
embedding = ef([query_text])[0]
# 検索文字列とEmbeddingを表示
st.subheader("入力内容")
st.code(query_text)
st.write(embedding)
# 検索して上位2件を取得する
result = collection.query(
query_embeddings=[embedding],
n_results=2
)
# 検索結果を表示
st.subheader("検索結果")
st.write(result)
# 画面右側:ベクトルデータベースからの検索
with col2:
st.header("検索")
search()
このプログラムをstreamlit run embed_file.py
で起動してみると、何となく検索できてそうです。
一方で、色々と入力してみると必ずしもヒットしてほしいものがトップにならないケースもあるということが分かります。
埋め込みモデルの能力もあるのでしょうが、例えば「1個35gの餃子は大きいのか小さいのか?」など、埋め込みモデルに対して渡したデータの外にしか意味が存在しないケースもあるので限界もあるのかなと解釈しています。