0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

OSS(Qdrant)にPull Requestしてみた

Posted at

はじめに

実務でLLMアプリケーションを開発しています。RAG※の手法の一つで、GraphRAGを導入する時にベクトルDBのQdrant(OSS)とグラフDBのNep4jをこちらの記事を参考に連携させました。

※そもそもRAGとは?

その際、Qdrantの実装の一部に懸念点を見つけ、初めてOSSにPull Requestを出したので記録として残します。

Pull Requestの内容

本件のアーキテクチャは以下となります。

(1). チャンクに含まれるノードとエッジをLLMにより抽出する

e.g.
Bob and Carol are friends.
ノード:Bob, Carol
エッジ:friends

(2). 抽出したノードとエッジをNeo4jに保存する
(3). 抽出したノードに対応するノードIDと各チャンクのベクトルデータをQdrantに保存する
(4). Qdrantでベクトル検索した結果のノードIDを用いてNeo4Jでグラフ検索をする

(1)について、現在の実装では各チャンクを全て繋げて一つの文章からノードとエッジをLLMにより抽出しています。結果としてサンプルのチャンク数25に対して、ノード数21となります。

graphrag.py
def extract_graph_components(raw_data):
    prompt = f"Extract nodes and relationships from the following text:\n{raw_data}"

    parsed_response = openai_llm_parser(prompt)  # Assuming this returns a list of dictionaries
    parsed_response = parsed_response.graph  # Assuming the 'graph' structure is a key in the parsed response

    nodes = {}
    relationships = []

    for entry in parsed_response:
        node = entry.node
        target_node = entry.target_node  # Get target node if available
        relationship = entry.relationship  # Get relationship if available

        # Add nodes to the dictionary with a unique ID
        if node not in nodes:
            nodes[node] = str(uuid.uuid4())

        if target_node and target_node not in nodes:
            nodes[target_node] = str(uuid.uuid4())

        # Add relationship to the relationships list with node IDs
        if target_node and relationship:
            relationships.append({
                "source": nodes[node],
                "target": nodes[target_node],
                "type": relationship
            })

    return nodes, relationships

そして(3)で抽出したノードに対応するノードIDと各チャンクのベクトルデータをfor-zip関数で繰り返し処理をしています。

graphrag.py
def ingest_to_qdrant(collection_name, raw_data, node_id_mapping):
    embeddings = [openai_embeddings(paragraph) for paragraph in raw_data.split("\n")]

    qdrant_client.upsert(
        collection_name=collection_name,
        points=[
            {
                "id": str(uuid.uuid4()),
                "vector": embedding,
                "payload": {"id": node_id}
            }
            for node_id, embedding in zip(node_id_mapping.values(), embeddings)
        ]
    )

ここでfor-zip関数は対応するペア(ノードIDとチャンクのベクトルデータ)が少ない方の数に合わせて繰り返し処理をします。よって、現在の実装ではサンプルのチャンク数25に対してノード数21となるため、チャンクの内4つの情報が抜け落ちてしまいます。
そのため、チャンク数とノード数を同数にし、情報が抜け落ちないように改修したコードが、今回のPull Requestの内容になります。

改修点は各チャンク毎にノードとエッジを抽出し、一つのチャンクに複数のノードがあり得るため、ノードはリストにし、チャンク数とノード数を同数にしてfor-zip関数に渡しています。また、ベクトル検索した一つの結果に対して複数のノードIDがあり得るため、取りこぼさないようにベクトル検索などを定義している既存のQdrantNeo4jRetrieverを継承してCustomQdrantNeo4jRetrieverを作成しています。

コードの一部を載せます。全体はこちらにあります。

graphrag.py
def extract_graph_components(chunks):
    nodes_list = []
    relationships_list = []
    for chunk in chunks:
        prompt = f"Extract nodes and relationships from the following text:\n{chunk}"

        parsed_response = openai_llm_parser(prompt)  # Assuming this returns a list of dictionaries
        parsed_response = parsed_response.graph  # Assuming the 'graph' structure is a key in the parsed response

        nodes = {}
        relationships = []

        for entry in parsed_response:
            node = entry.node
            target_node = entry.target_node  # Get target node if available
            relationship = entry.relationship  # Get relationship if available

            # Add nodes to the dictionary with a unique ID
            if node not in nodes:
                nodes[node] = str(uuid.uuid4())

            if target_node and target_node not in nodes:
                nodes[target_node] = str(uuid.uuid4())

            # Add relationship to the relationships list with node IDs
            if target_node and relationship:
                relationships.append({
                    "source": nodes[node],
                    "target": nodes[target_node],
                    "type": relationship
                })
        
        nodes_list.append(nodes)
        relationships_list.append(relationships)

    return nodes_list, relationships_list
