0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【RAG入門】ソースコード付き:簡単にRAGを始めたい人へ

Posted at

はじめに

最近、AIについて調べていると、RAG(ラグと読みます)と言う単語をよく目にします。RAGについて皆さんはどの程度理解していますでしょうか?
今回は、RAGを触ってみたいという人向けにRAGのソースコードを提供します。私自身、色々とRAGに関する記事を書いていますが、結局触ってみないとわからないこともあります。
ぜひ、このコードで遊んでみてください。

※RAGのコードは1つじゃないです。あくまで1例です。
※Mac OS 15.0 での動作確認は行っていますが、Windowsで動くかは不明です。

本記事はRAGに関する知識を前提としています。RAGに関して不安のある方は、【RAG入門】RAGって何?なぜRAGが注目を浴びるのか?を読んでみてください!

ソースコードやデータはこちら (Github) で公開してます。

目次

  1. 構成
  2. ソースコード
  3. 使い方
  4. 解説

構成

以下のような構成になっています。

.
├── env ※仮想環境(任意)
├── config
│   ├── key.py ※GPT用
│   ├── path.py
│   └── prompt.py
├── data
│   ├── question.csv
│   ├── story.txt
│   ├── db_vec.faiss
│   └── db_chunk.pkl
├── store.py
└── rag.py

story.txtには今回外部知識として扱うデータがテキストベースで入っています。今回はGPT-4oで生成した3000文字程度の物語を扱います。テキストの全文はここにアクセスしてください。

story.txt (一部抜粋)
大海原に浮かぶ孤島、シエラ。数世紀にわたり古代文明の遺跡が点在するこの島には、天空にそびえる白い塔があり、その頂には世界の秘密が隠されていると伝えられていた。塔を巡る争いは絶えず、幾多の冒険者たちが挑んでは帰らぬ人となった。その中には、エリオットの父も含まれていた。
エリオットは、若き考古学者だった。失われた文明の真実を追い求め、幾度も冒険を重ねた末、このシエラ島にたどり着いた。彼の手には、父が遺した古びた地図。そこには「蒼海のエトランジェ」と記された謎の言葉が残されていた。

・・・

一歩、また一歩と門へ歩を進める。背後で少女が微笑み、手を振った。
「行って、エトランジェ。蒼海の果てへ」
扉が閉じ、エリオットの姿は消えた。蒼い光が塔を包み、やがて静寂が戻る。
港町では、今日も人々が賑やかに祭りの準備をしていた。誰も知らない、若き考古学者の冒険譚。それはいつか、語り継がれる伝説となり、再びこの地に帰ってくるかもしれない。
白い塔は、その静謐を保ったまま、天空を見上げていた。
蒼海のエトランジェ、その名と共に。

また、question.csvには次の5つの質問を事前に用意しています。

Question Answer
主人公の名前は? エリオット
エリオットの父はどこで亡くなった? シエラ
ガラバットは昔どこに所属していた? 王国の騎士団
エリオットの父ととも冒険していた人の名前は? レオ
エトランジェとはどのような少女? 銀の髪と深い蒼の瞳を持つ少女

ソースコード

langchainをベースに実装しました。
LLMにはGPT-4oを、Embedding Modelには、intfloat/multilingual-e5-largeを使用しています。

APIにお金がかかるため、「まずは無料で始めたい!」という人はLLaMAのコードを動かしてみてください。(2025/4までにはLLaMAのコードも公開します)

GPT-4ファミリーで実装する場合

rag.py
import torch
import os
import faiss
import pickle
import pandas as pd
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from config.prompt import *
from config.path import *
from config.key import API_KEY
import warnings

warnings.simplefilter('ignore')


# *******************************
# 設定
# *******************************
os.environ['OPENAI_API_KEY'] = API_KEY
llm = ChatOpenAI(model="gpt-4o-2024-11-20", temperature=0)
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large")

