導入
前回に引き続き、以下のLangChain Blogより、Self-RAGのCookbookを改変してウォークスルーしてみます。
なお、検証はDatabricks on AWSで行いました。
DBRは14.3ML LTS、g5.xlargeのクラスタを利用しています。
Self-RAGとは
上記Blogより、一部邦訳。
Self-RAGは、他のいくつかの興味深いRAGのアイデア(Paper)と関連するアプローチです。このフレームワークは、RAGプロセスのさまざまな段階を管理する自己反射トークンを生成するようにLLMをトレーニングします。
うん、さすがにこれだけだとわからないですね。
簡略化した流れが以下の図として提示されています。
わたしの理解だと、「Retrieverで取得した関連文書が質問と関連するか」というチェックや、LLMが生成した回答についても「取得文書や質問と関連するか」などを適切に自己チェックし、不十分な場合はクエリを修正した上で前の工程からやり直すようなフローを繰り返すことで、十分な回答品質を保つようにするアプローチという認識です。
CRAGのようにWeb検索するノードを加えるなどといったことも出来るため、こちらもユースケースに応じて様々な拡張ができそうです。
今回は以下のCookbookで示された内容を基にまた魔改造して実装してみます。
このCookbookで示されているフローは以下の図のようになります。
CRAGの時に比べると、Conditional Edge(条件分岐部分)が3か所になり、自己チェックに合格しない場合はクエリを修正した上でRetrievからやり直すような動きになっています。
では、実際に動かして挙動を見ていきましょう。
Step1~Step4. もろもろ準備
パッケージのインストールから、ベクトルストアの準備までは、まるまるCRAG準備編と同じです。
1か所のみ、Retrieverの取得文書数を1に設定しておきます。
# Retriever取得。検索件数は1件固定
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
Step5. 状態クラスの定義
ここもCRAG実践編と同様です。
全ての状態をDict型で保管する構成となります。
from typing import Dict, TypedDict
class GraphState(TypedDict):
    """
    グラフの状態を表します。
    Attributes:
        keys: 各キーが文字列である辞書
    """
    keys: Dict[str, any]