graphrag.py
def ingest_to_qdrant(collection_name, chunks, node_id_mapping_list):
    embeddings = [openai_embeddings(chunk) for chunk in chunks]

    qdrant_client.upsert(
        collection_name=collection_name,
        points=[
            {
                "id": str(uuid.uuid4()),
                "vector": embedding,
                "payload": {"id": list(node_id_mapping.values())}
            }
            for node_id_mapping, embedding in zip(node_id_mapping_list, embeddings)
        ]
    )
custom_qdrant_neo4j_retriever.py
from __future__ import annotations

import logging
from typing import Any, Optional

import neo4j
from neo4j_graphrag.exceptions import EmbeddingRequiredError, SearchValidationError
from neo4j_graphrag.retrievers import QdrantNeo4jRetriever
from neo4j_graphrag.retrievers.external.utils import get_match_query
from neo4j_graphrag.types import RawSearchResult, VectorSearchModel
from pydantic import ValidationError

logger = logging.getLogger(__name__)


class CustomQdrantNeo4jRetriever(QdrantNeo4jRetriever):
    """
    Custom retriever inheriting from QdrantNeo4jRetriever.
    Handles cases where the external ID in Qdrant payload might be a list.

    Inherits initialization and other methods from QdrantNeo4jRetriever.
    Only overrides the get_search_results method for custom logic.
    """

    def get_search_results(
        self,
        query_vector: Optional[list[float]] = None,
        query_text: Optional[str] = None,
        top_k: int = 5,
        **kwargs: Any,
    ) -> RawSearchResult:
        try:
            validated_data = VectorSearchModel(
                query_vector=query_vector,
                query_text=query_text,
                top_k=top_k,
            )
        except ValidationError as e:
            raise SearchValidationError(e.errors()) from e

        if validated_data.query_text:
            if self.embedder:
                query_vector = self.embedder.embed_query(validated_data.query_text)
                logger.debug("Locally generated query vector: %s", query_vector)
            else:
                logger.error("No embedder provided for query_text.")
                raise EmbeddingRequiredError("No embedder provided for query_text.")

        points = self.client.query_points(
            collection_name=self.collection_name,
            query=query_vector,
            limit=top_k,
            with_payload=[self.id_property_external],
            **kwargs,
        ).points

        # Custom logic
        result_tuples = []
        for point in points:
            assert point.payload is not None
            target_ids = point.payload.get(self.id_property_external, [point.id])
            result_tuples = [[target_id, point.score] for target_id in target_ids]

        search_query = get_match_query(
            return_properties=self.return_properties,
            retrieval_query=self.retrieval_query,
        )

        parameters = {
            "match_params": result_tuples,
            "id_property": self.id_property_neo4j,
        }

        logger.debug("Qdrant Store Cypher parameters: %s", parameters)
        logger.debug("Qdrant Store Cypher query: %s", search_query)

        records, _, _ = self.driver.execute_query(
            search_query,
            parameters,
            database_=self.neo4j_database,
            routing_=neo4j.RoutingControl.READ,
        )

        return RawSearchResult(records=records)

Pull Requestの手順

まずはOSSへのフィードバックについてこちらの書籍でお作法を学びます。Pull Requestの文面(英語)は3stepに分けると作成しやすかったです。

  1. Steps to reproduce(再現手順)
  2. Actual result(実際に得られる結果)
  3. Expected result(期待される結果)

実際の文面です。

Steps to reproduce

ターミナル
python graphrag.py

Actual result

ターミナル
Ingesting to Qdrant...
len(embeddings):  25
len(node_id_mapping):  21  ←It needs to be the same as the number of embeddings.
Qdrant ingestion complete

Expected result

ターミナル
Ingesting to Qdrant...
len(embeddings):  25
len(node_id_mapping_list):  25
Qdrant ingestion complete

次にPull Requestの出し方ですが、これはOSSの方針に従います。もし方針が示されていない場合は、まずはIssueに挙げるのが無難かと思います。今回はQdrantから手順が示されており直接Pull Requestを出しました。

readme.md
Contributions are welcome! Follow these steps to contribute:

Fork the repository.
Create a new branch for your feature (git checkout -b feature/your-feature).
Commit your changes (git commit -m 'Add new feature').
Push to your branch (git push origin feature/your-feature).
Open a Pull Request describing your changes.

おわり

これからもOSSにContributeできるよう精進します。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?