10
8

はじめに

最近、Retrieval Augmented Generation (RAG)の分野で興味があるGraph RAGについて少し試してみました。日本語のWikipedia記事を題材に、Neo4jとLangChainを使ってGraph RAGを実装してみます。

この記事は、LangChainのブログ記事「Enhancing RAG-based application accuracy by constructing and leveraging knowledge graphs」や、下記のYoutube動画を参考にさせて頂きました。

初学者につき、内容に誤りを含む場合があります。

RAGとは

RAG(Retrieval-Augmented Generation)は、大規模言語モデル(LLM)の能力を外部知識で拡張する手法です。従来のLLMは学習済みの汎用的な知識のみに基づいて回答を生成しますが、RAGでは質問に関連する情報を外部のデータソースから取得し、それを基に回答を生成します。これにより、以下のような利点があります:

  1. 最新の情報を含めた回答が可能
  2. ソースを明示できるため、回答の信頼性が向上
  3. 特定のドメインや組織固有の知識を活用可能(LLMが学習していない社内データなどから回答を得られる)

RAGの一般的なプロセスは次のとおりです:

  1. 質問を受け取る
  2. 関連する情報を外部データソースから検索
  3. 検索結果とオリジナルの質問をLLMに入力
  4. LLMが検索結果を参考に回答を生成

Graph RAGとは

Graph RAGは、従来のRAGをさらに発展させた手法です。通常のRAGが主にテキストベースの検索を行うのに対し、Graph RAGはグラフデータベースを活用して情報を検索します。

グラフデータベースでは、データをノード(実体)とエッジ(関係)として表現します。この構造により、エンティティ間の複雑な関係や階層構造を効果的に表現・検索することができます。

0_rm_fRSPovV1wfTqH.jpg

引用:LangChain

Graph RAGの主な利点は以下の通りです:

  1. 複雑な関係性の把握:エンティティ間の関係を直接的に表現・検索できるため、複雑な質問により適切に対応できます。

  2. コンテキストの理解向上:関連する情報をグラフ構造で表現することで、LLMがより広い文脈を理解しやすくなります。

  3. 推論能力の強化:グラフ構造を利用することで、直接的な関係だけでなく、間接的な関係も考慮した回答が可能になります。

  4. 効率的な情報検索:グラフデータベースの特性を活かし、関連情報をより効率的に検索できます。

Graph RAGのプロセスは以下のようになります:

  1. 質問を受け取る
  2. 質問から関連するエンティティを抽出
  3. グラフデータベースで関連エンティティとその関係を検索
  4. 検索結果と質問をLLMに入力
  5. LLMがグラフ構造を考慮しつつ回答を生成

本記事では、この Graph RAG を日本語のWikipedia記事に適用し、Neo4jグラフデータベースとLangChainを使用して実装していきます。

環境設定

今回のトライアルはGoogle Colabで行いました。フリーで構いません。

LLMはOpenAI APIを用いています。あらかじめAPIキーを取得します。こちらは従量課金(有料)です。

また、グラフデータベースとしてNeo4jを使用します。Neo4jは、データを節(ノード)と関係(エッジ)として保存する強力なグラフデータベースです。今回はNeo4j AuraのFreeプランを利用しました。インスタンスを立てると次に使う環境変数3つを取得できます。

環境変数の設定は以下のように行います:

OpenAIのAPIキーとNeo4j各種パラメータを入手したら、GoogleColaboratory左ペインのシークレットから下記の各パラメータをシークレットに登録しておきます

  • OPENAI_API_KEY
  • NEO4J_URI
  • NEO4J_USERNAME
  • NEO4J_PASSWORD
from google.colab import userdata