Step6. ノードとエッジの定義
再掲ですが、以下のフローのようなグラフノードとなる処理(関数)と、グラフの分岐条件を制御する関数を定義します。最終的に、これらの関数を利用してLangGraphでグラフを構築します。
まずは、各ノードから。
retrieve
関連文書を検索するノード処理の定義です。
Step4で作成したRetrieverを使って文書を取得し、グラフの状態オブジェクトに格納して返すだけのシンプルな内容です。
import json
import operator
from typing import Annotated, Sequence, TypedDict
from langchain.schema import Document
import sglang as sgl
def retrieve(state):
    """
    文書を検索する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        state (dict): 検索された文書を含んだグラフ状態
    """
    print("---RETRIEVE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = retriever.get_relevant_documents(question)
    return {"keys": {"documents": documents, "question": question}}
grade_documents
retrieveで取得した文書が質問と関連するかどうかをLLMに判定させ、関連する文書だけにフィルタするノード処理の定義です。
オリジナルのCookbookではOpenAI GPT-4を利用して関連を判定する処理となっていましたが、今回はロード済みのローカルLLM(OpenChat v1.5 7B)を利用するように改変しています。
CRAG実践編ではSGLangのJSONデコード機能を利用しましたが、今回はchoice機能でyes/noをシンプルに返すよう制御しています。
def grade_documents(state):
    """
    文書が質問と関連するかどうかを判定する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        state (dict): 関連文書のみを残したグラフ状態
    """
    print("---CHECK RELEVANCE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    @sgl.function
    def grade_gen(s, context, question):
        template = (
            "You are a grader assessing relevance of a retrieved document to a user question. \n"
            f"Here is the retrieved document: \n\n{context} \n\n"
            f"Here is the user question: {question} \n"
            "If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n"
            "Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\n"
            "Please return 'no' even if the document is an unintelligible sentence.\n\n"
        )
        s += sgl.user(template)
        s += sgl.assistant(sgl.gen("output", choices=["yes", "no"]))
    # Score
    filtered_docs = []
    for d in documents:
        s = grade_gen.run(
            question=question,
            context=d.page_content,
            temperature=0,
        )
        grade = s.get_var("output")
        if grade == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            continue
    return {
        "keys": {
            "documents": filtered_docs,
            "question": question,
        }
    }
transform_query
検索クエリを、より検索に適した形へ変換するノード処理の定義です。
CRAG実践編ではWeb検索に特化した形へ変換しましたが、今回はRetrieverに与えるクエリとして変換します。
def transform_query(state):
    """
    クエリをより検索に適した形に変換する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        state (dict): クエリを適した形に置き換えたグラフ状態
    """
    print("---TRANSFORM QUERY---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    @sgl.function
    def query_gen(s, question):
        template = (
            "You are generating questions that is well optimized for retrieval. \n"
            "Look at the input and try to reason about the underlying sematic intent / meaning. \n"
            "Here is the initial question:"
            "\n ------- \n"
            f"{question}"
            "\n ------- \n"
            "You must reply only improved question sentence."
        )
        s += sgl.user(template)
        s += sgl.assistant(
            # "Formulate an improved question:\n\n"
            sgl.gen("output", max_tokens=128, frequency_penalty=1.1)
        )
    s = query_gen.run(
        question=question,
        temperature=0,
    )
    better_question = s.get_var("output")
    return {"keys": {"documents": documents, "question": better_question}}
generate
検索文書を基にクエリの回答をLLMで生成するノード処理の定義です。
オリジナルのCookbookではLangChainのLCELで処理を記述していましたが、今回はSGLang単体での推論処理にしています。
def generate(state):
    """
    回答を生成する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        state (dict): LLMの回答を追加したグラフ状態
    """
    print("---GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    @sgl.function
    def answer_gen(s, context, question):
        template = (
            "You are an assistant for question-answering tasks. "
            "Use the following pieces of retrieved context to answer the question. "
            "If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n"
            f"Question:\n{question} \n\n"
            f"Context:\n{context} \n"
        )
        s += sgl.user(template)
        s += sgl.assistant(sgl.gen("output", max_tokens=256, frequency_penalty=1.1))
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    # Run
    generation = answer_gen.run(
        question=question,
        context=format_docs(documents),
        temperature=0,
    ).get_var("output")
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }
prepare_for_final_grade
最終グレードの分岐条件の手前に設定するノードです。
行っていることは、入力されたグラフの状態をほぼそのまま次に流しているだけです。
私の理解だと、現状LangGraphは分岐条件用のエッジを連続で指定できないため、間にこのノードを挟んでいるという認識です。(理解が間違っていたらツッコミください)
def prepare_for_final_grade(state):
    """
    最終グレードの判定に向けて、状態をそのままスルーする
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        state (dict): 現在のグラフ状態
    """
    print("---FINAL GRADE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    return {
        "keys": {"documents": documents, "question": question, "generation": generation}
    }
次に、条件分岐を行うエッジ用の処理です。
decide_to_generate
検索文書の関連性評価をした後に、LLMによる回答生成(generate)に遷移してよいかどうかの条件分岐を行うエッジ処理の定義です。
後ほどこの処理を利用し、関連性のあると判定された文書が1件以上あれば生成するノードへ遷移し、そうでなければtransform_queryへ遷移するようなグラフを構築します。
def decide_to_generate(state):
    """
    検索文書のグレード判定から、回答生成に遷移するかの判定
    Args:
        state (dict): 現状のグラフ状態
    Returns:
        str: 決定した次のノード名
    """
    print("---DECIDE TO GENERATE---")
    state_dict = state["keys"]
    question = state_dict["question"]
    filtered_documents = state_dict["documents"]
    if not filtered_documents:
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print("---DECISION: TRANSFORM QUERY---")
        return "transform_query"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"
grade_generation_v_documents
生成された回答の根拠が検索文書の中に含まれているかどうかを判定するエッジ処理の定義です。
後ほどこの処理を利用し、適切でないと判定されればtransform_queryへ遷移するようなグラフを構築します。
def grade_generation_v_documents(state):
    """
    生成された回答の根拠が検索文書の中に含まれているかどうかを判定する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        str: 分岐先
    """
    print("---GRADE GENERATION vs DOCUMENTS---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    @sgl.function
    def grade_gen(s, documents, generation):
        template = (
            "You are a grader assessing whether an answer is grounded in / supported by a set of facts. \n"
            f"Here are the facts: \n\n{documents} \n\n"
            f"Here is the answer: {generation} \n"
            "Give a binary score 'yes' or 'no' to indicate whether the answer is grounded in / supported by a set of facts."
        )
        s += sgl.user(template)
        s += sgl.assistant(sgl.gen("output", choices=["yes", "no"]))
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    s = grade_gen.run(
        generation=generation,
        documents=format_docs(documents),
        temperature=0,
    )
    grade = s.get_var("output")
    if grade == "yes":
        print("---DECISION: SUPPORTED, MOVE TO FINAL GRADE---")
        return "supported"
    else:
        print("---DECISION: NOT SUPPORTED, GENERATE AGAIN---")
        return "not supported"
grade_generation_v_question
生成された回答が質問を解消する内容となっているかどうかを判定するエッジ処理の定義です。
後ほどこの処理を利用し、適切でないと判定されればtransform_queryへ遷移するようなグラフを構築します。
def grade_generation_v_question(state):
    """
    生成された回答が質問を解消する内容となっているかどうかを判定する
    Args:
        state (dict): 現在のグラフ状態
    Returns:
        str: 分岐先
    """
    print("---GRADE GENERATION vs QUESTION---")
    state_dict = state["keys"]
    question = state_dict["question"]
    documents = state_dict["documents"]
    generation = state_dict["generation"]
    # JSONテンプレート
    grade_regex = r"""\{\n""" + r"""    "score": "(yes|no)"\n""" + r"""\}"""
    @sgl.function
    def grade_gen(s, documents, generation):
        template = (
            "You are a grader assessing whether an answer is useful to resolve a question. \n"
            f"Here are the facts: \n\n{documents} \n\n"
            f"Here is the question: {question} \n"
            "Give a binary score 'yes' or 'no' to indicate whether the answer is useful to resolve a question."
        )
        s += sgl.user(template)
        s += sgl.assistant(sgl.gen("output", choices=["yes", "no"]))
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)
    s = grade_gen.run(
        generation=generation,
        documents=format_docs(documents),
        temperature=0,
    )
    grade = s.get_var("output")
    if grade == "yes":
        print("---DECISION: USEFUL---")
        return "useful"
    else:
        print("---DECISION: NOT USEFUL---")
        return "not useful"
以上がノードと条件分岐エッジの定義です。
Step7. グラフの構築
定義したノード用関数等を利用し、グラフを構築します。
こちらはCookbookのままです。フロー図に従ってノードをエッジで接続します。
import pprint
from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
workflow.add_node("transform_query", transform_query)  # transform_query
workflow.add_node("prepare_for_final_grade", prepare_for_final_grade)  # passthrough
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "retrieve")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents,
    {
        "supported": "prepare_for_final_grade",
        "not supported": "generate",
    },
)
workflow.add_conditional_edges(
    "prepare_for_final_grade",
    grade_generation_v_question,
    {
        "useful": END,
        "not useful": "transform_query",
    },
)
# Compile
app = workflow.compile()
Step8. 実行
グラフが構築できたので、推論してみます。
まず、ちゃんとしたクエリを基に実行してみます。
inputs = {"keys": {"question": "葬送のフリーレンの原作者は誰?"}}
output = app.invoke(inputs)
print(output["keys"]["generation"])
---RETRIEVE---
---CHECK RELEVANCE---
---GRADE: DOCUMENT RELEVANT---
---DECIDE TO GENERATE---
---DECISION: GENERATE---
---GENERATE---
---GRADE GENERATION vs DOCUMENTS---
---DECISION: SUPPORTED, MOVE TO FINAL GRADE---
---FINAL GRADE---
---GRADE GENERATION vs QUESTION---
---DECISION: USEFUL---
葬送のフリーレンの原作者は、山田鐘人です。
辿ったノードの過程と最終アウトプットを出力しています。
全てのグレード判定がパスしており、一直線で回答生成に至っていることがわかります。
では、より雑なクエリを投げてみましょう。
inputs = {"keys": {"question": "原作は?"}}
output = app.invoke(inputs)
print(output["keys"]["generation"])
---RETRIEVE---
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY---
---TRANSFORM QUERY---
---RETRIEVE---
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY---
---TRANSFORM QUERY---
---RETRIEVE---
---CHECK RELEVANCE---
---GRADE: DOCUMENT RELEVANT---
---DECIDE TO GENERATE---
---DECISION: GENERATE---
---GENERATE---
---GRADE GENERATION vs DOCUMENTS---
---DECISION: SUPPORTED, MOVE TO FINAL GRADE---
---FINAL GRADE---
---GRADE GENERATION vs QUESTION---
---DECISION: USEFUL---
原作作者は、山田鐘人です。
最終的に回答を得られていますが、先ほどに比べて、何度かグレードチェックでNGになったのがわかります。
これだけだとどのような動作になったのかわかりづらいので、詳細を確認するためにノード実行後の状態を出力して確認してみます。
inputs = {"keys": {"question": "原作は?"}}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    print()
# Final generation
pprint.pprint(value["keys"]["generation"])
---RETRIEVE---
{'documents': [Document(page_content='== 用語 ==')], 'question': '原作は?'}
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
{'documents': [], 'question': '原作は?'}
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY---
---TRANSFORM QUERY---
{'documents': [], 'question': '何作品の原作は?'}
---RETRIEVE---
{'documents': [Document(page_content='== 用語 ==')], 'question': '何作品の原作は?'}
---CHECK RELEVANCE---
---GRADE: DOCUMENT NOT RELEVANT---
{'documents': [], 'question': '何作品の原作は?'}
---DECIDE TO GENERATE---
---DECISION: TRANSFORM QUERY---
---TRANSFORM QUERY---
{'documents': [], 'question': '何作品の原作作者は?'}
---RETRIEVE---
{ 'documents': [ Document(page_content='『葬送のフリーレン』(そうそうのフリーレン)は、山田鐘人(原作)、アベツカサ(作画)による日本の漫画。『週刊少年サンデー』(小学館)にて、2020年22・23合併号より連載中。\n第14回マンガ大賞、第25回手塚治虫文化賞新生賞、第69回小学館漫画賞受賞作。')],
  'question': '何作品の原作作者は?'}
---CHECK RELEVANCE---
---GRADE: DOCUMENT RELEVANT---
{ 'documents': [ Document(page_content='『葬送のフリーレン』(そうそうのフリーレン)は、山田鐘人(原作)、アベツカサ(作画)による日本の漫画。『週刊少年サンデー』(小学館)にて、2020年22・23合併号より連載中。\n第14回マンガ大賞、第25回手塚治虫文化賞新生賞、第69回小学館漫画賞受賞作。')],
  'question': '何作品の原作作者は?'}
---DECIDE TO GENERATE---
---DECISION: GENERATE---
---GENERATE---
{ 'documents': [ Document(page_content='『葬送のフリーレン』(そうそうのフリーレン)は、山田鐘人(原作)、アベツカサ(作画)による日本の漫画。『週刊少年サンデー』(小学館)にて、2020年22・23合併号より連載中。\n第14回マンガ大賞、第25回手塚治虫文化賞新生賞、第69回小学館漫画賞受賞作。')],
  'generation': '原作作者は、山田鐘人です。',
  'question': '何作品の原作作者は?'}
---GRADE GENERATION vs DOCUMENTS---
---DECISION: SUPPORTED, MOVE TO FINAL GRADE---
---FINAL GRADE---
{ 'documents': [ Document(page_content='『葬送のフリーレン』(そうそうのフリーレン)は、山田鐘人(原作)、アベツカサ(作画)による日本の漫画。『週刊少年サンデー』(小学館)にて、2020年22・23合併号より連載中。\n第14回マンガ大賞、第25回手塚治虫文化賞新生賞、第69回小学館漫画賞受賞作。')],
  'generation': '原作作者は、山田鐘人です。',
  'question': '何作品の原作作者は?'}
---GRADE GENERATION vs QUESTION---
---DECISION: USEFUL---
{ 'documents': [ Document(page_content='『葬送のフリーレン』(そうそうのフリーレン)は、山田鐘人(原作)、アベツカサ(作画)による日本の漫画。『週刊少年サンデー』(小学館)にて、2020年22・23合併号より連載中。\n第14回マンガ大賞、第25回手塚治虫文化賞新生賞、第69回小学館漫画賞受賞作。')],
  'generation': '原作作者は、山田鐘人です。',
  'question': '何作品の原作作者は?'}
'原作作者は、山田鐘人です。'
雑なクエリだったため、最初のRetrieveの時点では関連した文書が取得できていないのがわかります。
その後、グレードチェックが入り、判定NGだったためtransform_queryのノードに遷移しています。
ここで、より検索性能を上げるために、 「原作は?」 というクエリから 「何作品の原作は?」 というクエリに変換が行われています。
その後、再度Retrieveからやり直されており、残念ながらここでもNGになり、さらにクエリが 「何作品の原作作者は?」 に変換されています。ようやくこれで関連文書が取得されるようになりました。
結果として生成結果含めて以後のグレードチェックも通り、正しい回答を得られています。
このように、いくつかの自己品質チェックを行いながら必要に応じてクエリをアップデートし、最終的に目的にそった回答を得ることができるようになります。
ちなみに、CRAGと異なりWeb検索との併用を今回はしていないため、文書に存在しない内容を聞くと反復回数の上限に達して例外が発生します。実務的には一定のループ回数を越えたら回答不能な旨を返すような実装にするべきでしょう。
まとめ
Self-RAGをLangChainのCookbookを基に実践してみました。
RAGで正しい結果を得るためには、そもそも入力となるクエリの品質が高いことを前提にしがちです。ただ、ユーザによってクエリの品質はばらつくため、このように品質チェックとなるゲートを設けながら自己で精度をあげるような仕組は重要だと思います。
しかし、LangGraph面白いなあ。シンプルな処理はLCELで書く方が簡単ですが、拡張性等を考えるとLangGraphで処理作る方がよさげに感じています。