def get_embeddings(texts) -> torch.Tensor:
    """
    文章を1024次元のベクトルに変換し、L2正規化を適用
    """
    embedding = embedding_model.encode(texts)
    embedding = F.normalize(torch.tensor(embedding), p=2, dim=1)

    return embedding.numpy()

def retrieve(query, k=4)-> list[str]:
    """
    - query: 検索クエリ
    - k: 取得する文の数
    """
    # クエリの埋め込みを取得
    query_embedding = get_embeddings([query])
    _, indices = faiss_index.search(query_embedding.reshape(1, -1), k)
    results = []
    for idx in indices[0]:
        chunk = chunk_store[idx]
        results.append(chunk)
    return results

# *******************************
# 外部知識のロード
# *******************************
faiss_index = faiss.read_index(FAISS_FILE)
with open(PICKLE_FILE, "rb") as f:
    chunk_store = pickle.load(f)


# *******************************
# Chainを作成
# *******************************
augmented_prompt = PromptTemplate(
    template=AUGMENT_TEMPLATE,
    input_variables=["context1", "context2", "context3", "context4", "question"]
)
chain = LLMChain(llm=llm, prompt=augmented_prompt)


# *******************************
# 回答生成
# *******************************

df = pd.read_csv(INPUT_FILE)

for index, row in df.iterrows():
    # 検索
    docs = retrieve(row["Question"])

    input_data = {
        "context1": docs[0],
        "context2": docs[1],
        "context3": docs[2],
        "context4": docs[3],
        "question": row["Question"]
    }
    # 回答生成
    ans = chain.run(input_data)

    # 書き込み
    df.at[index, "Response"] = ans
    df.at[index, "Prompt"] = augmented_prompt.format(**input_data)

# 出力
df.to_csv(OUTPUT_FILE, index=False)

使い方

必要なライブラリをインストールします。

ライブラリのインストール
pip install sentence_transformers langchain-community faiss-cpu pandas

sentence_transformersもしくは、その依存関係(torchtransformers)がNumpy 2.xに対応していないため、Numpy 1.xをインストールします。

numpyのインストール
pip install numpy==1.26.4

また、GPT-4ファミリーで実行する際には、OpenAIのライブラリのインストールが必要になります。

openaiのインストール
pip install openai

次に、config/key.pyAPI_KEYを設定します。この設定もGPT-4ファミリーで実行する場合のみ必要です。

key.py
API_KEY = 'XXX'

最後に、config/path.pyに各種パスを設定してください。ここの設定は任意です。

path.py
PICKLE_FILE = "data/db_chunk.pkl" チャンクの保存先(後述)
FAISS_FILE = "data/db_vec.faiss" 埋め込みベクトルの保存先(後述)
OUTPUT_FILE = "answer.csv" 回答の生成結果の保存先
INPUT_FILE = "data/question.csv" テストデータ
STORY_TEXT = "data/story.txt" 外部知識として使う小説データ

これで、設定が終わりです。次にStoreという作業を行います。
RAGは外部データを参照して回答生成するため、参照先のデータベースを作ってあげる必要があります。
今回は以下の設定でデータを作ります。

  • チャンクサイズ:200文字
  • オーバーラップ:50文字

例えば、チャンクサイズ15文字、オーバーラップ5文字の場合、
「扉が閉じ、エリオットの姿は消えた。蒼い光が塔を包み、やがて静寂が戻る。」

  • 扉が閉じ、エリオットの姿は消え
  • の姿は消えた。蒼い光が塔を包み
  • が塔を包み、やがて静寂が戻る。

という3つのデータができるといった感じです。
また、検索時にコサイン類似度を用いて検索をするため、データを日本語のままではなく、Embedding Modelを活用してベクトル化します。
store.pyを実行することで作成できますが、時間がかかるため面倒な人はGithubから作成済みのファイルをダウンロードしてください。ベクトル化したデータはdb_vec.faiss、日本語のデータはdb_chunk.pklに保存されています。

store.py
import faiss
import pickle
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter
from config.path import *

