導入
Retrieval Augumented Generation(RAG)において、最近はRAG Fusionなどの技法を使うようなAdvanced RAGと、ナレッジグラフを使って検索を行うGraphRAGに人気があるようです。
特に後者のGraphRAGにおいてはMicrosoft社がOSS実装を公開したりと少し前に大きく話題となっていました。
ちなみにGraphRAGについては以下の記事がわかりやすいのではないかと思います。
このGraphRAGを実装するための軽量なパッケージとしてnano-graphragというものがあることを少し前に知りました。
これを使うとDatabricks Mosaic AI Model Servingで公開しているLLM Endpointの利用もシンプルに実装できそうでしたので、単純なGraphRAG処理を試験実装してみます。
試験はDatabricks on AWS上で行いました。ノートブックのクラスタはサーバレスです。
Step1. パッケージインストール&初期設定
ノートブックを作成し、nano-graphragのインストールとmlflowの最新化を行います。
%pip install nano-graphrag
%pip install -q -U "mlflow-skinny[databricks]>=2.17.1"
dbutils.library.restartPython()
加えて、nest_asyncio
を使ってイベントループのネストを可能にしておきます。
これをしておかないと、ノートブック内でnano-graphragを使う場合、イベントループ関連でエラーが起きます。
import nest_asyncio
nest_asyncio.apply()
Step2. データの準備
GraphRAGで利用する検索用ドキュメントを準備します。
今回はwikipediaより以下の内容をお借りしました。
import requests
def get_ja_wikipedia_page(title: str):
URL = "https://ja.wikipedia.org/w/api.php"
params = {
"action": "query",
"format": "json",
"titles": title,
"prop": "extracts",
"explaintext": True,
}
headers = {"User-Agent": "sample_graphRag"}
response = requests.get(URL, params=params, headers=headers, verify=False)
data = response.json()
page = next(iter(data["query"]["pages"].values()))
return page["extract"] if "extract" in page else None
# Wikipediaからデータを取得して、テキストファイルに保管
full_document = get_ja_wikipedia_page("ジャパン・プロフェッショナル・バスケットボールリーグ")
with open("./book.txt", 'w', encoding='utf-8') as f:
f.write(full_document)
Step3. GraphRAG用の処理関数定義
nano-graphrag中に使用するEmbedding処理とLLMによる処理を定義します。
OpenAI APIを使う場合においては必須ではないのですが、今回はDatabricks Mosaic AI Model Servingのエンドポイントを利用してみたかったので、それらを使うための関数を定義しています。
なお、エンドポイントはDatabricksのsystem.ai
スキーマ内にあるモデルから、Embedding用のモデルとしてbge_m3
、LLMとしてmeta_llama_v3_1_70b_instruct
をエンドポイントにデプロイしたものを利用しました。
import numpy as np
import os
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag.base import BaseKVStorage
from nano_graphrag._utils import compute_args_hash, wrap_embedding_func_with_attrs
import mlflow.deployments
@wrap_embedding_func_with_attrs(
embedding_dim=1024,
max_token_size=8192,
)
async def bge_m3_embedding(texts: list[str]) -> np.ndarray:
inputs = {"inputs": texts}
client = mlflow.deployments.get_deploy_client("databricks")
response = client.predict(
endpoint="embedding_bge_m3_endpoint",
inputs=inputs,
)
return [e["embedding"] for e in response["predictions"]["data"]]
async def databricks_model(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# 注:簡易化のためにkwargsのパラメータを利用していませんが、実際には利用するようにしたほうがよい
inputs = {
"messages": messages,
"max_tokens": 4000,
}
client = mlflow.deployments.get_deploy_client("databricks")
response = client.predict(
endpoint="llama3_1_70b_instruct_endpoint",
inputs=inputs,
)
return response["choices"][0]["message"]["content"]
# 不要なファイルを削除するためのユーティリティ関数
def remove_if_exist(file):
if os.path.exists(file):
os.remove(file)
Step4. ナレッジグラフの作成
では、準備したデータや処理定義を利用してナレッジグラフを作成します。
WORKING_DIR = "./sample_graphrag"
def insert():
from time import time
# 保存しているテキストデータを読み込み
with open("./book.txt") as f:
FAKE_TEXT = f.read()
# 既に存在しているファイルを削除
remove_if_exist(f"{WORKING_DIR}/vdb_entities.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_full_docs.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_text_chunks.json")
remove_if_exist(f"{WORKING_DIR}/kv_store_community_reports.json")
remove_if_exist(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml")
rag = GraphRAG(
working_dir=WORKING_DIR,
# enable_llm_cache=True,
best_model_func=databricks_model,
cheap_model_func=databricks_model,
embedding_func=bge_m3_embedding,
)
start = time()
rag.insert([FAKE_TEXT])
print("indexing time:", time() - start)
insert()
処理に数分程度かかりますが、問題なければナレッジグラフ関連のファイルがWORKIND_DIR
で指定した場所に作成されます。
Step5. GraphRAGで問い合わせ・回答生成
では、作成したナレッジグラフを使って問い合わせをしてみます。
まず、問い合わせ用の単純なラッパー関数を定義。
rag.query
のパラメータとして、mode
を指定していますが、これを変更することで通常のRAG処理(naive)だったり、GraphRAGにおけるlocal/globalの切り替えができるようです。
今回はglobalでのGraphRAG検索・回答生成を行います。
def query(query:str, mode:str="global"):
rag = GraphRAG(
working_dir=WORKING_DIR,
best_model_func=databricks_model,
cheap_model_func=databricks_model,
embedding_func=bge_m3_embedding,
)
print(
rag.query(query, param=QueryParam(mode=mode)
)
)
というわけで問い合わせをしてみます。
まずは単純な質問をGraphRAGのGlobal Searchで。
query("滋賀にあるバスケットチームの名前は?日本語で答えてください")
INFO:nano-graphrag:Load KV full_docs with 1 data
INFO:nano-graphrag:Load KV text_chunks with 10 data
INFO:nano-graphrag:Load KV llm_response_cache with 0 data
INFO:nano-graphrag:Load KV community_reports with 12 data
INFO:nano-graphrag:Loaded graph from ./sample_graphrag/graph_chunk_entity_relation.graphml with 204 nodes, 114 edges
INFO:nano-vectordb:Load (182, 1024) data
INFO:nano-vectordb:Init {'embedding_dim': 1024, 'metric': 'cosine', 'storage_file': './sample_graphrag/vdb_entities.json'} 182 data
INFO:nano-graphrag:Revtrieved 12 communities
INFO:nano-graphrag:Grouping to 1 groups for global search
------------------------------------------------------------------------------------------
滋賀レイクスターズです。
最後の行が出力結果ですが、取得したWikipediaの内容から回答されていますね。
*正確には滋賀レイクスですが、Wikipediaには滋賀レイクスターズの方が先に記載されているのでしょうがないかな。
Local Searchでもやってみます。
query("滋賀にあるバスケットチームの名前は?日本語で答えてください", mode="local")
------------------------------------------------------------------------------------------
INFO:nano-graphrag:Using 20 entites, 4 communities, 3 relations, 3 text units
滋賀にあるバスケットボールチームの名前は、滋賀レイクスです。
こっちはレイクスで回答してくれました。詳細な情報を引き出したい場合はLocal Searchの方がやはりよさそう。
では、もう少し全体的なことを聞いてみます。
まずはGlobal Search。
query("Bリーグの歴史を簡単に日本語で説明してください")
------------------------------------------------------------------------------------------
**Bリーグの歴史**
=====================
Bリーグは、日本のプロバスケットボールリーグで、日本バスケットボール協会(JBA)が統轄しています。その歴史は2016年に発足して以来、段階的に成長してきました。
**リーグの構成**
----------------
リーグはB1、B2、B3の3つのディビジョンに分かれており、各ディビジョンには複数のチームが参加しています。 JPBL(ジャパン・プロフェッショナル・バスケットボール・リーグ)は、Bリーグの運営管理団体であり、リーグのルールや規則を決めています。
**競争力の向上**
-----------------
リーグは、参加チームの競争力を高めるため、東アジアスーパーリーグやNBAとのパートナーシップを結んでいます。これにより、チームは、高レベルのトレーニングと試合を通じてスキルを磨くことができます。
**加盟クラブの増加**
----------------------
リーグの加盟クラブ数は年々増えており、現在は50以上のクラブが加盟しています。このodynは、Bリーグの成長と人気の高まりを示しています。
**試合の公開**
-----------------
Bリーグによる試合は、日本の主要新聞やインターネットで掲載されているほか、インターネット配信サービスでも視聴可能です。これにより、ファンは、各地で行われる試合を followすることができます。
次に同じ問い合わせをLocal Searchで。
query("Bリーグの歴史を簡単に日本語で説明してください", mode="local")
------------------------------------------------------------------------------------------
**Bリーグの歴史**
Bリーグは日本のプロバスケットボールリーグで、2016年に設立されました。リーグは日本バスケットボール協会(JBA)によって運営されており、現在はB1リーグとB2リーグの2つのディビジョンで構成されています。
**設立の背景**
Bリーグは前身の日本プロバスケットボールリーグ(bjリーグ)と日本バスケットボールリーグ(JBL)が統合して設立されました。bjリーグはスポーツ庁から公益社団法人認定を受けており、JBLとの統合により、プロバスケットボールリーグの再編が行われました。
**歴史的経過**
* 2016年 - Bリーグ設立
* 2017年 - B1リーグとB2リーグが誕生
* 2018年 - 東アジアスーパーリーグとパートナーシップ契約を締結
* 2020年 - ライセンス制度を新設
* 2022年 - オンザコートルールを導入
* 2023年 - 将来構想を「B.革新」に変更し、2026年から新ロゴを使用
* 2024年 - タイトルパートナーがりそなグループに決定
**現在の状況**
Bリーグは現在、B1リーグとB2リーグの2つのディビジョンで構成されており、各ディビジョンに複数のチームが参加しています。また、東アジアスーパーリーグとのパートナーシップ契約を締結しているほか、ライセンス制度を新設しています。
このように、Bリーグは日本のプロバスケットボールリーグとして、不断の進化と改善を重ねてきました。
こっちの方が具体の情報が多いですね。
temperature
を指定してないので実行のたびに回答傾向は結構変わったりするのですが、Global Searchは全体傾向を、Local Searchはより詳細な情報を取得できる傾向にあります。想定通り。
Step6. グラフを可視化
こちらのサンプルコードを参考に、構築したナレッジグラフを可視化します。
import networkx as nx
import json
import os
# load GraphML file and transfer to JSON
def graphml_to_json(graphml_file):
G = nx.read_graphml(graphml_file)
data = nx.node_link_data(G)
return json.dumps(data)
# create HTML file
def save_as_html(html_path, graph_json):
html_content = '''
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Graph Visualization</title>
<script src="https://d3js.org/d3.v7.min.js"></script>
<style>
body, html {
margin: 0;
padding: 0;
width: 100%;
height: 100%;
overflow: hidden;
}
svg {
width: 100%;
height: 100%;
}
.links line {
stroke: #999;
stroke-opacity: 0.6;
}
.nodes circle {
stroke: #fff;
stroke-width: 1.5px;
}
.node-label {
font-size: 12px;
pointer-events: none;
}
.link-label {
font-size: 10px;
fill: #666;
pointer-events: none;
opacity: 0;
transition: opacity 0.3s;
}
.link:hover .link-label {
opacity: 1;
}
.tooltip {
position: absolute;
text-align: left;
padding: 10px;
font: 12px sans-serif;
background: lightsteelblue;
border: 0px;
border-radius: 8px;
pointer-events: none;
opacity: 0;
transition: opacity 0.3s;
max-width: 300px;
}
.legend {
position: absolute;
top: 10px;
right: 10px;
background-color: rgba(255, 255, 255, 0.8);
padding: 10px;
border-radius: 5px;
}
.legend-item {
margin: 5px 0;
}
.legend-color {
display: inline-block;
width: 20px;
height: 20px;
margin-right: 5px;
vertical-align: middle;
}
</style>
</head>
<body>
<svg></svg>
<div class="tooltip"></div>
<div class="legend"></div>
<script>
[GRPH DATA];
const graphData = graphJson;
const svg = d3.select("svg"),
width = window.innerWidth,
height = window.innerHeight;
svg.attr("viewBox", [0, 0, width, height]);
const g = svg.append("g");
const entityTypes = [...new Set(graphData.nodes.map(d => d.entity_type))];
const color = d3.scaleOrdinal(d3.schemeCategory10).domain(entityTypes);
const simulation = d3.forceSimulation(graphData.nodes)
.force("link", d3.forceLink(graphData.links).id(d => d.id).distance(150))
.force("charge", d3.forceManyBody().strength(-300))
.force("center", d3.forceCenter(width / 2, height / 2))
.force("collide", d3.forceCollide().radius(30));
const linkGroup = g.append("g")
.attr("class", "links")
.selectAll("g")
.data(graphData.links)
.enter().append("g")
.attr("class", "link");
const link = linkGroup.append("line")
.attr("stroke-width", d => Math.sqrt(d.value));
const linkLabel = linkGroup.append("text")
.attr("class", "link-label")
.text(d => d.description || "");
const node = g.append("g")
.attr("class", "nodes")
.selectAll("circle")
.data(graphData.nodes)
.enter().append("circle")
.attr("r", 5)
.attr("fill", d => color(d.entity_type))
.call(d3.drag()
.on("start", dragstarted)
.on("drag", dragged)
.on("end", dragended));
const nodeLabel = g.append("g")
.attr("class", "node-labels")
.selectAll("text")
.data(graphData.nodes)
.enter().append("text")
.attr("class", "node-label")
.text(d => d.id);
const tooltip = d3.select(".tooltip");
node.on("mouseover", function(event, d) {
tooltip.transition()
.duration(200)
.style("opacity", .9);
tooltip.html(`<strong>${d.id}</strong><br>Entity Type: ${d.entity_type}<br>Description: ${d.description || "N/A"}`)
.style("left", (event.pageX + 10) + "px")
.style("top", (event.pageY - 28) + "px");
})
.on("mouseout", function(d) {
tooltip.transition()
.duration(500)
.style("opacity", 0);
});
const legend = d3.select(".legend");
entityTypes.forEach(type => {
legend.append("div")
.attr("class", "legend-item")
.html(`<span class="legend-color" style="background-color: ${color(type)}"></span>${type}`);
});
simulation
.nodes(graphData.nodes)
.on("tick", ticked);
simulation.force("link")
.links(graphData.links);
function ticked() {
link
.attr("x1", d => d.source.x)
.attr("y1", d => d.source.y)
.attr("x2", d => d.target.x)
.attr("y2", d => d.target.y);
linkLabel
.attr("x", d => (d.source.x + d.target.x) / 2)
.attr("y", d => (d.source.y + d.target.y) / 2)
.attr("text-anchor", "middle")
.attr("dominant-baseline", "middle");
node
.attr("cx", d => d.x)
.attr("cy", d => d.y);
nodeLabel
.attr("x", d => d.x + 8)
.attr("y", d => d.y + 3);
}
function dragstarted(event) {
if (!event.active) simulation.alphaTarget(0.3).restart();
event.subject.fx = event.subject.x;
event.subject.fy = event.subject.y;
}
function dragged(event) {
event.subject.fx = event.x;
event.subject.fy = event.y;
}
function dragended(event) {
if (!event.active) simulation.alphaTarget(0);
event.subject.fx = null;
event.subject.fy = null;
}
const zoom = d3.zoom()
.scaleExtent([0.1, 10])
.on("zoom", zoomed);
svg.call(zoom);
function zoomed(event) {
g.attr("transform", event.transform);
}
</script>
</body>
</html>
'''
html_content = html_content.replace("[GRPH DATA]", graph_json)
with open(html_path, 'w', encoding='utf-8') as f:
f.write(html_content)
def create_json(json_data, json_path):
json_data = "var graphJson = " + json_data.replace('\\"', '').replace("'", "\\'").replace("\n", "")
return json_data
# main function
def visualize_graphml(graphml_file, html_path, port=8000):
json_data = graphml_to_json(graphml_file)
html_dir = os.path.dirname(html_path)
if not os.path.exists(html_dir):
os.makedirs(html_dir)
json_path = os.path.join(html_dir, 'graph_json.js')
save_as_html(html_path, create_json(json_data, json_path))
visualize_graphml(f"{WORKING_DIR}/graph_chunk_entity_relation.graphml", "./graph.html")
実行するとgraph.html
が作成されるので、それをエクスポートしてブラウザで表示すると以下のようにグラフが表示されます。
見た感じ、正しいような正しくないような。。。
Llama 3.1はそこまで日本語性能が高くないため、日本語性能が高い他のモデルを使う方がよりきちんとしたグラフが作れるんじゃないかと思います。
まとめ
nano-graphragを使ったGraphRAGを実装してみました。
チャンキング関連の設定など詰めるところはまだまだ多そうですが、割とシンプルにGraphRAGが構築できました。
個人的にGraphRAGにおける最大の勘所はナレッジグラフを作るところで、適したLLMの選定やチューニングが大きな課題かなと思っています。
なるべく性能のよいモデル(プロプライエタリなAPI)を利用する方がいいのですが、結構トークン数を使うのでコストは少し気にする必要があります。
*Llama 3.2 3Bモデルを使ったナレッジグラフ作成も試してみましたが、性能不足なのか途中でエラーになりました。
いろいろ難しいところはありますが、うまくGraphRAGが適したユースケースを選定してうまく使いこなせるとハマる仕組みだなと感じています。
今回はnano-graphragを使いましたが、nano-graphragの派生であるLightRAGも気になっています。
こちらの方が精度良さそうという話もあるので、おいおい触ってみようと思います。