ChatGPTとLangChainでデータ分析が楽々!自然言語で簡単にデータ探索ツール
この記事では、ChatGPTとLangChainを使って、自然言語を使って簡単にデータセットを探索できるツールを作成する方法を紹介します。
目次
- はじめに
- 必要なパッケージのインストール
- ChatGPTとLangChainを使ったデータ探索ツールの構築
- Dashアプリケーションの作成
- データフレームとクエリを処理する関数の実装
結論
ChatGPTとLangChainを組み合わせることで、自然言語を用いたデータ分析ツールを構築することができました。このツールは、Dashを使用してインタラクティブなWebアプリケーションとして実装されており、分析者が直感的にデータセットに関する質問を投げかけることができます。また、EDAを自動化し、簡単にデータセットの概要や特徴を把握できる機能も実装しました。このツールは、データ分析や機械学習に従事する専門家だけでなく、データに関心のある一般ユーザーにも有益であり、様々なデータセットに対応するための良いスタートポイントとなります。この記事を通じて、Pythonプログラミングや自然言語処理を活用したデータ分析の可能性を広げ、より多くの人々がデータを活用して新たな知見や価値を見つけられることを願っています。
色々と課題もありますが、ご指摘いただけると幸いです。
はじめに
データ分析では、データセットに対する探索的データ分析(EDA)が重要なステップです。しかし、EDAを行う際に、データセットに対するクエリをプログラムで記述することは、煩雑で手間がかかることがあります。そこで、自然言語でデータセットに対する質問を行い、データ分析を効率化する事ができないか考えました。
この記事では、ChatGPTとLangChainを利用して、自然言語でデータセットに対するクエリを行うことができるデータ探索ツールを作成します。具体的には、PythonのDashライブラリを使ってインタラクティブなWebアプリケーションを構築し、データセットをアップロードして、自然言語で質問を行うことができるようにします。
デモ
ファイル構成
project_directory/
│
├── app.py
└── chat.py
必要なパッケージのインストール
まず、以下のコマンドを実行して、必要なパッケージをインストールします。
pip install dash pandas plotly langchain openai
ChatGPTとLangChainを使ったデータ探索ツールの構築
データ探索ツールの構築は、以下の手順で行います。
データフレームとクエリを処理する関数の実装
次に、ChatGPTとLangChainを使って、日本語のクエリを英語に翻訳(これは必ずしも必要ないがChatGPTは3.5系を使用しているのとトークン節約を意識して挿入)し、データフレームに対してクエリを実行する関数(create_pandas_dataframe_agent)を実装します。
ここでの工夫点としては、データフレームのインデックスを出力させるようにすることです。ただし、このあたりは課題もあり、どのようにクエリを処理させて出力させるかは検討の余地があると考えています。
def translate_to_english(text: str) -> str:
llm = OpenAI(temperature=0)
prompt = PromptTemplate(
input_variables=["text"],
template="Translate the following Japanese text to English: {text}",
)
chain = LLMChain(llm=llm, prompt=prompt)
translated_text = chain.run(text)
return translated_text
def chat_tool_with_pandas_df(df, query):
query = query + "そのインデックスは?" # 入力クエリに対してのインデックスを出力させるように
# Translate the query to English
translated_query = translate_to_english(query)
print(query, translated_query)
agent = create_pandas_dataframe_agent(
OpenAI(temperature=0),
df,
verbose=True,
max_iterations=2,
early_stopping_method="generate",
)
# Run the agent with the translated query
raw_result = agent.run(translated_query)
print(raw_result)
# Create a prompt for the LLM to get a short answer in Python list format
llm = OpenAI(temperature=0)
prompt = PromptTemplate(
input_variables=["query", "raw_result"],
template="Q: {query}\nA: {raw_result}\nAnswer in python list format, using as few characters as possible. No explanatory text is required.", # このあたりも要改善必要か
)
# Create a chain to run the prompt with the LLM
chain = LLMChain(llm=llm, prompt=prompt)
# Run the chain with the translated query and raw_result as inputs
formatted_result = chain.run({"query": translated_query, "raw_result": raw_result})
formatted_result = json.loads(formatted_result.replace("\n", ""))
print(formatted_result)
return df.loc[formatted_result, :]
Dashアプリケーションの作成
まず、dash, pandas, plotly, langchain などの必要なライブラリをインポートします。
Dashについて
Dashは、PythonでインタラクティブなWebアプリケーションを簡単に作成できるオープンソースのフレームワークです。データビジュアル化ライブラリであるPlotlyをベースにしており、データ分析や機械学習の結果を視覚化し、ユーザーと対話的に操作できるアプリケーションを作成するのに適しています。Dashアプリケーションは、シンプルなPythonスクリプトで構築され、Webブラウザで動作します。このため、データ分析者やエンジニアが、専門的なWeb開発のスキルを持っていなくても、高度なデータドリブンのWebアプリを作成できます。Dashでは、豊富なコンポーネントが用意されており、グラフやチャート、テーブルなどの表示要素や、ボタンやスライダー、ドロップダウンメニューなどのインタラクティブな操作要素を組み合わせて、独自のデータ分析アプリを構築できます。
import os
import base64
import io
from io import StringIO
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import dash
from dash import dcc, html, dash_table
from dash import Input, Output, State
from dash.exceptions import PreventUpdate
from chat import chat_tool_with_pandas_df
# set the environment variable
os.environ["OPENAI_API_KEY"] = "OPENAIのAPIキー"
# Define the layout for the Dash application
app = dash.Dash(__name__)
# Create the layout with upload, input, and display sections
app.layout = html.Div(
[
# Title
html.H1("Dataset Query Tool"),
# Upload section
dcc.Upload(
id="upload-data",
children=html.Div(["Drag and Drop or ", html.A("Select a CSV File")]),
style={
"width": "100%",
"height": "60px",
"lineHeight": "60px",
"borderWidth": "1px",
"borderStyle": "dashed",
"borderRadius": "5px",
"textAlign": "center",
"margin": "10px",
},
# Allow only one file to be uploaded
multiple=False,
),
dcc.Store(id="stored-dataframe"),
html.Div(id="dataframe-info"),
# Query input section
html.Label("クエリ入力"),
html.Div(
[
dcc.Input(
id="query-input",
placeholder="クエリを入力してください(例:男性で30歳以上40歳未満で生き残った人は?)",
style={"width": "100%", "height": "50px"},
),
html.Button("Submit", id="submit-button"),
]
),
# Query result section
html.Label("クエリ結果"),
dcc.Loading(html.Div(id="query-result")),
dcc.Store(id="query-result-dataframe"),
# EDA plots section
html.Label("EDAプロット"),
dcc.Loading(html.Div(id="eda-plots")),
]
)
# Callback function to display query results and store the result dataframe
@app.callback(
Output("query-result", "children"),
Output("query-result-dataframe", "data"),
Input("submit-button", "n_clicks"),
[
State("query-input", "value"),
State("stored-dataframe", "data"),
],
)
def update_output(n_clicks, query, df):
if n_clicks is None:
return "", None
if df is None:
return "csv ファイルをアップロードしてください", None
if query:
df = pd.DataFrame(df)
result_df = chat_tool_with_pandas_df(df, query).reset_index()
store_data = result_df.to_dict("records")
return [
dash_table.DataTable(
columns=[{"name": i, "id": i} for i in result_df.columns],
data=result_df.to_dict("records"),
),
], store_data
return "クエリが入力されていません。", None
# Callback function to store the uploaded dataframe and display its information
@app.callback(
Output("stored-dataframe", "data"),
Output("dataframe-info", "children"),
[Input("upload-data", "contents")],
[State("upload-data", "filename")],
)
def update_output(contents, filename):
layout = []
if contents:
content_type, content_string = contents.split(",")
decoded = base64.b64decode(content_string)
try:
if "csv" in filename:
# Assume that the user uploaded a CSV file
uploaded_df = pd.read_csv(
io.StringIO(decoded.decode("utf-8")), index_col=0
)
info_buffer = StringIO()
uploaded_df.info(buf=info_buffer)
info_str = info_buffer.getvalue()
layout = [
html.Label("データフレームの情報:"),
html.Pre(info_str),
html.Label("データフレームの要約統計:"),
dash_table.DataTable(
id="describe-table",
columns=[
{"name": i, "id": i}
for i in uploaded_df.reset_index().describe().columns
],
data=uploaded_df.describe().reset_index().to_dict("records"),
),
html.Label("データフレームの最初の10行:"),
dash_table.DataTable(
id="head-table",
columns=[
{"name": i, "id": i} for i in uploaded_df.head(10).columns
],
data=uploaded_df.head(10).to_dict("records"),
),
]
else:
return html.Div(["Invalid file type. Please upload a CSV file."])
except Exception as e:
return html.Div([f"Error processing file: {str(e)}"])
return uploaded_df.to_dict("records"), layout
return None, layout
# Callback function to generate plots from DataFrame
@app.callback(
Output("eda-plots", "children"),
Input("query-result-dataframe", "data"),
)
def generate_eda_plots(data):
if not data:
raise PreventUpdate
df = pd.DataFrame(data)
# Check for missing values
missing_values = df.isnull().sum()
missing_values_percent = (missing_values / len(df)) * 100
missing_df = pd.DataFrame(
{"Missing Values": missing_values, "Percentage": missing_values_percent}
).reset_index()
# Separate columns into numerical and categorical
numerical_columns = df.select_dtypes(include=["int64", "float64"]).columns
categorical_columns = df.select_dtypes(include=["object", "bool"]).columns
# Create a subplot for numerical distributions
fig_num_dist = make_subplots(
cols=len(numerical_columns), rows=1, subplot_titles=numerical_columns
)
# Plot histograms for numerical columns
for i, col in enumerate(numerical_columns, start=1):
fig_num_dist.add_trace(
go.Histogram(x=df[col], nbinsx=20, histnorm="probability"), col=i, row=1
)
fig_num_dist.update_layout(title_text="Numerical Distributions")
# Create a subplot for categorical distributions
fig_cat_dist = make_subplots(
cols=len(categorical_columns), rows=1, subplot_titles=categorical_columns
)
# Plot bar charts for categorical columns
for i, col in enumerate(categorical_columns, start=1):
fig_cat_dist.add_trace(
go.Histogram(x=df[col], histnorm="probability"), col=i, row=1
)
fig_cat_dist.update_layout(title_text="Categorical Distributions")
# Plot heatmap for numerical column correlations
corr_matrix = df[numerical_columns].corr()
fig_corr = go.Figure(
go.Heatmap(
z=corr_matrix,
x=numerical_columns,
y=numerical_columns,
colorscale="RdBu",
)
)
fig_corr.update_layout(title_text="Numerical Column Correlations")
layout = [
html.Label("欠損値の確認"),
dash_table.DataTable(
columns=[{"name": i, "id": i} for i in missing_df.columns],
data=missing_df.to_dict("records"),
),
html.Label("数値データの分布"),
dcc.Graph(figure=fig_num_dist),
html.Label("カテゴリデータの分布"),
dcc.Graph(figure=fig_cat_dist),
html.Label("数値データの相関"),
dcc.Graph(figure=fig_corr),
]
return layout
# Run the Dash app
if __name__ == "__main__":
app.run_server(debug=True)
これで、データ探索ツールが完成しました。アプリケーションを実行して、CSVファイルをアップロードし、自然言語で質問を行い、結果を表示させることができます。
課題
- どんなクエリでもエラーが出ないわけではない。
- クエリの出力結果サイズが大きい場合、langchainの出力が見切れてしまい、エラーが出てしまう
例:array([ 1, 2, 3, ..., 9997, 9998, 9999]
まとめ
この記事では、ChatGPTとLangChainを使って、自然言語でデータセットに対するクエリを行うことができるデータ探索ツールを作成しました。Dashを使ってインタラクティブなWebアプリケーションを構築し、データセットをアップロードして、自然言語で質問を行うことができるようにしました。このツールは、データ分析において、EDAを効率化することに役立つと考えています。
今後の展望
このデータ探索ツールは基本的な機能を持っていますが、さらなる改善や機能追加が可能と考えられます。以下は、今後の展望として考慮できる改善点や追加機能です。
-
複数のデータセットのサポート: 現在は1つのデータセットのみに対応していますが、複数のデータセットを同時に扱えるようにすることで、より幅広い分析が可能。
-
ユーザーアカウントの導入: ユーザーアカウントを導入することで、個々のユーザーが自分専用のスペースでデータセットを管理し、分析結果を保存。
-
データセットの共有機能: ユーザーがデータセットを共有し、他のユーザーと分析結果を交換できるようにすることで、コラボレーションを促進。
-
機械学習モデルの統合: さらに進んだ分析のために、機械学習モデルを統合し、自動的にデータセットに適用される機能を追加。
このように、今後の展望としてさまざまな改善や機能追加が可能です。このデータ探索ツールをベースに、データ分析の効率化やデータに対する理解を深めるためのさらなる開発が期待されます。
この記事が、ChatGPTとLangChainを用いたデータ分析の新たな可能性についての理解を深める手助けとなることを願っています。ぜひ試してみて、データ分析作業をよりスムーズで効率的に進めるためのツールとして活用してください。
最後に
ちなみにですが、ここの記事の内容やコードはほとんどをChatGPTで作成しました。やりたいことがあれば、本当になんでもできるようになった気がします。
参考
https://github.com/KRFH/natural-language-analysis-tool
https://python.langchain.com/en/latest/index.html
https://python.langchain.com/en/latest/modules/agents/toolkits/examples/pandas.html
https://dash.plotly.com/