# Embeddingモデルをロード
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large", device="cpu")

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 200,      # チャンクの文字数
    chunk_overlap = 50,    # チャンクオーバーラップの文字数
)

def get_embeddings(texts: list[str]) -> torch.Tensor:
    """
    文章を1024次元のベクトルに変換し、L2正規化を適用
    """
    embedding = embedding_model.encode(texts)
    embedding = F.normalize(torch.tensor(embedding), p=2, dim=1)

    return embedding.numpy()

# ファイルを開いて読み込む
with open(STORY_TEXT, "r", encoding="utf-8") as file:
    content = file.read()
# 改行をなくす
content = content.replace("\n","")

# 200文字でテキストを分割
texts = text_splitter.split_text(content)

# faissを利用してベクトルを保存
faiss_index = faiss.IndexFlatIP(1024)
faiss_index.add(get_embeddings(texts))

# 保存
faiss.write_index(faiss_index, FAISS_FILE)
with open(PICKLE_FILE, "wb") as f:
    pickle.dump(texts, f)

データベースを作り終えました。最後に質問をもとに回答を生成します。

rag.pyの実行
python rag.py

config/path.pyOUTPUT_FILEで指定したところにcsvファイルができているはずです!(既定では、rag.pyと同じ場所にanswer.csvという名前で作成されます)
ヘッダーは

  • Question:質問文
  • Answer:正解
  • Response:RAGによる回答
  • Prompt:回答生成時に使用したプロンプト

の4つです。
ちなみに、GPT-4oで実行したところ、

Question Answer Response
主人公の名前は? エリオット 主人公の名前は「エリオット」です。
エリオットの父はどこで亡くなった? シエラ Contextの中には、エリオットの父がどこで亡くなったかについての明確な記述はありません。ただし、彼が白い塔の頂にある異界への門を目指していたことや、激しい戦いと焦りが手記に記されていることから、塔を目指す過程で命を落とした可能性が高いと推測されます。しかし、具体的な場所については明示されていません。
ガラバットは昔どこに所属していた? 王国の騎士団 ガラハッドは昔、王国の騎士団に所属していました。
エリオットの父ととも冒険していた人の名前は? レオ エリオットの父とともに冒険していた人の名前は「レオ」です。
エトランジェとはどのような少女? 銀の髪と深い蒼の瞳を持つ少女 エトランジェとは、銀の髪と深い蒼の瞳を持つ少女であり、「異界の守り人」として描かれています。彼女は異界への門である塔を守り、エリオットのような訪問者に対して、その願いが真実であるかを見極め、扉を開く役割を担っています。彼女は静かで神秘的な存在感を持ち、異界とこの世界を結ぶ重要な存在として描かれています。

となりました。チャンクが計20個しかないので、答えやすかったかもしれないです。何がともあれ、4/5で正解できました。
余談になりますが、RAGの検索はまあまあ難しい印象です。その中でもHydeやRerankerは有効な手法です。RAGには多くのテクニックが存在します。これを機に、いろいろ調べてみてください!

解説

rag.pyについて細かく解説していきます
大まかにいうと、OpenAI APIFAISS (Facebook AI Similarity Search) を使って、指定された質問に対して文書検索を行い、関連する情報を元に回答を生成し、CSVファイルに出力しています。
各セクションに分けて解説していこうと思います。


1. ライブラリのインポート

import torch
import os
import faiss
import pickle
import pandas as pd
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from langchain_community.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from config.prompt import *
from config.path import *
from config.key import API_KEY
import warnings

warnings.simplefilter('ignore')
  • torch: PyTorchライブラリ
  • faiss: ベクトル検索ライブラリ。高速な類似検索を実行
  • pickle: チャンクの読み込みに使用
  • pandas: CSVファイルの読み書きやデータ処理
  • sentence_transformers: テキストをベクトルに変換
  • langchain: LLM (Large Language Model) を使ったチェーン処理を実行
  • warnings: 警告メッセージの非表示
  • config: APIキーやファイルパス、プロンプトのテンプレートを外部ファイルから読み込み

