進化の激しいLangchain(以前は一日単位でマニュアルにアクセスしても変化してました)、久々にマニュアル見たらエージェントも追加されてました。
関連記事はこちら。
なお、サンプルは今日時点ではlangchain==0.0.246
でないと動作確認できませんでした。
%pip install databricks-sql-connector langchain==0.0.246
dbutils.library.restartPython()
OpenAI APIのキーを設定します。
import os
os.environ["OPENAI_API_KEY"] = dbutils.secrets.get("demo-token-takaaki.yayoi", "openai_api_key")
カタログとスキーマ名を指定してデータベースに接続します。以下の例ではカタログsamples
にあるタクシー乗降データを格納しているスキーマnyctaxi
にアクセスしています。
# SQLDatabaseラッパーによるDatabricksへの接続
from langchain import SQLDatabase
db = SQLDatabase.from_databricks(catalog="samples", schema="nyctaxi")
LLMラッパーの生成。
# OpenAI Chat LLMラッパーの作成
from langchain.chat_models import ChatOpenAI
llm = ChatOpenAI(temperature=0, model_name="gpt-4")
langchain==0.0.246
でないと以下でエラーになっていました。この辺を参考に。
from langchain import SQLDatabaseChain
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
問い合わせ。
db_chain.run(
"What is the average duration of taxi rides that start between midnight and 6am?"
)
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
> Entering new SQLDatabaseChain chain...
What is the average duration of taxi rides that start between midnight and 6am?
SQLQuery:
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
SELECT AVG(UNIX_TIMESTAMP(tpep_dropoff_datetime) - UNIX_TIMESTAMP(tpep_pickup_datetime)) AS average_duration
FROM trips
WHERE HOUR(tpep_pickup_datetime) BETWEEN 0 AND 6
LIMIT 5
SQLResult: [(933.2543103448276,)]
Answer:The average duration of taxi rides that start between midnight and 6am is approximately 933.25 seconds.
> Finished chain.
Out[5]: 'The average duration of taxi rides that start between midnight and 6am is approximately 933.25 seconds.'
日本語でも聞いてみます。
db_chain.run(
"深夜と6amの間にスタートしたタクシー乗降の平均乗車期間は何ですか?"
)
秒と分の違いはあれども同じ答えが返ってきます。ロジックも同じ。
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
> Entering new SQLDatabaseChain chain...
深夜と6amの間にスタートしたタクシー乗降の平均乗車期間は何ですか?
SQLQuery:
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
SELECT AVG(TIMESTAMPDIFF(MINUTE, tpep_pickup_datetime, tpep_dropoff_datetime)) AS average_duration
FROM trips
WHERE HOUR(tpep_pickup_datetime) BETWEEN 0 AND 6
SQLResult: [(15.060344827586206,)]
Answer:深夜と6amの間にスタートしたタクシー乗降の平均乗車期間は約15分です。
> Finished chain.
Out[8]: '深夜と6amの間にスタートしたタクシー乗降の平均乗車期間は約15分です。'
エージェントも試します。
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent = create_sql_agent(llm=llm, toolkit=toolkit, verbose=True)
最も距離の長い移動はどれで、どれだけかかりましたか?
agent.run("What is the longest trip distance and how long did it take?")
以下のような挙動をして回答を得ています。エージェント、触ったことなかったのですがこれも面白いですね。
- 質問に適切なテーブル
trips
を特定し、必要なカラムがあるのかをチェック。 -
trips
にはtrip_distance
、tpep_pickup_datetime
、tpep_dropoff_datetime
のカラムがあるので移動期間を特定できます。最長の距離と時間を計算するクエリーを記述します。 - クエリーには問題が無いようです。距離が最も長かった移動とその時間を計算するためにクエリーを実行します。
- ファイナルアンサーを取得しました。最長の移動は30.6マイルで43分31秒を要しました。
The 'trips' table seems to be the most relevant for this question. I should check its schema to see what columns it has.
Action: sql_db_schema
Action Input: "trips"
Observation:
CREATE TABLE trips (
tpep_pickup_datetime TIMESTAMP,
tpep_dropoff_datetime TIMESTAMP,
trip_distance FLOAT,
fare_amount FLOAT,
pickup_zip INT,
dropoff_zip INT
) USING DELTA
/*
3 rows from trips table:
tpep_pickup_datetime tpep_dropoff_datetime trip_distance fare_amount pickup_zip dropoff_zip
2016-02-14 16:52:13 2016-02-14 17:16:04 4.94 19.0 10282 10171
2016-02-04 18:44:19 2016-02-04 18:46:00 0.28 3.5 10110 10110
2016-02-17 17:13:57 2016-02-17 17:17:55 0.7 5.0 10103 10023
*/
Thought:The 'trips' table has the columns 'trip_distance' and 'tpep_pickup_datetime' and 'tpep_dropoff_datetime' which I can use to calculate the duration of the trip. I will write a query to get the longest trip distance and its duration.
Action: sql_db_query_checker
Action Input: "SELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as trip_duration FROM trips ORDER BY trip_distance DESC LIMIT 1"
Observation: The original query seems to be correct as it doesn't have any of the common mistakes mentioned. Here is the reproduction of the original query:
SELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as trip_duration
FROM trips
ORDER BY trip_distance DESC
LIMIT 1
Thought:
DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set.
The query is correct. Now I will execute it to get the longest trip distance and its duration.
Action: sql_db_query
Action Input: "SELECT trip_distance, tpep_dropoff_datetime - tpep_pickup_datetime as trip_duration FROM trips ORDER BY trip_distance DESC LIMIT 1"
Observation: [(30.6, '0 00:43:31.000000000')]
Thought:I now know the final answer
Final Answer: The longest trip distance is 30.6 miles and it took 43 minutes and 31 seconds.
> Finished chain.
Out[7]: 'The longest trip distance is 30.6 miles and it took 43 minutes and 31 seconds.'