LLMの活用方法を探していたらMicrosoftさんが「LIDA」というLLMを使ったデータ可視化ライブラリを公開していました。
とりあえずどのくらい使えるのか試してみたので、内容をシェアしたいと思います。
試したコードは下記に置いています。
公式にも以下のように注意書きがありますが、本ライブラリはLLMがコードを生成してそれを実行するという仕組みのため、サンドボックス環境での実行を推奨されています。(私はとりあえずGoogle Colabで試しました)
Note on Code Execution: To create visualizations, LIDA generates and executes code. Ensure that you run LIDA in a secure environment.
本記事の前提
- 実行環境 -> Google Colaboratory
- Python==3.10.12
- lida==0.0.8
何ができるの?
LIDAはCSVやJSON形式などのデータに対して、可視化のためのコードを作成しそれを実行・表示することができるようで、その他にもコードを提案したり、修正したりもできるようです。
とりあえず試してみたところ、データセットを渡しただけで次のようなグラフが出てきました。
この提案に「凡例を消して」というと凡例を消してくれます。
「他に3つ提案して」というと次のようなグラフたちを出してきました。
掘り下げてみる
READMEによれば以下のことができるようです。
メソッド | できること |
---|---|
summarize | データの要約 |
goals | 可視化する目標の設定 (e.g. ○○のヒストグラム) |
visualize | 可視化のためのコード生成 |
edit | コードの編集 |
explain | コードの説明 |
evaluate | 目標に対するコードの評価 |
repair | コードの修正 |
recommend | 生成したコードに基づいた他の可視化コードの提案 |
事前にpipでライブラリをインストールしておきます。
!pip install lida
データセットにはscikit-learnのカリフォルニア住宅価格データセットを使用しました。
import pandas as pd
from sklearn.datasets import fetch_california_housing
california = fetch_california_housing()
explain_data = pd.DataFrame(california.data, columns=california.feature_names)
target_data = pd.DataFrame(california.target, columns=california.target_names)
data = pd.concat([explain_data, target_data], axis=1)
0. Managerクラス、TextGenerationConfigクラス
まずはManagerクラスをインスタンス化します。
from lida import Manager, TextGenerationConfig, llm
lida = Manager(text_gen = llm(provider="openai", api_key="YOUR_API_KEY!"))
text_gen
では使用するLLMのプロバイダを指定します。
llm
関数のprovider
引数に渡せる文字列は以下のようになっています。
provider |
---|
openai (or default) |
palm (or google) |
cohere |
hf (or huggingface) |
今回はOpenAIモデルで試しているのでapi_key
を指定しています。HuggingFaceのモデルを使用する場合はモデル名や特殊トークン等の指定が必要となります。
Manager
クラスが先ほどのメソッド群を持っているのですが、そのすべてのtextgen_config
引数にTextGenerationConfig
オブジェクトを渡すことができます。この設定で各メソッドの生成するコードや文章を調整できます。
TextGenerationConfig
クラスは次のようなデータを持ちます。
# llmx/datamodel.py
@dataclass
class TextGenerationConfig:
n: int = 1
temperature: float = 0.1
max_tokens: Union[int, None] = None
top_p: float = 1.0
top_k: int = 50
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
provider: Union[str, None] = None
model: Optional[str] = None
stop: Union[List[str], str, None] = None
use_cache: bool = True
def __post_init__(self):
self._fields_dict = asdict(self)
def __getitem__(self, key: Union[str, int]) -> Any:
return self._fields_dict.get(key)
今回はOpenAIのGPT-3.5とGPT-4を使いますので、基本的にはtemperature
、model
、use_cache
を指定すれば事足りそうです。use_cacheについてはリポジトリ内のtutorial.ipynbで次のように説明されています。
- Caching: Each manager method takes a
textgen_config
argument which is a dictionary that can be used to configure the text generation process (with parameters for model, temperature, max_tokens, topk etc). One of the keys in this dictionary isuse_cache
. If set toTrue
, the manager will cache the generated text associated with that method. Use for speedup and to avoid hitting API limits.
一度生成されたテキストをキャッシュできて速度とAPI代の削減にもなるので、何かのプロダクトに組み込むときにはuse_cache=True
にしておく方がよさそうです。
これより以下はこの設定で設定を保持しておきます。
textgen_config = TextGenerationConfig(temperature=0.5, model="gpt-3.5-turbo", use_cache=True)
1. summarize
summaryとgoalsとcodeは他のメソッドの引数に使うことになるので、LIDAの最初の動きとしては要約→目標設定→コード生成→実行→表示が基本になると思います。
以下に一連のコードブロックを示します。
from lida import Manager, TextGenerationConfig , llm
from lida.utils import plot_raster
lida = Manager(text_gen = llm("openai", api_key="YOUR_API_KEY!"))
textgen_config = TextGenerationConfig(temperature=0.5, model="gpt-3.5-turbo", use_cache=True)
# 要約
summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)
# 目標設定
goals = lida.goals(summary, n=1, textgen_config=textgen_config)
# 生成 & 実行
charts = lida.visualize(summary=summary, goal=goals[0], library="seaborn", textgen_config=textgen_config)
# 表示
plot_raster(charts[0].raster)
これで記事の最初のようなグラフが表示されます。
以下に詳細な説明をしていきます。
summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)
- data: pandas.DataFrameかファイルパスを指定します。(対応する拡張子は .json、.csv、.xls、.xlsx、.parquet、.feather、.tsv)
注意
カラム名にあたる部分に英数字とアンダースコア以外の文字がある場合、その文字はアンダースコアに置き換えられ、元のファイルも上書きされてしまうので注意してください。
また、行数が4500を超える場合は4500行が無作為にサンプリングされます。
-
summary_method:
- default : 各列をn_samples個サンプリングしたものとpandasの統計情報からsummaryを機械的に作成します。(n_samples=3)
defaultでのsummaryサンプル
{ "name": "", "file_name": "", "dataset_description": "", "fields": [ { "column": "MedInc", "properties": { "dtype": "number", "std": 1.8998217179452688, "min": 0.4999, "max": 15.0001, "samples": [ 5.0286, 2.0433, 6.1228 ], "num_unique_values": 12928, "semantic_type": "", "description": "" } }, { "column": "HouseAge", "properties": { "dtype": "number", "std": 12.58555761211165, "min": 1.0, "max": 52.0, "samples": [ 35.0, 25.0, 7.0 ], "num_unique_values": 52, "semantic_type": "", "description": "" } }, { "column": "AveRooms", "properties": { "dtype": "number", "std": 2.4741731394243187, "min": 0.8461538461538461, "max": 141.9090909090909, "samples": [ 6.111269614835948, 5.912820512820513, 5.7924528301886795 ], "num_unique_values": 19392, "semantic_type": "", "description": "" } }, { "column": "AveBedrms", "properties": { "dtype": "number", "std": 0.473910856795466, "min": 0.3333333333333333, "max": 34.06666666666667, "samples": [ 0.9906542056074766, 1.112099644128114, 1.0398230088495575 ], "num_unique_values": 14233, "semantic_type": "", "description": "" } }, { "column": "Population", "properties": { "dtype": "number", "std": 1132.462121765341, "min": 3.0, "max": 35682.0, "samples": [ 4169.0, 636.0, 3367.0 ], "num_unique_values": 3888, "semantic_type": "", "description": "" } }, { "column": "AveOccup", "properties": { "dtype": "number", "std": 10.386049562213618, "min": 0.6923076923076923, "max": 1243.3333333333333, "samples": [ 2.6939799331103678, 3.559375, 3.297082228116711 ], "num_unique_values": 18841, "semantic_type": "", "description": "" } }, { "column": "Latitude", "properties": { "dtype": "number", "std": 2.1359523974571153, "min": 32.54, "max": 41.95, "samples": [ 33.7, 34.41, 38.24 ], "num_unique_values": 862, "semantic_type": "", "description": "" } }, { "column": "Longitude", "properties": { "dtype": "number", "std": 2.0035317235025882, "min": -124.35, "max": -114.31, "samples": [ -118.63, -119.86, -121.26 ], "num_unique_values": 844, "semantic_type": "", "description": "" } }, { "column": "MedHouseVal", "properties": { "dtype": "number", "std": 1.1539561587441387, "min": 0.14999, "max": 5.00001, "samples": [ 1.943, 3.79, 2.301 ], "num_unique_values": 3842, "semantic_type": "", "description": "" } } ], "field_names": [ "MedInc", "HouseAge", "AveRooms", "AveBedrms", "Population", "AveOccup", "Latitude", "Longitude", "MedHouseVal" ] }
- llm :
summary["dataset_description"]
にLLMによるデータセットの要約が入ります。
llmでのsummaryサンプル
{ "name": "California Housing Dataset", "file_name": "", "dataset_description": "This dataset contains information about block groups in California, derived from the 1990 U.S. census. It includes various metrics like median income, house age, average rooms, average bedrooms, population, average occupation, latitude, longitude, and median house value.", "fields": [ { "column": "MedInc", "properties": { "dtype": "number", "std": 1.8998217179452688, "min": 0.4999, "max": 15.0001, "samples": [ 5.0286, 2.0433, 6.1228 ], "num_unique_values": 12928, "semantic_type": "income", "description": "Median income in block" } }, { "column": "HouseAge", "properties": { "dtype": "number", "std": 12.58555761211165, "min": 1.0, "max": 52.0, "samples": [ 35.0, 25.0, 7.0 ], "num_unique_values": 52, "semantic_type": "age", "description": "Median age of a house within a block" } }, { "column": "AveRooms", "properties": { "dtype": "number", "std": 2.4741731394243187, "min": 0.8461538461538461, "max": 141.9090909090909, "samples": [ 6.111269614835948, 5.912820512820513, 5.7924528301886795 ], "num_unique_values": 19392, "semantic_type": "number", "description": "Average number of rooms" } }, { "column": "AveBedrms", "properties": { "dtype": "number", "std": 0.473910856795466, "min": 0.3333333333333333, "max": 34.06666666666667, "samples": [ 0.9906542056074766, 1.112099644128114, 1.0398230088495575 ], "num_unique_values": 14233, "semantic_type": "number", "description": "Average number of bedrooms" } }, { "column": "Population", "properties": { "dtype": "number", "std": 1132.462121765341, "min": 3.0, "max": 35682.0, "samples": [ 4169.0, 636.0, 3367.0 ], "num_unique_values": 3888, "semantic_type": "population", "description": "Total population" } }, { "column": "AveOccup", "properties": { "dtype": "number", "std": 10.386049562213618, "min": 0.6923076923076923, "max": 1243.3333333333333, "samples": [ 2.6939799331103678, 3.559375, 3.297082228116711 ], "num_unique_values": 18841, "semantic_type": "number", "description": "Average house occupancy" } }, { "column": "Latitude", "properties": { "dtype": "number", "std": 2.1359523974571153, "min": 32.54, "max": 41.95, "samples": [ 33.7, 34.41, 38.24 ], "num_unique_values": 862, "semantic_type": "latitude", "description": "Latitude of the block" } }, { "column": "Longitude", "properties": { "dtype": "number", "std": 2.0035317235025882, "min": -124.35, "max": -114.31, "samples": [ -118.63, -119.86, -121.26 ], "num_unique_values": 844, "semantic_type": "longitude", "description": "Longitude of the block" } }, { "column": "MedHouseVal", "properties": { "dtype": "number", "std": 1.1539561587441387, "min": 0.14999, "max": 5.00001, "samples": [ 1.943, 3.79, 2.301 ], "num_unique_values": 3842, "semantic_type": "price", "description": "Median house value for households within a block" } } ], "field_names": [ "MedInc", "HouseAge", "AveRooms", "AveBedrms", "Population", "AveOccup", "Latitude", "Longitude", "MedHouseVal" ] }
- columns : summaryはファイル名とカラム名だけのシンプルなものになります。
2. goals
goals = lida.goals(summary, n=1, textgen_config=textgen_config)
- summary: summarizeの返り値を渡します。
- n: Goalオブジェクトをいくつ生成するかの数値です。Goalクラスは次のようなデータを持ちます。
# lida/datamodel.py
@dataclass
class Goal:
"""A visualization goal"""
index: int # 0
question: str # "MedIncとMedHouseValの相関関係は?"
visualization: str # "MedIncとMedHouseValの散布図'"
rationale: str # "これは所得中央値と住宅価格中央値の間に関係があるかどうかを理解するのに役立つだろう"
def _repr_markdown_(self):
return f"""
goalsの返り値はList[Goal]
になります。
3. visualize
charts = lida.visualize(summary=summary, goal=goals[0], library="seaborn", textgen_config=textgen_config)
- summary: summarizeの返り値を渡します。
- goal: goalsの返り値の要素を1つ渡します。
-
library: 可視化に用いるライブラリ名を指定します。内部のSystemPromptで「これ以外を使用するな」と指示されます。デフォルトは
altair
で、他にmatplotlib
、seaborn
、ggplot
が選択できます。
visualizeの中ではsummary
とgoal
をもとにlibrary
を使ったコードを生成して、それぞれのコードをexec
関数で実行、plotしたグラフをPNG形式で保存し、それをBase64エンコードしたものをrasterに入れてList[ChartExecutorResponse]
を返します。ChartExecutorResponse
クラスは以下のようになっています。
# lida/datamodel.py
@dataclass
class ChartExecutorResponse:
"""Response from a visualization execution"""
spec: Optional[Union[str, Dict]] # interactive specification e.g. vegalite
status: bool # True if successful
raster: Optional[str] # base64 encoded image
code: str # code used to generate the visualization
library: str # library used to generate the visualization
error: Optional[Dict] = None # error message if status is False
def _repr_mimebundle_(self, include=None, exclude=None):
bundle = {"text/plain": self.code}
if self.raster is not None:
bundle["image/png"] = self.raster
if self.spec is not None:
bundle["application/vnd.vegalite.v5+json"] = self.spec
return bundle
これを実際に表示するためにはutilsのplot_raster
関数にraster
を渡します。単一のrasterでもいいですし、rasterのListを渡すこともできます。
from lida.utils import plot_raster
plot_raster(charts[0].raster)
4. edit
editを使うと既存の可視化コードを自然言語で編集できます。
冒頭のように「凡例を消して」というと凡例をなくしたコードを生成して、visualizeと同様にList[ChartExecutorResponse]
を返してくれます。
instructions = "delete legend"
edited_charts = lida.edit(code=charts[0].code, summary=summary, instructions=instructions, library="seaborn", textgen_config=textgen_config)
plot_raster(edited_charts[0].raster)
- code: visualizeで生成したコードを渡します。このコードを元に編集されます。
- summary: summarizeの返り値を渡します。
-
instruction: 自然言語で編集の指示を渡します。複数の指示を出したい場合は
instructions = ["delete legend", "change color of bar"]
という風にもできます。 - library: visualizeと同様です。
5. explain
explainでは生成された可視化コードを説明してもらえます。
explanations = lida.explain(code=edited_charts[0].code, textgen_config=textgen_config)
for row in explanations[0]:
print(row["section"]," ** ", row["explanation"])
- code: 説明してほしいコードを渡します。
例えばcode
に次のようなコードを渡したとします。
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
def plot(data: pd.DataFrame):
fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(data=data, x='MedInc', kde=True, ax=ax)
plt.axvline(data['MedInc'].mean(), color='red', linestyle='--')
plt.title('', wrap=True)
return plt;
chart = plot(data)
すると、explainでは3つのセクションに分かれて説明が返ってきます。(※文章は日本語訳しています)
- accessibility
- プロットはヒストグラムであり、カーネル密度推定(KDE)ラインも含まれています。ヒストグラムの棒は青色で、KDEラインも青色です。赤い破線で縦線が描かれ、これは'MedInc'列の平均値を示しています。x軸は「中央値の所得(Median Income)」とラベル付けされ、y軸は「頻度(Frequency)」とラベル付けされています。プロットのタイトルは「Median Incomeの分布はどうなっているか」です。
- transformation
- 'plot'という関数はpandasのDataFrameを入力として受け取ります。コード内に明示的なデータ変換はありません。しかし、seabornの'histplot'関数は暗黙のうちに'MedInc'データをビン分けしてヒストグラムを作成します。'kde=True'引数はさらにデータを変換し、カーネル密度推定を計算します。
- visualization
- コードはまず、図のサイズを10x6に設定します。次に、seabornの'histplot'関数を使用して'MedInc'データのヒストグラムを作成し、KDEラインを追加します。ヒストグラムとKDEラインの色は青に設定されています。x軸とy軸にはそれぞれ「Median Income」、「Frequency」とラベルが付けられています。プロットにはタイトルが追加されます。最後に、'MedInc'データの平均値に赤い破線の縦線が追加されます。
6. evaluate
evaluateを使うと、生成されたコードが品質や目標に対してどうなのかを評価できます。
evaluations = lida.evaluate(code=edited_charts[0].code, goal=goals[0], library="seaborn", textgen_config=textgen_config)
for eval in evaluations[0]:
print(eval["dimension"], "Score" ,eval["score"], "/ 10")
print("\t", eval["rationale"][:200])
print("\t**********************************")
- code: visualizeで生成したコードを渡します。このコードを評価します。
- goal: 評価目的となるgoalを渡します。
- library: visualizeと同様です。
評価対象のコードは以下で、
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
# solution plan
# i. ..
def plot(data: pd.DataFrame):
plt.figure(figsize=(10, 6))
sns.histplot(data=data, x='MedInc', kde=True, color='blue')
plt.xlabel('Median Income')
plt.ylabel('Frequency')
plt.title('What is the distribution of Median Income?')
plt.axvline(x=data['MedInc'].mean(), color='red', linestyle='--')
return plt
chart = plot(data)
Goalオブジェクトは次のようになっているとすると、
Goal 0
Question: What is the distribution of Median Income?
// Median Income(中央値の所得)の分布はどうなっている?
Visualization: histogram of Median Income
// Median Income(中央値の所得)のヒストグラム
Rationale: This tells us about the distribution of income in the dataset.
// これにより、データセット内の所得分布について知ることができます。
以下のような感じで評価が返ってきます。(※文章は日本語訳しています)
- bugs Score 10 / 10
- バグや構文エラーは見つかりませんでした。
- transformation Score 10 / 10
- この可視化にはデータ変換は必要ありません。
- compliance Score 10 / 10
- コードは、「Median Income(中央値の所得)」の分布を表示するという指定された可視化の目標を満たしています。
- type Score 10 / 10
- ヒストグラムは、Median Income(中央値の所得)のような連続変数の分布を示すのに適した可視化のタイプです。
- encoding Score 10 / 10
- データは適切にエンコードされており、x軸にはMedian Income(中央値の所得)、y軸にはFrequency(頻度)が表示されています。
- aesthetics Score 10 / 10
- 可視化の美学(外観)は適切で、明確なタイトル、軸のラベル、そして平均的なMedian Income(中央値の所得)を示す赤い破線があります。
7. repair
※未検証です。
8. recommend
n_recommendations = 3
recommended_charts = lida.recommend(code=edited_charts[0].code, summary=summary, n=n_recommendations, library="seaborn", textgen_config=textgen_config)
plot_raster([recommended_charts[i].raster for i in range(n_recommendations)], figsize=(20, 20))
- code: visualizeで生成したコードを渡します。このコードがSystemPrompsに例として入ります。
- summary: summarizeの返り値を渡します。
- n: 可視化コードをいくつ提案してもらうかを指定します。
- library: visualizeと同様です。
これで冒頭の最後のグラフのようなものが表示されます。
返り値はvisualizeやeditと同様に扱えます。
まとめ
データセットを渡すだけでグラフを表示してくれるのは良さそうです!
ただ、ユーザーが直接介入できる部分がedit
のみで、あとはLLMにおまかせという部分が「欲しいのはそういうグラフじゃないんだけど...」という印象を持ちました。
もしかすると私がこのライブラリの目的を勘違いしている可能性もあるので、この後もう少し踏み込んでみようと思います。
現時点での考えられる方法としては、visualize
やrecommend
にはsummary
を渡しますので
summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)
summary["dataset_description"] += "I need to create a scatter plot with one appropriate independent variable and the dependent variable."
のように無理やりどんなグラフが欲しいのかを付け足すことで、欲しいグラフを出させることができました。
あとは費用についてですが、要約→目標設定→コード生成→実行→表示→編集→説明→評価→提案→表示という一連の流れを行った場合に gpt-3.5-turbo を使用して $0.012 でした。( 1~2円 くらい)
gpt-4 では 40~50円 くらいで、やはりすべての処理に使うにはもったいないなという感じです。
(今回試してみた感じだと表示されるグラフに大きな違いはありませんでした)
他にもブラウザUIやInfographic Generation
というβ機能もあるみたいなので、検証次第書いてみたいと思います。
ローカルLLMでどこまで行けるのかも随時試してみます。