はじめに
SnowflakeとChatGPTを使ってRAG環境を構築してみました。
今回はStreamlitを使用せず、python上で動かしてみた使用感について記していきます。
事前準備
用意するもの
snowflakeで利用するデータベース、ウェアハウスなどのリソース
(弊社ではTelegramの脅威情報の収集を行っており、今回はそのリソースの一部利用します。)
注意:snowflakeのリソース利用については料金がかかります。
参考資料
前回ご紹介したこちらのコードを参考に進めていきます。
実際に作ってみた
コード
rag_test.py
import os
import pandas as pd
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.llms import OpenAI
from sqlalchemy import create_engine
from snowflake.sqlalchemy import URL
# OpenAI APIキーの設定
os.environ["OPENAI_API_KEY"] = "your_api_key"
# 1. snowflake
engine = create_engine(
URL(
# Snowflake接続情報
account="your_account_name",
user="your_username",
password="your_password",
database="your_database",
schema="your_schema",
warehouse="your_warehouse",
)
)
query = f"SELECT * FROM database_name.schema_name.table_name"
# データ取得
df = pd.read_sql(query, engine)
# 2. Documentに変換
documents = [
Document(
page_content=" ".join(map(str, row.values)),
metadata=row.to_dict()
)
for _, row in df.iterrows()
]
# 3. テキストを小さなチャンクに分割
splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
texts = splitter.split_documents(documents)
# 4. 埋め込みの生成
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(texts, embeddings)
# 5. RAGチェーンの作成
llm = OpenAI()
retriever = vectorstore.as_retriever()
rag_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever)
# 6. 質問に対する応答を生成
question = "channel_nameの□□□□□□について概要を教えてください。"
chat_history = []
response = rag_chain.run({"question": question, "chat_history": chat_history})
print(response)
# 出力結果
□□□□□□ is a member of the public channel □□□□□□ on Telegram. This channel is part of the larger group □□□□□□ and is a private channel within that group. The channel was created on August 17, 2024 and will be active until November 29, 2024. The channel has the □□□□□□ and can be found at □□□□□□.
機密情報なので出力結果の一部はぼかしていますが、チャンネル(ハッカーグループ)について概要を教えてくれました。
質問を変えてみました。
rag_test.py
# 出力結果
question = "日本を攻撃対象としたchannel_nameを教えてください。"
↓↓↓
# 出力結果
□□□□□□, ○○○○○○, and ×××××× are all channels that have targeted Japan for cyber attacks.
異なる3つのチャンネルから日本が攻撃対象とされていることがわかりました。(怖っ)
最後に
本記事では、Snowflakeを活用してRAG構築の手順を記しました。
今後もSnowflakeを活用した技術検証を行っていきたいと思います。