データの準備
data.py
# %%
from datasets import load_dataset
ds = load_dataset("neural-bridge/rag-dataset-12000")
ds
# %%
import pandas as pd
df_train = pd.DataFrame(ds['train'])
df_train
# %%
df_test = pd.DataFrame(ds['test'])
df_test
# %%
df = pd.concat([df_train, df_test])
df
# %%
df.rename(columns={'answer': 'reference'}, inplace=True)
df
# %%
import csv
df.iloc[:100].to_csv('data.csv', index=False, quoting=csv.QUOTE_ALL)
# %%
データをmilvusへinsert
milvus.py
# %%
import pandas as pd
filepath = "data.csv"
df = pd.read_csv(filepath)
df
# %%
from langchain_core.documents import Document
documents = []
for index, row in df.iterrows():
question = row.question
context = row.context
reference = row.reference
metadata = {'question': question, 'reference': reference}
page_content = context
documents.append(Document(page_content=page_content, metadata=metadata))
documents
# %%
len(documents)
# %%
from langchain_huggingface import HuggingFaceEmbeddings
embedding = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# %%
from langchain_milvus import Milvus
URI = "milvus.db"
connection_args = {"uri": URI}
vector_store = Milvus.from_documents(documents=documents, embedding=embedding, connection_args=connection_args, drop_old=True)
# %%
RAGの実装
rag.py
# %%
from langchain_huggingface import HuggingFaceEmbeddings
embedding = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# %%
from langchain_milvus import Milvus
URI = "milvus.db"
connection_args = {'uri': URI}
vector_store = Milvus(embedding_function=embedding, connection_args=connection_args)
# %%
from dotenv import load_dotenv
load_dotenv()
# %%
import os
from azure.identity import ClientSecretCredential, get_bearer_token_provider
tenant_id = os.environ.get('AZURE_TENANT_ID')
client_id = os.environ.get('AZURE_CLIENT_ID')
client_secret = os.environ.get('AZURE_CLIENT_SECRET')
credential = ClientSecretCredential(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret
)
# %%
scopes = "https://cognitiveservices.azure.com/.default"
azure_ad_token_provider = get_bearer_token_provider(credential, scopes)
# %%
from langchain_openai import AzureChatOpenAI
azure_endpoint=os.environ.get('AZURE_OPENAI_ENDPOINT')
azure_deployment = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
api_version = os.environ.get("API_VERSION")
temperature = 0
max_tokens = 4096
llm = AzureChatOpenAI(
azure_endpoint=azure_endpoint,
api_version=api_version,
azure_deployment=azure_deployment,
azure_ad_token_provider=azure_ad_token_provider,
temperature=temperature,
max_tokens=max_tokens,
)
# %%
from langchain_core.prompts import PromptTemplate
template = """# System:
Answer the question based on the given context.
# Context:
{context}
# Question:
{question}
# Answer:
"""
prompt_template = PromptTemplate.from_template(template=template)
# %%
#import pandas as pd
#filepath = "data.csv"
#df = pd.read_csv(filepath_or_buffer=filepath)
#df
# %%
#for index, row in df.iterrows():
# context = row.context
# question = row.question
# try:
# response = llm.invoke(input=prompt_template.format(context=context, question=question))
# generated_text = response.content
# except Exception as e:
# generated_text = ""
# print(generated_text)
# %%
Webアプリケーションの実装
app.py
# %%
import gradio as gr
from rag import vector_store, prompt_template, llm
# %%
def rag(question):
try:
results = vector_store.similarity_search(query=question)
except Exception as e:
context = ""
generated_text = ""
return generated_text, context
context = ""
for i in range(len(results)):
context += f"# {i+1}"
context += "\n"
context += results[i].page_content
context += "\n"
prompt = prompt_template.format(context=context, question=question)
response = llm.invoke(input=prompt)
generated_text = response.content
return generated_text, context
# %%
with gr.Blocks() as app:
question = gr.Textbox(label="question")
generated_text = gr.Textbox(label="generated_text")
context = gr.Textbox(label="context")
button = gr.Button()
button.click(fn=rag, inputs=[question], outputs=[generated_text, context])
question.submit(fn=rag, inputs=[question], outputs=[generated_text, context])
# %%
app.launch(server_name="0.0.0.0")
# %%