os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
os.environ["NEO4J_URI"] = userdata.get('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = userdata.get('NEO4J_USERNAME')
os.environ["NEO4J_PASSWORD"] = userdata.get('NEO4J_PASSWORD')

必要なライブラリのインストールとインポート

%%capture
%pip install --upgrade --quiet  wikipedia neo4j langchain langchain-community langchain-openai langchain-experimental tiktoken yfiles_jupyter_graphs
#MeCabのインストール
!apt-get -q -y install sudo file mecab libmecab-dev mecab-ipadic-utf8 git curl python-mecab > /dev/null
!pip install mecab-python3 unidic-lite > /dev/null

まずは必要なライブラリをインポートします:

import os
import re
import wikipedia
import MeCab
from typing import List, Tuple
from google.colab import userdata
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.output_parsers import StrOutputParser
from langchain.output_parsers import PydanticOutputParser
from langchain_community.graphs import Neo4jGraph
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.vectorstores import Neo4jVector
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain.schema import Document

try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

テキストの形態素解析

MeCabを使用して日本語テキストを形態素解析します。特段やらなくてもLLMがよしなにやってくれるような気もしましたが...元の英語記事がクエリをスペース区切りにしていたのでそれに対応すべく一応。

#MeCabを使用してテキストを形態素解析し、固有表現を抽出する関数を定義
def tokenize_and_extract_entities(text: str) -> Tuple[List[str], List[str]]:
    """
    テキストを形態素解析し、トークンと固有表現を抽出する
    
    :param text: 解析対象のテキスト
    :return: トークンのリストと固有表現のリスト
    """
    tagger = MeCab.Tagger()
    parsed = tagger.parse(text)
    
    tokens = []
    entities = []
    
    for line in parsed.split('\n'):
        if line == 'EOS':
            break
        
        parts = line.split('\t')
        if len(parts) < 2:
            continue  # 無効な行はスキップ
        
        surface = parts[0]
        feature = parts[1]
        
        features = feature.split(',')
        
        tokens.append(surface)
        
        # 固有名詞の抽出
        if len(features) > 1 and features[0] == '名詞' and features[1] in ['固有名詞', '人名', '組織', '地名']:
            entities.append(surface)
    
    return tokens, entities

# テキスト分割関数の定義
def split_text(text: str, max_length: int = 1000) -> List[str]:
    """
    テキストを指定された最大長で分割する
    
    :param text: 分割対象のテキスト
    :param max_length: 分割後の最大文字数
    :return: 分割されたテキストのリスト
    """
    return [text[i:i+max_length] for i in range(0, len(text), max_length)]

Wikipediaデータの取得と前処理

今回は「この素晴らしい世界に祝福を!」のWikipedia記事を使用します。

article = "この素晴らしい世界に祝福を"
# Wikipediaの日本語記事を用いる
wikipedia.set_lang("ja")
page = wikipedia.page(article, auto_suggest=False)
content = re.sub('(.+?)', '', page.content)  # ふりがなを除去

# テキストを分割
chunks = split_text(content)

all_tokens = []
all_entities = []

for chunk in chunks:
    try:
        tokens, entities = tokenize_and_extract_entities(chunk)
        all_tokens.extend(tokens)
        all_entities.extend(entities)
    except Exception as e:
        print(f"チャンク処理中にエラーが発生しました: {e}")
        continue

# MeCabで処理したトークンとエンティティを使用して、より適切なDocumentオブジェクトを作成
processed_content = " ".join(all_tokens)  # トークンを空白で結合

# エンティティ情報をメタデータとして追加
metadata = {"entities": list(set(all_entities))}  # 重複を除去

# Documentオブジェクトの作成
raw_documents = [Document(page_content=processed_content, metadata=metadata)]

# チャンクサイズの調整(必要に応じて)
text_splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=64)

# ドキュメントを分割
documents = text_splitter.split_documents(raw_documents)

おまけ)形態素解析を用いない場合contentのままDocumentオブジェクトに変換

# raw_documents = [Document(page_content=content)]
# text_splitter = TokenTextSplitter(chunk_size=1024, chunk_overlap=64) #チャンクサイズと前後の重なり部分を指定 
# documents = text_splitter.split_documents(raw_documents) #ドキュメントを分割

グラフの生成と保存

ここの処理は数分程度必要でした。LangChainのLLMGraphTransformerを使用してグラフを生成し、Neo4jに保存します。
なお、LLMのモデルは安いもの(GPT-3.5)を使用しています。

llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125") #モデルはお好みで指定
llm_transformer = LLMGraphTransformer(llm=llm) #LLMTransformerオブジェクトを作成

