GradioでChromaにコレクションを作成したり、削除したり、PDFのドキュメントを追加したり、検索したりする簡単なWebアプリケーションを作ってみた。
# %%
import gradio as gr
import chromadb
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_chroma.vectorstores import Chroma
from langchain_community.document_loaders.pdf import PDFPlumberLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
# %% Chromaのクライアントを作成
path = './chroma'
client = chromadb.PersistentClient(path=path)
# %% HuggingFaceよりEmbeddingモデルを取得
embedding_model = HuggingFaceEmbeddings(model_name='intfloat/multilingual-e5-large')
# %% コレクションのリストを取得する関数
def get_collection_names():
collection_names = []
for collection in client.list_collections():
collection_names.append(collection.name)
return collection_names
# %% コレクションのリストをテキストにするための関数
def display_collections():
collections = get_collection_names()
if collections:
return '\n'.join(collections)
else:
return '現在、利用可能なコレクションはありません。'
# %% コレクションを作成するための関数
def create_collection(collection_name):
collection_names = get_collection_names()
if collection_name in collection_names:
return f"コレクション '{collection_name}' はすでに存在します。"
else:
client.create_collection(collection_name)
return f"コレクション '{collection_name}' が作成されました。"
# %% PDFの文書を追加するための関数
def add_document(collection_name, pdf_file):
collection_names = get_collection_names()
if collection_name not in collection_names:
return f"コレクション '{collection_name}' が見つかりません。"
else:
loader = PDFPlumberLoader(pdf_file)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=0)
tmp = text_splitter.split_documents(documents=documents)
db = Chroma(collection_name=collection_name, embedding_function=embedding_model, persist_directory=path)
db.add_documents(documents=tmp)
return f"文書が '{collection_name}' に追加されました。"
# %% コレクションを削除するための関数
def delete_collection(collection_name):
collection_names = get_collection_names()
if collection_name not in collection_names:
return f"コレクション '{collection_name}' が見つかりません。"
else:
client.delete_collection(collection_name)
return f"コレクション '{collection_name}' が削除されました。"
# コレクションから検索するための関数
def search_collection(collection_name, query, k):
collection_names = get_collection_names()
if collection_name not in collection_names:
return f"コレクション '{collection_name}' が見つかりません。"
else:
db = Chroma(collection_name=collection_name, embedding_function=embedding_model, persist_directory=path)
results = db.similarity_search_with_relevance_scores(query=query, k=k)
tmp = ''
for result, score in results:
tmp += result.page_content.replace('\n', '')
tmp += '\n'
tmp += f'スコア: {score:.4f}'
tmp += '\n\n'
return tmp
# %% Webアプリケーションの設定
with gr.Blocks() as demo:
with gr.Tab('RAG'):
gr.Markdown('## コレクションの表示')
with gr.Tab('データ登録'):
# コレクション一覧セクション
gr.Markdown('## コレクションの表示')
display_button = gr.Button('コレクションの表示')
collection_list_output = gr.Textbox(label='コレクション')
display_button.click(fn=display_collections, inputs=[], outputs=collection_list_output)
# コレクション作成セクション
gr.Markdown('## コレクションの作成')
collection_name_input = gr.Textbox(label='コレクション名')
create_button = gr.Button('コレクションの作成')
create_output = gr.Textbox(label='作成結果')
create_button.click(create_collection, inputs=collection_name_input, outputs=create_output)
# 文書追加セクション
gr.Markdown('## 文書の追加')
collection_name_for_add = gr.Textbox(label='コレクション名')
pdf_file_input = gr.File(label='PDFファイルをアップロード', file_types=['.pdf'])
add_button = gr.Button('文書の追加')
add_output = gr.Textbox(label='追加結果')
add_button.click(fn=add_document, inputs=[collection_name_for_add, pdf_file_input], outputs=add_output)
# コレクション削除セクション
gr.Markdown('## コレクションの削除')
collection_name_for_delete = gr.Textbox(label='コレクション名')
delete_button = gr.Button('コレクションの削除')
delete_output = gr.Textbox(label='削除結果')
delete_button.click(fn=delete_collection, inputs=collection_name_for_delete, outputs=delete_output)
# コレクション検索セクション
gr.Markdown('## 検索')
collection_name_for_search = gr.Textbox(label='コレクション名')
query_input = gr.Textbox(label='クエリ')
k_slider = gr.Slider(label='検索上位件数', minimum=1, maximum=16, step=1, value=4)
search_button = gr.Button('検索')
search_output = gr.Textbox(label='検索結果')
search_button.click(fn=search_collection, inputs=[collection_name_for_search, query_input, k_slider], outputs=search_output)
# %% Webアプリケーションの起動
demo.launch()
# %%