今度はこちらのノートブックをウォークスルーします。
以下のコードでは、Sparkデータフレームエージェントのサンプルを説明します。
要件
- このノートブックを使うには、お使いのOpenAI APIトークンを指定してください。
- Databricks Runtime 13.3 ML以降
ライブラリのインストール
最新のlangchain
とdatabricks-sql-connector
を使うことをお勧めします。
注意
もとのノートブックのままではエラーとなるのでlangchain_experimental
を追加します。
%pip install --upgrade langchain databricks-sql-connector langchain_experimental
dbutils.library.restartPython()
import os
os.environ["OPENAI_API_KEY"] = "<OpenAI APIキー>"
Sparkデータフレームエージェント
このSparkデータフレームエージェントによって、必要に応じてSparkデータフレームを操作できるようになります。シンプルに、LLMと対象のデータフレームを指定してcreate_spark_dataframe_agent
をコールするだけです。
from langchain.llms import OpenAI
#from langchain.agents import create_spark_dataframe_agent
from langchain_experimental.agents.agent_toolkits import create_spark_dataframe_agent
df = spark.read.csv("/databricks-datasets/COVID/coronavirusdataset/Region.csv", header=True, inferSchema=True)
display(df)
agent = create_spark_dataframe_agent(llm=OpenAI(temperature=0), df=df, verbose=True)
問い合わせます。
agent.run("How many rows are there?")
> Entering new AgentExecutor chain...
Thought: I need to find out how many rows are in the dataframe
Action: python_repl_ast
Action Input: df.count()
Observation: 244
Thought: I now know the final answer
Final Answer: There are 244 rows in the dataframe.
> Finished chain.
'There are 244 rows in the dataframe.'
OpenAIなので日本語でもOKです。
agent.run("何行ありますか?")
> Entering new AgentExecutor chain...
Thought: I need to know how many rows are in the dataframe
Action: python_repl_ast
Action Input: df.count()
Observation: 244
Thought: I now know the final answer
Final Answer: データフレームには244行があります。
> Finished chain.
'データフレームには244行があります。'