#ドキュメントをグラフに変換。documents全部入れると時間もコストもかかるので、documents[0:10]など小規模で試してみると良い
graph_documents = llm_transformer.convert_to_graph_documents(documents[0:10])

graph = Neo4jGraph() #Neo4jGraphオブジェクトを作成

#Neo4jGraphオブジェクトにドキュメントを追加
graph.add_graph_documents(
    graph_documents, 
    baseEntityLabel=True,
    include_source=True
)

グラフを可視化してみます

# 指定されたCypherクエリから得られるグラフを直接表示する関数 ノードの数はここで調整します。増やすと重いので注意
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 100"

def showGraph(cypher: str = default_cypher):
    # Neo4jセッションを作成してクエリを実行するための準備
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()

    # GraphWidgetを使用してグラフを可視化
    widget = GraphWidget(graph = session.run(cypher).graph())

    # ノードのラベルとしてidプロパティを使用
    widget.node_label_mapping = 'id'

    # ウィジェットを返す(他の場所で表示や操作が可能)
    return widget

#可視化の実行
showGraph()

スクリーンショット 2024-06-22 20.16.30.png

英語で書いているエッジ(ノード間の関係)はLLMで自動で割り振られたものです。これを記事から手作業で作ろうとしたらどれだけ大変でしょうか・・・この時点で試してみてよかったと思うくらいです。
さて、細かく確認してみましょう。アクアのVOICE_ACTOR(声優)が雨宮天さんであることやめぐみんが爆裂魔法がFAVORITEであること、ダクネスは本名がララティーナとNAMED_ASされていることなど、概ね正しそうです。

ハイブリッド検索の設定

ベクトル検索とキーワード検索を組み合わせたハイブリッド検索を設定します

vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),  # OpenAIの埋め込みモデルを使用してテキストをベクトル化
    search_type="hybrid",  # ベクトル検索とキーワード検索を組み合わせたハイブリッド検索を使用
    node_label="Document",  # 'Document'ラベルを持つノードを対象とする
    text_node_properties=["text"],  # 'text'プロパティの内容をベクトル化の対象とする
    embedding_node_property="embedding"  # ベクトル(埋め込み)を'embedding'プロパティに格納
)
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

エンティティ抽出とグラフ検索

ユーザーの質問からエンティティを抽出し、グラフデータベースから関連情報を取得します

class Entities(BaseModel):
    names: List[str] = Field(..., description="テキスト内に出現する全ての人物、組織エンティティ")

prompt = ChatPromptTemplate.from_messages([
    ("system", "You are extracting organization and person entities from the text. Output should be in JSON format with a 'names' key containing a list of extracted entities."),
    ("human", "Extract entities from the following input: {question}")
])

parser = PydanticOutputParser(pydantic_object=Entities)
entity_chain = prompt | llm | parser

def generate_full_text_query(input: str) -> str:
    tagger = MeCab.Tagger()
    nodes = tagger.parseToNode(input)
    
    important_words = []
    while nodes:
        if nodes.feature.split(',')[0] in ['名詞', '動詞', '形容詞']:
            important_words.append(nodes.surface)
        nodes = nodes.next
    
    if not important_words:
        return input
    
    return ' OR '.join(f'"{word}"' for word in important_words)

def structured_retriever(question: str) -> str:
    try:
        entities = entity_chain.invoke({"question": question})
        if not entities.names:
            return "質問に関連するエンティティが見つかりませんでした。"
        
        result = ""
        for entity in entities.names:
            query = generate_full_text_query(entity)
            if query:
                try:
                    response = graph.query(
                        """CALL db.index.fulltext.queryNodes('entity', $query, {limit:20})
                        YIELD node,score
                        CALL {
                          WITH node
                          MATCH (node)-[r:!MENTIONS]->(neighbor)
                          RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
                          UNION ALL
                          WITH node
                          MATCH (node)<-[r:!MENTIONS]-(neighbor)
                          RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
                        }
                        RETURN output LIMIT 1000
                        """,
                        {"query": query},
                    )
                    result += "\n".join([el['output'] for el in response])
                except Exception as e:
                    print(f"クエリ実行中にエラーが発生しました: {e}")
        
        return result if result else "関連情報が見つかりませんでした。"
    except Exception as e:
        print(f"エンティティ抽出中にエラーが発生しました: {e}")
        return "エンティティの抽出中にエラーが発生しました。"

