参考記事
こちらの方の記事がベースになっています。
その上でアレンジしたこと
1 :LangChainは、様々なデータベースを扱えるはずなのに、
ほとんどの記事はSQliteにしか言及がない
MS SQL
MySQL
MariaDB
PostgreSQL
Oracle SQL
Databricks
SQLite
→ SQL SERVERへの接続に挑戦
2 :SQLDatabaseChainというLangchainの機能の一つを用いている。
この関数は、ユーザーからの質問について、
「質問の復唱」、「SQLにおけるQueryの生成」、「実行結果の取得」、「AIによるデータの回答・見解」
の4つの情報を返却してくれる優れもの。
しかし、データ量が多くなることから高確率でトークン上限のエラーに陥ることが多い。
→「SQLにおけるQueryの生成」のみでいいのではと思い、
それだけがピンポイントで得られる関数の探索と工夫を実施
完成品 概要
・Streamlitでのデモシステムです
・いくつかあるうちのテーブルをセレクトボックスから選択します(今回は1つです)
・文章で自分の欲しい情報を依頼します。
例:製品の一覧を表示して
例:2022年の製品別の売上金額を取得して 等
・自然言語を通して、得られたSQLのクエリを用いて
pandas ライブラリの操作でクエリによる結果データを
取得します
まずは全文
# Import libraries
import os
import urllib
import pyodbc
import pandas as pd
import streamlit as st
from sqlalchemy import text, create_engine
from langchain import PromptTemplate, OpenAI
from llama_cpp import Llama
from langchain.chains import create_sql_query_chain
# 独自設定箇所
SECRET_KEY = "your API token"
server = "your sql server"
db = "your db name"
driver = r'ODBC Driver 17 for SQL Server'#例として表示しましたがここも環境に合わせて
# OPENAI keyの環境変数への設定
os.environ["OPENAI_API_KEY"] = SECRET_KEY
# SQL SERVER用の接続文字列を生成
odbc_connect = urllib.parse.quote_plus(
'DRIVER={%s};SERVER=%s;DATABASE=%s;Trusted_Connection=yes' % (driver, server, db))
connection_string = "mssql+pyodbc:///?odbc_connect=%s" % odbc_connect
# LLMモデルの定義
LLM = OpenAI(#model_name = "text-davinci-003", # 利用するモデル
model_name = "gpt-3.5-turbo",# 利用するモデル
temperature = 0, # 出力する単語のランダム性(0から2の範囲) 0であれば毎回返答内容固定
verbose = False, # プロンプトの動的表示有無
)
#テーブル名を取得
def fetch_table_names():
conn = pyodbc.connect(
f"Driver=SQL Server;Server={server};Database={db};Trusted_Connection=yes;")
cursor = conn.cursor()
query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';"
cursor.execute(query)
tables = cursor.fetchall()
return [table.TABLE_NAME for table in tables]
#LLMによるデータ
def search_db(text, db_chain):
result = db_chain.invoke({"question": "%s" % text})
return {
'Question': text,
'SQLQuery': result
}
def main():
from langchain import SQLDatabase
table_names = fetch_table_names()
st.title("検索システム")
selected_table = st.selectbox("テーブルを選択してください", table_names)
user_question = st.text_input("質問を入力してください:", "")
if st.button("クエリを実行"):
tables = [selected_table]
SQLDatabase = SQLDatabase.from_uri(
connection_string, include_tables=tables)
db_chain = create_sql_query_chain(LLM, SQLDatabase)
result = search_db(user_question, db_chain)
st.write("質問:", result["Question"])
st.write("SQL クエリ:", result["SQLQuery"])
conn = pyodbc.connect(
f"Driver=SQL Server;Server={server};Database={db};Trusted_Connection=yes;")
result_df = pd.read_sql_query(result["SQLQuery"], conn)
conn.close()
st.table(result_df)
if __name__ == "__main__":
main()
言語モデル 1,000トークンあたりの料金
gpt-3.5-turbo :$0.002(およそ0.26円)
text-davinci-003:$0.02(およそ2.6円)
なので、コードではgpt-3.5-turboにしてます。
ポイント
SQL SERVERについて
SQLDatabase = SQLDatabase.from_uri(
connection_string, include_tables=tables)
URIの定義をしています。
引数はそれぞれ、
①connection_string : 接続文字列
②include_tables=tables:今回推論に用いてもらうテーブルの指定
②SQL SERVERの接続文字列の生成
server = "your sql server"
db = "your db name"
driver = r'ODBC Driver 17 for SQL Server'#例として表示しましたがここも環境に合わせて
#SQL SERVER用の接続文字列を生成
odbc_connect = urllib.parse.quote_plus(
'DRIVER={%s};SERVER=%s;DATABASE=%s;Trusted_Connection=yes' % (driver, server, db))
connection_string = "mssql+pyodbc:///?odbc_connect=%s" % odbc_connect
テーブル名のUIでの選択
#テーブル名を取得
def fetch_table_names():
conn = pyodbc.connect(
f"Driver=SQL Server;Server={server};Database={db};Trusted_Connection=yes;")
cursor = conn.cursor()
query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';"
cursor.execute(query)
tables = cursor.fetchall()
return [table.TABLE_NAME for table in tables]
SQLのクエリ生成について
SQLDatabaseChainは使用せず、
クエリ生成に限定して「create_sql_query_chain」という関数を使います。
この記事の「Case 1: Text-to-SQL query」の記述を参考にしています。
db_chain = create_sql_query_chain(LLM, SQLDatabase)
result = search_db(user_question, db_chain)
def search_db(text, db_chain):
result = db_chain.invoke({"question": "%s" % text})
return {
'Question': text,
'SQLQuery': result
}
生成されたクエリをPandasライブラリに通す
これによって、モデル側が担っていた役割を分散して、出力プロンプト数を減らす
st.write("質問:", result["Question"])
st.write("SQL クエリ:", result["SQLQuery"])
conn = pyodbc.connect(
f"Driver=SQL Server;Server={server};Database={db};Trusted_Connection=yes;")
result_df = pd.read_sql_query(result["SQLQuery"], conn)
conn.close()
デモ
都合により削除
考察
・ユーザーのリクエストした文章を基にカラム名から情報を推測して
クエリを作ってくれてます。
・ユーザーがカラムを指定すればそのように用意もしてくれます
(正規のカラム名でなくても、大丈夫のようですが精度は下がると思います。)
・複数のテーブルを渡せば、テーブル結合もやってくれそうですが、
入力プロンプトのトークン上限によってそこまではできずにいます。
→次の課題
CHATGPT編があるということは・・・
Llama2でもやってみたのでその記事を載せたいと思います。