LoginSignup
1
0

MicrosoftのLIDAをさわってみた

Posted at

LLMの活用方法を探していたらMicrosoftさんが「LIDA」というLLMを使ったデータ可視化ライブラリを公開していました。

とりあえずどのくらい使えるのか試してみたので、内容をシェアしたいと思います。
試したコードは下記に置いています。

公式にも以下のように注意書きがありますが、本ライブラリはLLMがコードを生成してそれを実行するという仕組みのため、サンドボックス環境での実行を推奨されています。(私はとりあえずGoogle Colabで試しました)

Note on Code Execution: To create visualizations, LIDA generates and executes code. Ensure that you run LIDA in a secure environment.

本記事の前提

  • 実行環境 -> Google Colaboratory
  • Python==3.10.12
  • lida==0.0.8

何ができるの?

LIDAはCSVやJSON形式などのデータに対して、可視化のためのコードを作成しそれを実行・表示することができるようで、その他にもコードを提案したり、修正したりもできるようです。

とりあえず試してみたところ、データセットを渡しただけで次のようなグラフが出てきました。
download.png
この提案に「凡例を消して」というと凡例を消してくれます。
download.png
「他に3つ提案して」というと次のようなグラフたちを出してきました。
download.png

掘り下げてみる

READMEによれば以下のことができるようです。

メソッド できること
summarize データの要約
goals 可視化する目標の設定 (e.g. ○○のヒストグラム)
visualize 可視化のためのコード生成
edit コードの編集
explain コードの説明
evaluate 目標に対するコードの評価
repair コードの修正
recommend 生成したコードに基づいた他の可視化コードの提案

事前にpipでライブラリをインストールしておきます。

!pip install lida

データセットにはscikit-learnのカリフォルニア住宅価格データセットを使用しました。

import pandas as pd
from sklearn.datasets import fetch_california_housing

california = fetch_california_housing()

explain_data = pd.DataFrame(california.data, columns=california.feature_names)
target_data = pd.DataFrame(california.target, columns=california.target_names)

data = pd.concat([explain_data, target_data], axis=1)

0. Managerクラス、TextGenerationConfigクラス

まずはManagerクラスをインスタンス化します。

from lida import Manager, TextGenerationConfig, llm

lida = Manager(text_gen = llm(provider="openai", api_key="YOUR_API_KEY!"))

text_genでは使用するLLMのプロバイダを指定します。
llm関数のprovider引数に渡せる文字列は以下のようになっています。

provider
openai (or default)
palm (or google)
cohere
hf (or huggingface)

今回はOpenAIモデルで試しているのでapi_keyを指定しています。HuggingFaceのモデルを使用する場合はモデル名や特殊トークン等の指定が必要となります。

Managerクラスが先ほどのメソッド群を持っているのですが、そのすべてのtextgen_config引数にTextGenerationConfigオブジェクトを渡すことができます。この設定で各メソッドの生成するコードや文章を調整できます。
TextGenerationConfigクラスは次のようなデータを持ちます。

# llmx/datamodel.py
@dataclass
class TextGenerationConfig:
    n: int = 1
    temperature: float = 0.1
    max_tokens: Union[int, None] = None
    top_p: float = 1.0
    top_k: int = 50
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    provider: Union[str, None] = None
    model: Optional[str] = None
    stop: Union[List[str], str, None] = None
    use_cache: bool = True

    def __post_init__(self):
        self._fields_dict = asdict(self)

    def __getitem__(self, key: Union[str, int]) -> Any:
        return self._fields_dict.get(key)

今回はOpenAIのGPT-3.5とGPT-4を使いますので、基本的にはtemperaturemodeluse_cacheを指定すれば事足りそうです。use_cacheについてはリポジトリ内のtutorial.ipynbで次のように説明されています。

  • Caching: Each manager method takes a textgen_config argument which is a dictionary that can be used to configure the text generation process (with parameters for model, temperature, max_tokens, topk etc). One of the keys in this dictionary is use_cache. If set to True, the manager will cache the generated text associated with that method. Use for speedup and to avoid hitting API limits.

一度生成されたテキストをキャッシュできて速度とAPI代の削減にもなるので、何かのプロダクトに組み込むときにはuse_cache=Trueにしておく方がよさそうです。
これより以下はこの設定で設定を保持しておきます。

textgen_config = TextGenerationConfig(temperature=0.5, model="gpt-3.5-turbo", use_cache=True)

1. summarize

summaryとgoalsとcodeは他のメソッドの引数に使うことになるので、LIDAの最初の動きとしては要約→目標設定→コード生成→実行→表示が基本になると思います。
以下に一連のコードブロックを示します。