最終的なRAGチェーンの構築

構造化データと非構造化データを組み合わせて最終的なRAGチェーンを構築します。

工夫ポイント
ナレッジグラフで対応できない質問を受けるとエラーになるので、関連情報がないことを述べた上で一般論で回答するようにしています。「わかりません」と回答するのでもよいでしょう。

def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
    {structured_data}
    Unstructured data:
    {"#Document ". join(unstructured_data)}
    """
    return final_data

_search_query = RunnableLambda(lambda x: x["question"])

template = """あなたは優秀なAIです。下記のコンテキストを利用してユーザーの質問に丁寧に答えてください。
必ず文脈からわかる情報のみを使用して回答を生成してください。
コンテキストに関連情報がない場合は、その旨を述べた上で一般的な回答を提供してください。

コンテキスト:
{context}

ユーザーの質問: {question}"""

prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

動作確認

最後に、構築したRAGチェーンを使って質問に答えてみます

chain.invoke({"question": "めぐみんの好きなことは?"})
# めぐみんの好きなことは爆裂魔法です。
chain.invoke({"question": "カズマは何故異世界にきたのですか?"})
# カズマは、アクアを腹いせに異世界に持っていくモノとして指定し、転移に巻き込んだことから異世界にやってきました。その後、異世界「アクセルの街」に降り立ち、運勢だけが異常にいい平凡な冒険者となりました。

シンプルな質問には、なかなか正確に回答できているようです。

chain.invoke({"question": "アクアの友達の友達は誰?"})
# アクアの友達の友達は、アクアの友達であるカズマの友達であり、ダストという人物です。
chain.invoke({"question": "アクアとゆんゆんの関係は?"})
# アクアとゆんゆんの関係は、ゆんゆんがめぐみんと親しい友人であり、めぐみんとアクアも友人関係にあることから、間接的につながっています。ゆんゆんはめぐみんと親しい友人であるため、アクアとも一定の関係性が存在していると言えます。

友達の友達、というクエリはグラフデータベースの強みが生きるシーンだと思います。

chain.invoke({"question": "アクアについて教えてください"})
# アクアは、異世界に転生した日本人の高校生であり、水を司る女神です。彼女はアクセルの街でアクア、めぐみん、ダクネス、そして冒険者のカズマというパーティーを結成し、様々な事件に巻き込まれています。戦闘では盗賊やアーチャー、リッチーのスキルを駆使し、他のメンバーに的確な指示やサポートを行います。また、アクアは信仰心が深く、エリス教の信者であり、毎日エリス教会を訪れて祈りを捧げています。一部からは嫌われることもありますが、仲間からは概ね信頼されています。

アクアについての回答にはいくつか大きな誤りが見られます。高校生ではないですし、盗賊やアーチャー、リッチーのスキルを駆使するのはカズマです。なによりアクシズ教の女神なのでエリス教の信者では絶対にありません...
オープンな質問だとこうなってしまうのか、モデルをGPT-4oなどにすればうまくいくのか、検証しがいはありそうです。ただ、誤った情報もすべて「このすば」の範囲内の情報であることは重要なポイントです。例えば次のような質問をすると...

chain.invoke({"question": "日本で一番高い山は?"})
#申し訳ありませんが、提供された情報からは日本で一番高い山に関する情報は得られませんでした。一般的に知られている情報として、日本で一番高い山は富士山です。

ナレッジグラフで検索できない場合はそれを明言します。生成AIのハルシネーション(でたらめ)をある程度抑え込むことができるのは実務適用する際に重要なポイントだと思います。

まとめ

この記事では、日本語のWikipedia記事を使ってGraph RAGを実装する方法を紹介しました。
Graph RAGは従来のRAGと比べて、エンティティ間の関係性をより明確に捉えることができるため、複雑な質問に対してもより正確な回答を生成できる可能性があります。
今回の実装では完璧と呼べる結果にはなりませんでしたが、今後さらに改良を重ねることで、便利な日本語テキスト解析や質問応答システムなどの構築が可能になるかもしれません。

この記事を書いた人

10
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
8