このページを閲覧しているということは、皆さんが会社や自宅でのDIYプロジェクトにおいてRAG(Retriever-Augmented Generator)を活用している、または導入を検討していることと思います。カスタムローダーの作成に至るということは、既存のLlamaIndexやLangchainが標準で提供するRetriever機能だけでは要件を満たさない、あるいは私のケースのように、既存のRetrieverが十分ではなく、さらなるカスタマイズを求めているのでしょう。
インターネットで調べても、カスタムローダーの作成方法に関する情報はあまり見つかりません。私の検索方法が不適切なのか、それとも情報が少ないのかは分かりませんが、様々なキーワードで試してもあまり変わりませんでした。
そこで、この記事ではカスタムローダーの作り方について詳しく説明していきます。すぐに作成方法を知りたい方は、下記の出来上がったテンプレート
までスクロールして、コピーして使ってくださいね。
既存のローダーの観察
作り方どこにも記載がなかったので、既存のコードを参考にするしかありません。
今回は、WikipediaのRetrieverを参考に作っていきたいと思います。
Toolsを作ったときもWikipediaが一番簡素に作ってあったイメージですね。
まずは、観察ということでInputとOutputをどのような形でやればいいのか、見ていきたいと思います。(ここが肝)
from typing import List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utilities.wikipedia import WikipediaAPIWrapper
class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
"""`Wikipedia API` retriever.
It wraps load() to get_relevant_documents().
It uses all WikipediaAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)
こちらがRetriever側のコードです。ここでInputはquery
があることがわかります。重要な部分は、Wikipediaのデータを取得するWikipediaAPIWrapper
に含まれています。具体的には、_get_relevant_documents
という内部メソッドで、継承されているWikipediaAPIWrapper
のload
メソッドが利用されています。
次に、このWikipediaAPIWrapper
の詳細に焦点を当てて見ていきましょう。
def load(self, query: str) -> List[Document]:
"""
Run Wikipedia search and get the article text plus the meta information.
See
Returns: a list of documents.
"""
page_titles = self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
docs = []
for page_title in page_titles[: self.top_k_results]:
if wiki_page := self._fetch_page(page_title):
if doc := self._page_to_document(page_title, wiki_page):
docs.append(doc)
return docs
この部分のコードで__page_to_document
を呼んでいる事がわかります。
そのコードを見てみると以下のようになっています。
def _page_to_document(self, page_title: str, wiki_page: Any) -> Document:
main_meta = {
"title": page_title,
"summary": wiki_page.summary,
"source": wiki_page.url,
}
add_meta = (
{
"categories": wiki_page.categories,
"page_url": wiki_page.url,
"image_urls": wiki_page.images,
"related_titles": wiki_page.links,
"parent_id": wiki_page.parent_id,
"references": wiki_page.references,
"revision_id": wiki_page.revision_id,
"sections": wiki_page.sections,
}
if self.load_all_available_meta
else {}
)
doc = Document(
page_content=wiki_page.content[: self.doc_content_chars_max],
metadata={
**main_meta,
**add_meta,
},
)
return doc
ここの中の
doc = Document(
page_content=wiki_page.content[: self.doc_content_chars_max],
metadata={
**main_meta,
**add_meta,
},
)
page_contentとmetadataを持ったDICTがOutputであることがわかります。
観察結果
上記を観察した結果下記のようにInput/Outputを書けば良いことがわかりました。
Input(入力) | Output(出力) |
---|---|
query: str 文字型のqueryというパラメータ | [Document(page_content="Retriveした内容",metadata={'source':'どこから取ってきたかなど'})] |
※入力のqueryというパラメータ名は変更する事ができます。
※出力のmetadataのsource部分は任意に書き換えれます。
出来上がったテンプレート
下記のようにCustomLoaderを作るときのテンプレートを作りました。
注意
現在は固定値で返すようにしていますが、この部分を任意の場所からRetriveするように作り替えてください。
ドキュメントローダー(Retriver)
from typing import List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
class CustomRetriever(BaseRetriever):
"""
Custom retriever.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
doc = [Document(page_content="島根県には出雲大社があります。",metadata={'summary':'test'})]
return doc
使い方
import os
import openai
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from typing import List
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
class CustomRetriever(BaseRetriever):
"""
Custom retriever.
"""
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
doc = [Document(page_content="島根県には出雲大社があります。",metadata={'summary':'test'})]
return doc
os.environ["OPENAI_API_KEY"] = #あなたのAPIKEY
retriever = CustomRetriever()
# retriever.get_relevant_documents(query='島根県には何がありますか?') #Retriverがどう拾って来るかみる
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True
)
qa.run("島根県には何がありますか?")
こんな感じになります。Class部分は別のPythonファイルに分けるとスッキリすると思いますね。
実行結果
> Entering new RetrievalQA chain...
> Finished chain.
島根県には出雲大社があります。
まとめ
以上、実装方法とテンプレートになります。みなさまよきRAGライフを!