2. 設定

os.environ['OPENAI_API_KEY'] = API_KEY
llm = ChatOpenAI(model="gpt-4o-2024-11-20", temperature=0)
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large")
  • APIキー設定: OpenAI APIを使用するためにAPIキーを環境変数に設定
  • モデル設定: ChatOpenAIでGPT-4を指定。回答の一貫性を高めるために temperature=0に設定
  • 埋め込みモデル: SentenceTransformerを使い、テキストを1024次元のベクトルに変換

3. 埋め込み取得関数

def get_embeddings(texts) -> torch.Tensor:
    """
    文章を1024次元のベクトルに変換し、L2正規化を適用
    """
    embedding = embedding_model.encode(texts)
    embedding = F.normalize(torch.tensor(embedding), p=2, dim=1)

    return embedding.numpy()
  • 目的:
    テキストを1024次元のベクトルに変換し、L2正規化でベクトルの長さを1に揃える。
  • 返り値:
    NumPy配列形式の正規化済みベクトル。

4. 文書検索関数

def retrieve(query, k=4) -> list[str]:
    """
    - query: 検索クエリ
    - k: 取得する文の数
    """
    # クエリの埋め込みを取得
    query_embedding = get_embeddings([query])
    _, indices = faiss_index.search(query_embedding.reshape(1, -1), k)
    results = []
    for idx in indices[0]:
        chunk = chunk_store[idx]
        results.append(chunk)
    return results
  • 目的:
    指定されたクエリに最も類似したk個の文書をFAISSを使って検索。
  • FAISSの使い方:
    最も類似した文書のインデックス (indices) を取得。
  • 返り値:
    検索結果の文書リスト

5. 外部知識のロード

faiss_index = faiss.read_index(FAISS_FILE)
with open(PICKLE_FILE, "rb") as f:
    chunk_store = pickle.load(f)
  • FAISSインデックスの読み込み:
    faiss.read_indexでインデックスファイル (FAISS_FILE) をロード
  • 文書データの読み込み:
    pickleで保存された文書のチャンク (PICKLE_FILE) を読み込み

6. プロンプトテンプレートの作成

augmented_prompt = PromptTemplate(
    template=AUGMENT_TEMPLATE,
    input_variables=["context1", "context2", "context3", "context4", "question"]
)
final_chain = LLMChain(llm=llm, prompt=augmented_prompt)
  • 目的:
    LangChainでGPT-4に渡すプロンプトのテンプレートを作成。
  • テンプレート:
    外部知識 (context1context4) と質問 (question) を組み合わせてプロンプトを構築

7. 回答生成

df = pd.read_csv(INPUT_FILE)

for index, row in df.iterrows():
    # 検索
    docs = retrieve(row["Question"])

    input_data = {
        "context1": docs[0],
        "context2": docs[1],
        "context3": docs[2],
        "context4": docs[3],
        "question": row["Question"]
    }
    # 回答生成
    ans = chain.run(input_data)

    # 書き込み
    df.at[index, "Response"] = ans
    df.at[index, "Prompt"] = augmented_prompt.format(**input_data)
  • 入力:
    CSV (INPUT_FILE) から質問を読み込み
  • 検索:
    retrieve関数で関連する4つの文書を取得
  • 回答生成:
    chain.runでGPT-4にプロンプトを渡し、回答を生成
  • 書き込み:
    回答 (Response) と使用したプロンプト (Prompt) をCSVに保存

8. CSVへの書き込み

df.to_csv(OUTPUT_FILE, index=False)
  • 目的:
    最終結果(質問、回答、プロンプト)を指定のCSVファイル (OUTPUT_FILE) に保存
  • index=False:
    インデックスを含めずに書き込む

終わりに

いかがでしたでしょうか?
座学でも多くのことを身につけることは可能ですが、触ってみないと感覚やセンスを養うことはできないと思っています。エンジニアにとって、感覚やセンスは非常に大切なところですので、ぜひ、たくさんRAGで遊んでみてください。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?