from lida import Manager, TextGenerationConfig , llm
from lida.utils import plot_raster

lida = Manager(text_gen = llm("openai", api_key="YOUR_API_KEY!"))
textgen_config = TextGenerationConfig(temperature=0.5, model="gpt-3.5-turbo", use_cache=True)

# 要約
summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)

# 目標設定
goals = lida.goals(summary, n=1, textgen_config=textgen_config)

# 生成 & 実行
charts = lida.visualize(summary=summary, goal=goals[0], library="seaborn", textgen_config=textgen_config)

# 表示
plot_raster(charts[0].raster)

これで記事の最初のようなグラフが表示されます。
以下に詳細な説明をしていきます。

summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)
  • data: pandas.DataFrameかファイルパスを指定します。(対応する拡張子は .json、.csv、.xls、.xlsx、.parquet、.feather、.tsv)

注意
カラム名にあたる部分に英数字とアンダースコア以外の文字がある場合、その文字はアンダースコアに置き換えられ、元のファイルも上書きされてしまうので注意してください。
また、行数が4500を超える場合は4500行が無作為にサンプリングされます。

  • summary_method:

    • default : 各列をn_samples個サンプリングしたものとpandasの統計情報からsummaryを機械的に作成します。(n_samples=3)
    defaultでのsummaryサンプル
    {
        "name": "",
        "file_name": "",
        "dataset_description": "",
        "fields": [
            {
                "column": "MedInc",
                "properties": {
                    "dtype": "number",
                    "std": 1.8998217179452688,
                    "min": 0.4999,
                    "max": 15.0001,
                    "samples": [
                        5.0286,
                        2.0433,
                        6.1228
                    ],
                    "num_unique_values": 12928,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "HouseAge",
                "properties": {
                    "dtype": "number",
                    "std": 12.58555761211165,
                    "min": 1.0,
                    "max": 52.0,
                    "samples": [
                        35.0,
                        25.0,
                        7.0
                    ],
                    "num_unique_values": 52,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "AveRooms",
                "properties": {
                    "dtype": "number",
                    "std": 2.4741731394243187,
                    "min": 0.8461538461538461,
                    "max": 141.9090909090909,
                    "samples": [
                        6.111269614835948,
                        5.912820512820513,
                        5.7924528301886795
                    ],
                    "num_unique_values": 19392,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "AveBedrms",
                "properties": {
                    "dtype": "number",
                    "std": 0.473910856795466,
                    "min": 0.3333333333333333,
                    "max": 34.06666666666667,
                    "samples": [
                        0.9906542056074766,
                        1.112099644128114,
                        1.0398230088495575
                    ],
                    "num_unique_values": 14233,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "Population",
                "properties": {
                    "dtype": "number",
                    "std": 1132.462121765341,
                    "min": 3.0,
                    "max": 35682.0,
                    "samples": [
                        4169.0,
                        636.0,
                        3367.0
                    ],
                    "num_unique_values": 3888,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "AveOccup",
                "properties": {
                    "dtype": "number",
                    "std": 10.386049562213618,
                    "min": 0.6923076923076923,
                    "max": 1243.3333333333333,
                    "samples": [
                        2.6939799331103678,
                        3.559375,
                        3.297082228116711
                    ],
                    "num_unique_values": 18841,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "Latitude",
                "properties": {
                    "dtype": "number",
                    "std": 2.1359523974571153,
                    "min": 32.54,
                    "max": 41.95,
                    "samples": [
                        33.7,
                        34.41,
                        38.24
                    ],
                    "num_unique_values": 862,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "Longitude",
                "properties": {
                    "dtype": "number",
                    "std": 2.0035317235025882,
                    "min": -124.35,
                    "max": -114.31,
                    "samples": [
                        -118.63,
                        -119.86,
                        -121.26
                    ],
                    "num_unique_values": 844,
                    "semantic_type": "",
                    "description": ""
                }
            },
            {
                "column": "MedHouseVal",
                "properties": {
                    "dtype": "number",
                    "std": 1.1539561587441387,
                    "min": 0.14999,
                    "max": 5.00001,
                    "samples": [
                        1.943,
                        3.79,
                        2.301
                    ],
                    "num_unique_values": 3842,
                    "semantic_type": "",
                    "description": ""
                }
            }
        ],
        "field_names": [
            "MedInc",
            "HouseAge",
            "AveRooms",
            "AveBedrms",
            "Population",
            "AveOccup",
            "Latitude",
            "Longitude",
            "MedHouseVal"
        ]
    }
    
    • llm : summary["dataset_description"]にLLMによるデータセットの要約が入ります。
    llmでのsummaryサンプル
    {
        "name": "California Housing Dataset",
        "file_name": "",
        "dataset_description": "This dataset contains information about block groups in California, derived from the 1990 U.S. census. It includes various metrics like median income, house age, average rooms, average bedrooms, population, average occupation, latitude, longitude, and median house value.",
        "fields": [
            {
                "column": "MedInc",
                "properties": {
                    "dtype": "number",
                    "std": 1.8998217179452688,
                    "min": 0.4999,
                    "max": 15.0001,
                    "samples": [
                        5.0286,
                        2.0433,
                        6.1228
                    ],
                    "num_unique_values": 12928,
                    "semantic_type": "income",
                    "description": "Median income in block"
                }
            },
            {
                "column": "HouseAge",
                "properties": {
                    "dtype": "number",
                    "std": 12.58555761211165,
                    "min": 1.0,
                    "max": 52.0,
                    "samples": [
                        35.0,
                        25.0,
                        7.0
                    ],
                    "num_unique_values": 52,
                    "semantic_type": "age",
                    "description": "Median age of a house within a block"
                }
            },
            {
                "column": "AveRooms",
                "properties": {
                    "dtype": "number",
                    "std": 2.4741731394243187,
                    "min": 0.8461538461538461,
                    "max": 141.9090909090909,
                    "samples": [
                        6.111269614835948,
                        5.912820512820513,
                        5.7924528301886795
                    ],
                    "num_unique_values": 19392,
                    "semantic_type": "number",
                    "description": "Average number of rooms"
                }
            },
            {
                "column": "AveBedrms",
                "properties": {
                    "dtype": "number",
                    "std": 0.473910856795466,
                    "min": 0.3333333333333333,
                    "max": 34.06666666666667,
                    "samples": [
                        0.9906542056074766,
                        1.112099644128114,
                        1.0398230088495575
                    ],
                    "num_unique_values": 14233,
                    "semantic_type": "number",
                    "description": "Average number of bedrooms"
                }
            },
            {
                "column": "Population",
                "properties": {
                    "dtype": "number",
                    "std": 1132.462121765341,
                    "min": 3.0,
                    "max": 35682.0,
                    "samples": [
                        4169.0,
                        636.0,
                        3367.0
                    ],
                    "num_unique_values": 3888,
                    "semantic_type": "population",
                    "description": "Total population"
                }
            },
            {
                "column": "AveOccup",
                "properties": {
                    "dtype": "number",
                    "std": 10.386049562213618,
                    "min": 0.6923076923076923,
                    "max": 1243.3333333333333,
                    "samples": [
                        2.6939799331103678,
                        3.559375,
                        3.297082228116711
                    ],
                    "num_unique_values": 18841,
                    "semantic_type": "number",
                    "description": "Average house occupancy"
                }
            },
            {
                "column": "Latitude",
                "properties": {
                    "dtype": "number",
                    "std": 2.1359523974571153,
                    "min": 32.54,
                    "max": 41.95,
                    "samples": [
                        33.7,
                        34.41,
                        38.24
                    ],
                    "num_unique_values": 862,
                    "semantic_type": "latitude",
                    "description": "Latitude of the block"
                }
            },
            {
                "column": "Longitude",
                "properties": {
                    "dtype": "number",
                    "std": 2.0035317235025882,
                    "min": -124.35,
                    "max": -114.31,
                    "samples": [
                        -118.63,
                        -119.86,
                        -121.26
                    ],
                    "num_unique_values": 844,
                    "semantic_type": "longitude",
                    "description": "Longitude of the block"
                }
            },
            {
                "column": "MedHouseVal",
                "properties": {
                    "dtype": "number",
                    "std": 1.1539561587441387,
                    "min": 0.14999,
                    "max": 5.00001,
                    "samples": [
                        1.943,
                        3.79,
                        2.301
                    ],
                    "num_unique_values": 3842,
                    "semantic_type": "price",
                    "description": "Median house value for households within a block"
                }
            }
        ],
        "field_names": [
            "MedInc",
            "HouseAge",
            "AveRooms",
            "AveBedrms",
            "Population",
            "AveOccup",
            "Latitude",
            "Longitude",
            "MedHouseVal"
        ]
    }
    
    • columns : summaryはファイル名とカラム名だけのシンプルなものになります。

2. goals

goals = lida.goals(summary, n=1, textgen_config=textgen_config)
  • summary: summarizeの返り値を渡します。
  • n: Goalオブジェクトをいくつ生成するかの数値です。Goalクラスは次のようなデータを持ちます。
# lida/datamodel.py
@dataclass
class Goal:
    """A visualization goal"""

    index: int # 0
    question: str # "MedIncとMedHouseValの相関関係は?"
    visualization: str # "MedIncとMedHouseValの散布図'"
    rationale: str # "これは所得中央値と住宅価格中央値の間に関係があるかどうかを理解するのに役立つだろう"

    def _repr_markdown_(self):
        return f"""

goalsの返り値はList[Goal]になります。

3. visualize

charts = lida.visualize(summary=summary, goal=goals[0], library="seaborn", textgen_config=textgen_config)
  • summary: summarizeの返り値を渡します。
  • goal: goalsの返り値の要素を1つ渡します。
  • library: 可視化に用いるライブラリ名を指定します。内部のSystemPromptで「これ以外を使用するな」と指示されます。デフォルトはaltairで、他にmatplotlibseabornggplotが選択できます。

visualizeの中ではsummarygoalをもとにlibraryを使ったコードを生成して、それぞれのコードをexec関数で実行、plotしたグラフをPNG形式で保存し、それをBase64エンコードしたものをrasterに入れてList[ChartExecutorResponse]を返します。ChartExecutorResponseクラスは以下のようになっています。

# lida/datamodel.py
@dataclass
class ChartExecutorResponse:
    """Response from a visualization execution"""

    spec: Optional[Union[str, Dict]]  # interactive specification e.g. vegalite
    status: bool  # True if successful
    raster: Optional[str]  # base64 encoded image
    code: str  # code used to generate the visualization
    library: str  # library used to generate the visualization
    error: Optional[Dict] = None  # error message if status is False

    def _repr_mimebundle_(self, include=None, exclude=None):
        bundle = {"text/plain": self.code}
        if self.raster is not None:
            bundle["image/png"] = self.raster
        if self.spec is not None:
            bundle["application/vnd.vegalite.v5+json"] = self.spec

        return bundle

これを実際に表示するためにはutilsのplot_raster関数にrasterを渡します。単一のrasterでもいいですし、rasterのListを渡すこともできます。

from lida.utils import plot_raster

plot_raster(charts[0].raster)

4. edit

editを使うと既存の可視化コードを自然言語で編集できます。
冒頭のように「凡例を消して」というと凡例をなくしたコードを生成して、visualizeと同様にList[ChartExecutorResponse]を返してくれます。

instructions = "delete legend"
edited_charts = lida.edit(code=charts[0].code,  summary=summary, instructions=instructions, library="seaborn", textgen_config=textgen_config)
plot_raster(edited_charts[0].raster)
  • code: visualizeで生成したコードを渡します。このコードを元に編集されます。
  • summary: summarizeの返り値を渡します。
  • instruction: 自然言語で編集の指示を渡します。複数の指示を出したい場合はinstructions = ["delete legend", "change color of bar"]という風にもできます。
  • library: visualizeと同様です。

5. explain

explainでは生成された可視化コードを説明してもらえます。

explanations = lida.explain(code=edited_charts[0].code, textgen_config=textgen_config)
for row in explanations[0]:
    print(row["section"]," ** ", row["explanation"])
  • code: 説明してほしいコードを渡します。

例えばcodeに次のようなコードを渡したとします。

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

def plot(data: pd.DataFrame):
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.histplot(data=data, x='MedInc', kde=True, ax=ax)
    plt.axvline(data['MedInc'].mean(), color='red', linestyle='--')
    plt.title('', wrap=True)
    return plt;

chart = plot(data)

すると、explainでは3つのセクションに分かれて説明が返ってきます。(※文章は日本語訳しています)

accessibility
プロットはヒストグラムであり、カーネル密度推定(KDE)ラインも含まれています。ヒストグラムの棒は青色で、KDEラインも青色です。赤い破線で縦線が描かれ、これは'MedInc'列の平均値を示しています。x軸は「中央値の所得(Median Income)」とラベル付けされ、y軸は「頻度(Frequency)」とラベル付けされています。プロットのタイトルは「Median Incomeの分布はどうなっているか」です。
transformation
'plot'という関数はpandasのDataFrameを入力として受け取ります。コード内に明示的なデータ変換はありません。しかし、seabornの'histplot'関数は暗黙のうちに'MedInc'データをビン分けしてヒストグラムを作成します。'kde=True'引数はさらにデータを変換し、カーネル密度推定を計算します。
visualization
コードはまず、図のサイズを10x6に設定します。次に、seabornの'histplot'関数を使用して'MedInc'データのヒストグラムを作成し、KDEラインを追加します。ヒストグラムとKDEラインの色は青に設定されています。x軸とy軸にはそれぞれ「Median Income」、「Frequency」とラベルが付けられています。プロットにはタイトルが追加されます。最後に、'MedInc'データの平均値に赤い破線の縦線が追加されます。

6. evaluate

evaluateを使うと、生成されたコードが品質や目標に対してどうなのかを評価できます。

evaluations = lida.evaluate(code=edited_charts[0].code,  goal=goals[0], library="seaborn", textgen_config=textgen_config)
for eval in evaluations[0]:
    print(eval["dimension"], "Score" ,eval["score"], "/ 10")
    print("\t", eval["rationale"][:200])
    print("\t**********************************")
  • code: visualizeで生成したコードを渡します。このコードを評価します。
  • goal: 評価目的となるgoalを渡します。
  • library: visualizeと同様です。

評価対象のコードは以下で、

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# solution plan
# i.  ..
def plot(data: pd.DataFrame):
    plt.figure(figsize=(10, 6))
    sns.histplot(data=data, x='MedInc', kde=True, color='blue')
    plt.xlabel('Median Income')
    plt.ylabel('Frequency')
    plt.title('What is the distribution of Median Income?')
    plt.axvline(x=data['MedInc'].mean(), color='red', linestyle='--')
    return plt

chart = plot(data)

Goalオブジェクトは次のようになっているとすると、

Goal 0
Question: What is the distribution of Median Income?
// Median Income(中央値の所得)の分布はどうなっている?
Visualization: histogram of Median Income
// Median Income(中央値の所得)のヒストグラム
Rationale: This tells us about the distribution of income in the dataset.
// これにより、データセット内の所得分布について知ることができます。

以下のような感じで評価が返ってきます。(※文章は日本語訳しています)

bugs Score 10 / 10
バグや構文エラーは見つかりませんでした。
transformation Score 10 / 10
この可視化にはデータ変換は必要ありません。
compliance Score 10 / 10
コードは、「Median Income(中央値の所得)」の分布を表示するという指定された可視化の目標を満たしています。
type Score 10 / 10
ヒストグラムは、Median Income(中央値の所得)のような連続変数の分布を示すのに適した可視化のタイプです。
encoding Score 10 / 10
データは適切にエンコードされており、x軸にはMedian Income(中央値の所得)、y軸にはFrequency(頻度)が表示されています。
aesthetics Score 10 / 10
可視化の美学(外観)は適切で、明確なタイトル、軸のラベル、そして平均的なMedian Income(中央値の所得)を示す赤い破線があります。

7. repair

※未検証です。

8. recommend

n_recommendations = 3

recommended_charts = lida.recommend(code=edited_charts[0].code, summary=summary, n=n_recommendations, library="seaborn", textgen_config=textgen_config)
plot_raster([recommended_charts[i].raster for i in range(n_recommendations)], figsize=(20, 20))
  • code: visualizeで生成したコードを渡します。このコードがSystemPrompsに例として入ります。
  • summary: summarizeの返り値を渡します。
  • n: 可視化コードをいくつ提案してもらうかを指定します。
  • library: visualizeと同様です。

これで冒頭の最後のグラフのようなものが表示されます。
返り値はvisualizeやeditと同様に扱えます。

まとめ

データセットを渡すだけでグラフを表示してくれるのは良さそうです!

ただ、ユーザーが直接介入できる部分がeditのみで、あとはLLMにおまかせという部分が「欲しいのはそういうグラフじゃないんだけど...」という印象を持ちました。
もしかすると私がこのライブラリの目的を勘違いしている可能性もあるので、この後もう少し踏み込んでみようと思います。

現時点での考えられる方法としては、visualizerecommendにはsummaryを渡しますので

summary = lida.summarize(data, summary_method="default", textgen_config=textgen_config)
summary["dataset_description"] += "I need to create a scatter plot with one appropriate independent variable and the dependent variable."

のように無理やりどんなグラフが欲しいのかを付け足すことで、欲しいグラフを出させることができました。

あとは費用についてですが、要約→目標設定→コード生成→実行→表示→編集→説明→評価→提案→表示という一連の流れを行った場合に gpt-3.5-turbo を使用して $0.012 でした。( 1~2円 くらい)
gpt-4 では 40~50円 くらいで、やはりすべての処理に使うにはもったいないなという感じです。
(今回試してみた感じだと表示されるグラフに大きな違いはありませんでした)

他にもブラウザUIやInfographic Generationというβ機能もあるみたいなので、検証次第書いてみたいと思います。
ローカルLLMでどこまで行けるのかも随時試してみます。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0