0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

DatabricksのFree Editionで、Databricks Appsを使って分析手法提案アプリを作ってみた

Posted at

DatabricksのFree Editionが公開されたので早速アカウントを作って戯れているのですが、まだ東京リージョンで使えなかったDatabricks Apps(今見たら使えるようになってる・・・)が使えたり簡単なChatモデルなら普通に呼び出せたりとても嬉しいです。
今後も最新機能が東京リージョンより早く試せるかもしれません。

無料試用版というと機能制限がかなり強いイメージがあったのですが、DatabricksのFree Editionはリソース上限が1日単位/アカウント全体ともにかなりシビアなものの機能自体の制限はほぼ無いのですごいものが出てきたなという感じです。

という事で今回は、Databricks AppsからUnity CatalogとMosaic AI Model Servingを叩きに行くアプリを作ってみたくて、「入力内容に応じた分析手法を提案してくれるアプリ」を作ってみました。
分析手法の提案はLLMに丸投げですが、Databricks AppsでDBとモデルを触りに行く基礎的な部分は参考になるのではないかと思います。

Databricks Appsのいいところは、「ヘッダーからアクセストークンを取得できること」で、これは「アプリを開いているユーザーの権限でテーブルにアクセスできること」を意味します。
このあたりの制御はModel Servingだと結構大変なので、Databricks AppsとModel Servingを併用する場合にはユーザーの権限を使いたい処理はApps側に寄せると良さそうです。

実装したコード全体

ひとまず実装したコード全体を置いておきます。
UI上から「アプリを作成」→「テンプレートからインストール」でstreamlitのData appを選択して出てきたものを魔改造して作成しています。

以下順に説明していきます。
特に説明が無い場所は「Data appまたはChatbotを選択してデフォルトでできるテンプレートのコード」か「処理の流れにそこまで関係ない部分」です。(怠惰)

app.py
import os
from databricks import sql
from databricks.sdk.core import Config
import streamlit as st
import pandas as pd
import logging
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import json
import requests
import time
from mlflow.deployments import get_deploy_client

def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
    """Calls a model serving endpoint."""
    res = get_deploy_client('databricks').predict(
        endpoint=endpoint_name,
        inputs={'messages': messages, "max_tokens": max_tokens},
    )
    if "messages" in res:
        return res["messages"]
    elif "choices" in res:
        return [res["choices"][0]["message"]]
    raise Exception("This app can only run against:"
                    "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
                    "2) Databricks agent serving endpoints that implement the conversational agent schema documented "
                    "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")

def query_endpoint(endpoint_name, messages, max_tokens):
    """
    Query a chat-completions or agent serving endpoint
    If querying an agent serving endpoint that returns multiple messages, this method
    returns the last message
    ."""
    return _query_endpoint(endpoint_name, messages, max_tokens)[-1]

# ログの設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def get_user_info():
    headers = st.context.headers
    return dict(
        user_name=headers.get("X-Forwarded-Preferred-Username"),
        user_email=headers.get("X-Forwarded-Email"),
        user_id=headers.get("X-Forwarded-User"),
    )

user_info = get_user_info()

# Databricks config
cfg = Config()

# Query the SQL warehouse with the user credentials
def sql_query_with_user_token(query: str, user_token: str) -> pd.DataFrame:
    """Execute a SQL query and return the result as a pandas DataFrame."""
    with sql.connect(
        server_hostname=cfg.host,
        http_path=f"/sql/1.0/warehouses/{cfg.warehouse_id}",
        access_token=user_token  # Pass the user token into the SQL connect to query on behalf of user
    ) as connection:
        with connection.cursor() as cursor:
            cursor.execute(query)
            return cursor.fetchall_arrow().to_pandas()

st.set_page_config(layout="wide")

user_token = st.context.headers.get('X-Forwarded-Access-Token')

# Streamlit app
if "visibility" not in st.session_state:
    st.session_state.visibility = "visible"
    st.session_state.disabled = False

st.title("分析手法提案アシスタント")
st.markdown(
    "入力した内容に関連するテーブル、列、メトリックビューを検索し、その結果を用いて分析手法を提案します。"
)

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    if message["role"] != "system":
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

DATABRICKS_HOST = os.getenv("DATABRICKS_HOST")
DATABRICKS_INSTANCE = f"https://{DATABRICKS_HOST}"

def safe_encode(text):
    if text is None or pd.isna(text):
        return None
    return model.encode(text)

def safe_cosine_similarity(vec1, vec2):
    if vec1 is None or vec2 is None:
        return 0
    return float(cosine_similarity([vec1], [vec2])[0][0])

# 軽量な埋め込みモデルをロード(初回少し時間がかかる)
with st.spinner("類似検索モデルを読み込み中"):
    model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

@st.cache_data(show_spinner=False)
def embed_table():
    with st.spinner("テーブル一覧の読み込み中"):
        query = f"""
        SELECT table_catalog, table_schema, table_name, comment
        FROM system.information_schema.tables
        WHERE table_catalog != "system" AND table_schema != "information_schema"
        """
        tables_pdf = sql_query_with_user_token(query, user_token=user_token)
    # テーブル名、テーブルコメントのベクトル化
    with st.spinner("類似検索モデルをテーブル一覧に適用中"):
        tables_pdf["embedding_table_name"] = tables_pdf["table_name"].apply(safe_encode)
        tables_pdf["embedding_comment"] = tables_pdf["comment"].apply(safe_encode)
    return tables_pdf

tables_pdf = embed_table()

@st.cache_data(show_spinner=False)
def embed_column():
    with st.spinner("列一覧の読み込み中"):
        query = f"""
        SELECT table_catalog, table_schema, table_name, column_name, comment
        FROM system.information_schema.columns
        WHERE table_catalog != "system" AND table_schema != "information_schema"
        """
        columns_pdf = sql_query_with_user_token(query, user_token=user_token)
    with st.spinner("類似検索モデルを列一覧に適用中"):
        columns_pdf["embedding_column_name"] = columns_pdf["column_name"].apply(safe_encode)
        columns_pdf["embedding_comment"] = columns_pdf["comment"].apply(safe_encode)
    return columns_pdf

columns_pdf = embed_column()

PAT_for_API = "XXXXXX"

@st.cache_data(show_spinner=False)
def embed_metric_views():
    with st.spinner("メトリックビュー一覧の読み込み中"):
        # METRIC VIEWのみ抽出
        query = f"""
        SELECT table_catalog, table_schema, table_name
        FROM system.information_schema.tables
        WHERE table_type = "METRIC_VIEW"
        """
        metric_views_list = sql_query_with_user_token(query, user_token=user_token)

        # METRIC VIEWの中身を取得
        metric_views_pdf = pd.DataFrame(columns=["source", "name", "data_type", "metric_type", "metric_expr"])

        for index, row in metric_views_list.iterrows():
            table_catalog = row['table_catalog']
            table_schema = row['table_schema']
            table_name = row['table_name']
            metric_view = f"{table_catalog}.{table_schema}.{table_name}"
            url = f"{DATABRICKS_INSTANCE}/api/2.1/unity-catalog/tables/{metric_view}"
            response = requests.get(url, headers={"Authorization": f"Bearer {PAT_for_API}"})

            source = response.json().get("view_dependencies").get("dependencies")[0].get("table").get("table_full_name")
            rows = []
            for column in response.json().get("columns"):
                rows.append({
                    "source": source,
                    "name": column.get("name"),
                    "data_type": column.get("type_text"),
                    "metric_type": json.loads(column.get("type_json")).get("metadata").get("metric_view.type"),
                    "metric_expr": json.loads(column.get("type_json")).get("metadata").get("metric_view.expr")
                })
            if rows:
                metric_views_pdf = pd.concat([metric_views_pdf, pd.DataFrame(rows)], ignore_index=True)
        metric_views_pdf["table_catalog"] = metric_views_pdf["source"].str.split(".").str[0]
        metric_views_pdf["table_schema"] = metric_views_pdf["source"].str.split(".").str[1]
        metric_views_pdf["table_name"] = metric_views_pdf["source"].str.split(".").str[2]
    with st.spinner("類似検索モデルをメトリックビュー一覧に適用中"):
        metric_views_pdf["embedding_name"] = metric_views_pdf["name"].apply(safe_encode)
        metric_views_pdf["embedding_metric_expr"] = metric_views_pdf["metric_expr"].apply(safe_encode)
    return metric_views_pdf

metric_views_pdf = embed_metric_views()

if "initialize" not in st.session_state:
    st.success("検索を行う準備が完了しました。")
    st.session_state.initialize = "initialized"

def generate_similar_tables_json(input_tables_pdf):
    input_tables_pdf_subset = input_tables_pdf[["table_catalog", "table_schema", "table_name"]].drop_duplicates(subset=["table_catalog", "table_schema", "table_name"])
    merged_tables_pdf = tables_pdf.merge(input_tables_pdf_subset, on=["table_catalog", "table_schema", "table_name"], how="inner")
    similar_tables = []
    for _, table_row in merged_tables_pdf.iterrows():
        table_catalog = table_row["table_catalog"]
        table_schema = table_row["table_schema"]
        table_name = table_row["table_name"]
        table_comment = table_row["comment"]
        table_similarity = table_row["similarity"]
        columns = columns_pdf[(columns_pdf["table_catalog"] == table_catalog) & 
                              (columns_pdf["table_schema"] == table_schema) & 
                              (columns_pdf["table_name"] == table_name)][["column_name", "comment", "similarity"]]
        columns_list = [
            {
                "column_name": col["column_name"],
                "column_comment": col["comment"],
                "similarity": col["similarity"]
            }
            for _, col in columns.iterrows()
        ]
        metrics_rows = metric_views_pdf[
            (metric_views_pdf["table_catalog"] == table_catalog) &
            (metric_views_pdf["table_schema"] == table_schema) &
            (metric_views_pdf["table_name"] == table_name)
        ]
        if not metrics_rows.empty:
            metrics_list = [
                {
                    "metric_name": col["name"],
                    "data_type": col["data_type"],
                    "metric_type": col["metric_type"],
                    "metric_expr": col["metric_expr"],
                    "similarity": col["similarity"]
                }
                for _, col in metrics_rows.iterrows()
            ]
        else:
            metrics_list = []
        similar_tables.append({
            "table_name": f"{table_catalog}.{table_schema}.{table_name}",
            "table_comment": table_comment,
            "similarity": table_similarity,
            "columns": columns_list,
            "metrics": metrics_list
        })

    similar_tables_json = json.dumps(similar_tables, ensure_ascii=False, indent=2)
    return similar_tables_json

keyword = st.chat_input("分析したい内容を入力して下さい。")
if keyword:
    with st.spinner("類似度算出中"):
        keyword_embedding = model.encode(keyword)
        tables_pdf["similarity_table_name"] = tables_pdf["embedding_table_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        tables_pdf["similarity_comment"] = tables_pdf["embedding_comment"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        tables_pdf["similarity"] = tables_pdf[["similarity_table_name", "similarity_comment"]].max(axis=1)
        columns_pdf["similarity_column_name"] = columns_pdf["embedding_column_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        columns_pdf["similarity_comment"] = columns_pdf["embedding_comment"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        columns_pdf["similarity"] = columns_pdf[["similarity_column_name", "similarity_comment"]].max(axis=1)
        metric_views_pdf["similarity_name"] = metric_views_pdf["embedding_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        metric_views_pdf["similarity_metric_expr"] = metric_views_pdf["embedding_metric_expr"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        metric_views_pdf["similarity"] = metric_views_pdf[["similarity_name", "similarity_metric_expr"]].max(axis=1)

    # 1. similarity >= 0.7 を抽出
    filtered_tables_pdf = tables_pdf[tables_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_tables_pdf) < 5:
        filtered_tables_pdf = tables_pdf.sort_values(by="similarity", ascending=False).head(5)

    filtered_tables_pdf = filtered_tables_pdf.drop("embedding_table_name", axis=1).drop("embedding_comment", axis=1).drop("similarity_table_name", axis=1).drop("similarity_comment", axis=1).sort_values("similarity", ascending=False)

    # 1. similarity >= 0.7 を抽出
    filtered_columns_pdf = columns_pdf[columns_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_columns_pdf) < 5:
        filtered_columns_pdf = columns_pdf.sort_values(by="similarity", ascending=False).head(5)
    filtered_columns_pdf = filtered_columns_pdf.drop("embedding_column_name", axis=1).drop("embedding_comment", axis=1).drop("similarity_column_name", axis=1).drop("similarity_comment", axis=1).sort_values("similarity", ascending=False)

    # 1. similarity >= 0.7 を抽出
    filtered_metric_views_pdf = metric_views_pdf[metric_views_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_metric_views_pdf) < 5:
        filtered_metric_views_pdf = metric_views_pdf.sort_values(by="similarity", ascending=False).head(5)
    filtered_metric_views_pdf = filtered_metric_views_pdf.drop("embedding_name", axis=1).drop("embedding_metric_expr", axis=1).drop("similarity_name", axis=1).drop("similarity_metric_expr", axis=1).sort_values("similarity", ascending=False)
    message = f"""
        "{keyword}"について分析したい。
    """
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": message})
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(message)

    with st.spinner("回答生成中"):
        similar_tables_json = generate_similar_tables_json(filtered_tables_pdf)
        similar_columns_json = generate_similar_tables_json(filtered_columns_pdf)
        similar_metric_views_json = generate_similar_tables_json(filtered_metric_views_pdf)
        
        message = f"""
            ユーザーが右記のキーワードについて分析したがっています:{keyword}
            分析手法を提案して下さい。

            事前にテーブル、テーブルコメント、列、列コメント、メトリクスビュー定義について"{keyword}"でベクトル検索をしており、以下が検索結果です。
            それぞれ類似度0.7以上のものを抽出していますが、もし0.7以上だけで5件に満たなければ上位5件を抽出しています。そのため類似度が低いものも混ざり得ます。
            テーブル名、テーブルコメントが類似していた:
            {similar_tables_json}
            列名、列コメントが類似していた:
            {similar_columns_json}
            メトリクスビューのメジャー名、関数定義が類似していた:
            {similar_metric_views_json}

            類似度が低い場合、無理して提案に含める必要はありませんが、含めなかった場合には回答の冒頭でその理由を明示するようにして下さい。
        """

        st.session_state.messages.append({"role": "system", "content": message})

        # Display assistant response in chat message container
        with st.chat_message("assistant"):
            # Query the Databricks serving endpoint
            assistant_response = query_endpoint(
                endpoint_name="databricks-llama-4-maverick",
                messages=st.session_state.messages,
                max_tokens=4000,
            )["content"]
            st.markdown(assistant_response)

        # Add assistant response to chat history
        st.session_state.messages.append({"role": "assistant", "content": assistant_response})
requrements.txt
sentence-transformers
scikit-learn
tabulate

Model Servingへのアクセス関数定義

ここは「アプリを作成」→「テンプレートからインストール」でstreamlitのChatbotを選択した際に出てくる「model_serving_utils.py」の内容そのままを入れています。

app.py
def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
    """Calls a model serving endpoint."""
    res = get_deploy_client('databricks').predict(
        endpoint=endpoint_name,
        inputs={'messages': messages, "max_tokens": max_tokens},
    )
    if "messages" in res:
        return res["messages"]
    elif "choices" in res:
        return [res["choices"][0]["message"]]
    raise Exception("This app can only run against:"
                    "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
                    "2) Databricks agent serving endpoints that implement the conversational agent schema documented "
                    "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")

def query_endpoint(endpoint_name, messages, max_tokens):
    """
    Query a chat-completions or agent serving endpoint
    If querying an agent serving endpoint that returns multiple messages, this method
    returns the last message
    ."""
    return _query_endpoint(endpoint_name, messages, max_tokens)[-1]

エンコード用の関数がnullで落ちないようラッピング

後続でベクトル化をするのですが、null値だと落ちるのでnullの場合の挙動を定義した関数を準備しておきます。

app.py
def safe_encode(text):
    if text is None or pd.isna(text):
        return None
    return model.encode(text)

def safe_cosine_similarity(vec1, vec2):
    if vec1 is None or vec2 is None:
        return 0
    return float(cosine_similarity([vec1], [vec2])[0][0])

モデル読み込み

ベクトル化モデルを読み込みます。
今回は多言語に強くて軽いらしい「paraphrase-multilingual-MiniLM-L12-v2」を使ってみました。

app.py
# 軽量な埋め込みモデルをロード(初回少し時間がかかる)
with st.spinner("類似検索モデルを読み込み中"):
    model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')

アクセスできるテーブルおよび列の一覧を取得し、ベクトル化

system.information_schema.tablesおよびcolumnsを用いてテーブル名、テーブルコメント、列名、列コメントを取得しベクトル化します。
Databricks Appsを用いる利点として、冒頭でも述べましたが「ヘッダーからアプリ利用ユーザーのアクセストークンが取得できる」というものがあり、これにより「利用ユーザーがアクセスできるテーブルの情報」だけを簡単に持ってくることができます。これは助かる。

@st.cache_data(show_spinner=False)

によってキャッシュ化することで、チャットによるページ全体の読み込み直しが発生してもベクトル化処理は1回で済むようにしています。

app.py
@st.cache_data(show_spinner=False)
def embed_table():
    with st.spinner("テーブル一覧の読み込み中"):
        query = f"""
        SELECT table_catalog, table_schema, table_name, comment
        FROM system.information_schema.tables
        WHERE table_catalog != "system" AND table_schema != "information_schema"
        """
        tables_pdf = sql_query_with_user_token(query, user_token=user_token)
    # テーブル名、テーブルコメントのベクトル化
    with st.spinner("類似検索モデルをテーブル一覧に適用中"):
        tables_pdf["embedding_table_name"] = tables_pdf["table_name"].apply(safe_encode)
        tables_pdf["embedding_comment"] = tables_pdf["comment"].apply(safe_encode)
    return tables_pdf

tables_pdf = embed_table()

@st.cache_data(show_spinner=False)
def embed_column():
    with st.spinner("列一覧の読み込み中"):
        query = f"""
        SELECT table_catalog, table_schema, table_name, column_name, comment
        FROM system.information_schema.columns
        WHERE table_catalog != "system" AND table_schema != "information_schema"
        """
        columns_pdf = sql_query_with_user_token(query, user_token=user_token)
    with st.spinner("類似検索モデルを列一覧に適用中"):
        columns_pdf["embedding_column_name"] = columns_pdf["column_name"].apply(safe_encode)
        columns_pdf["embedding_comment"] = columns_pdf["comment"].apply(safe_encode)
    return columns_pdf

columns_pdf = embed_column()

Metric Viewのベクトル化

Databricksにおけるセマンティックレイヤー機能、Metric Viewについても取得して、メトリクス名と関数定義をベクトル化します。
Metric Viewの定義はinformation_schemaからは拾えなそうだったので、tables API経由で取得します。
Databricks Appsのヘッダーから取得できるアクセストークンは、設定から作成されるアクセストークンと違って権限が超限定されているため、API叩くのには使えないのでここはAPIを叩けるユーザーやサービスプリンシパルのトークンを発行して使います。
本当はsecretsを使うべきですが、横着して手書きです。

app.py
PAT_for_API = "XXXXXX" # ここは取得したトークンで埋める

@st.cache_data(show_spinner=False)
def embed_metric_views():
    with st.spinner("メトリックビュー一覧の読み込み中"):
        # METRIC VIEWのみ抽出
        query = f"""
        SELECT table_catalog, table_schema, table_name
        FROM system.information_schema.tables
        WHERE table_type = "METRIC_VIEW"
        """
        metric_views_list = sql_query_with_user_token(query, user_token=user_token)

        # METRIC VIEWの中身を取得
        metric_views_pdf = pd.DataFrame(columns=["source", "name", "data_type", "metric_type", "metric_expr"])

        for index, row in metric_views_list.iterrows():
            table_catalog = row['table_catalog']
            table_schema = row['table_schema']
            table_name = row['table_name']
            metric_view = f"{table_catalog}.{table_schema}.{table_name}"
            url = f"{DATABRICKS_INSTANCE}/api/2.1/unity-catalog/tables/{metric_view}"
            response = requests.get(url, headers={"Authorization": f"Bearer {PAT_for_API}"})

            source = response.json().get("view_dependencies").get("dependencies")[0].get("table").get("table_full_name")
            rows = []
            for column in response.json().get("columns"):
                rows.append({
                    "source": source,
                    "name": column.get("name"),
                    "data_type": column.get("type_text"),
                    "metric_type": json.loads(column.get("type_json")).get("metadata").get("metric_view.type"),
                    "metric_expr": json.loads(column.get("type_json")).get("metadata").get("metric_view.expr")
                })
            if rows:
                metric_views_pdf = pd.concat([metric_views_pdf, pd.DataFrame(rows)], ignore_index=True)
        metric_views_pdf["table_catalog"] = metric_views_pdf["source"].str.split(".").str[0]
        metric_views_pdf["table_schema"] = metric_views_pdf["source"].str.split(".").str[1]
        metric_views_pdf["table_name"] = metric_views_pdf["source"].str.split(".").str[2]
    with st.spinner("類似検索モデルをメトリックビュー一覧に適用中"):
        metric_views_pdf["embedding_name"] = metric_views_pdf["name"].apply(safe_encode)
        metric_views_pdf["embedding_metric_expr"] = metric_views_pdf["metric_expr"].apply(safe_encode)
    return metric_views_pdf

metric_views_pdf = embed_metric_views()

入力となるテーブルに近似値を付与し、jsonで返却する関数の定義

input_tables_pdf(pandas dataframe)のテーブルに対し、存在する列と定義されているメトリクス、および計算された類似度を付与します。
この関数が呼ばれる時点で、tables_pdfにテーブルと入力ワードの類似度が、schemas_pdfに列と入力ワードの類似度が、metric_views_pdfにメトリクスと入力ワードの類似度が既に計算されて入っています。

なぜjsonかというと、基盤モデルに投げるのにdataframeより都合が良いと聞いたからです。

app.py

def generate_similar_tables_json(input_tables_pdf):
    input_tables_pdf_subset = input_tables_pdf[["table_catalog", "table_schema", "table_name"]].drop_duplicates(subset=["table_catalog", "table_schema", "table_name"])
    merged_tables_pdf = tables_pdf.merge(input_tables_pdf_subset, on=["table_catalog", "table_schema", "table_name"], how="inner")
    similar_tables = []
    for _, table_row in merged_tables_pdf.iterrows():
        table_catalog = table_row["table_catalog"]
        table_schema = table_row["table_schema"]
        table_name = table_row["table_name"]
        table_comment = table_row["comment"]
        table_similarity = table_row["similarity"]
        columns = columns_pdf[(columns_pdf["table_catalog"] == table_catalog) & 
                              (columns_pdf["table_schema"] == table_schema) & 
                              (columns_pdf["table_name"] == table_name)][["column_name", "comment", "similarity"]]
        columns_list = [
            {
                "column_name": col["column_name"],
                "column_comment": col["comment"],
                "similarity": col["similarity"]
            }
            for _, col in columns.iterrows()
        ]
        metrics_rows = metric_views_pdf[
            (metric_views_pdf["table_catalog"] == table_catalog) &
            (metric_views_pdf["table_schema"] == table_schema) &
            (metric_views_pdf["table_name"] == table_name)
        ]
        if not metrics_rows.empty:
            metrics_list = [
                {
                    "metric_name": col["name"],
                    "data_type": col["data_type"],
                    "metric_type": col["metric_type"],
                    "metric_expr": col["metric_expr"],
                    "similarity": col["similarity"]
                }
                for _, col in metrics_rows.iterrows()
            ]
        else:
            metrics_list = []
        similar_tables.append({
            "table_name": f"{table_catalog}.{table_schema}.{table_name}",
            "table_comment": table_comment,
            "similarity": table_similarity,
            "columns": columns_list,
            "metrics": metrics_list
        })

    similar_tables_json = json.dumps(similar_tables, ensure_ascii=False, indent=2)
    return similar_tables_json

入力を受け付けて類似度を算出

ここまでで準備が整ったので、入力を受け付けます。
入力をベクトル化し、テーブル名orテーブルコメント、列名or列コメント、メトリクスorメトリクス定義それぞれの類似度を算出、0.7以上のもの全てか、上位5券を抽出します。

app.py
keyword = st.chat_input("分析したい内容を入力して下さい。")
if keyword:
    with st.spinner("類似度算出中"):
        keyword_embedding = model.encode(keyword)
        tables_pdf["similarity_table_name"] = tables_pdf["embedding_table_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        tables_pdf["similarity_comment"] = tables_pdf["embedding_comment"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        tables_pdf["similarity"] = tables_pdf[["similarity_table_name", "similarity_comment"]].max(axis=1)
        columns_pdf["similarity_column_name"] = columns_pdf["embedding_column_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        columns_pdf["similarity_comment"] = columns_pdf["embedding_comment"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        columns_pdf["similarity"] = columns_pdf[["similarity_column_name", "similarity_comment"]].max(axis=1)
        metric_views_pdf["similarity_name"] = metric_views_pdf["embedding_name"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        metric_views_pdf["similarity_metric_expr"] = metric_views_pdf["embedding_metric_expr"].apply(lambda x: safe_cosine_similarity(x, keyword_embedding))
        metric_views_pdf["similarity"] = metric_views_pdf[["similarity_name", "similarity_metric_expr"]].max(axis=1)

    # 1. similarity >= 0.7 を抽出
    filtered_tables_pdf = tables_pdf[tables_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_tables_pdf) < 5:
        filtered_tables_pdf = tables_pdf.sort_values(by="similarity", ascending=False).head(5)

    filtered_tables_pdf = filtered_tables_pdf.drop("embedding_table_name", axis=1).drop("embedding_comment", axis=1).drop("similarity_table_name", axis=1).drop("similarity_comment", axis=1).sort_values("similarity", ascending=False)

    # 1. similarity >= 0.7 を抽出
    filtered_columns_pdf = columns_pdf[columns_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_columns_pdf) < 5:
        filtered_columns_pdf = columns_pdf.sort_values(by="similarity", ascending=False).head(5)
    filtered_columns_pdf = filtered_columns_pdf.drop("embedding_column_name", axis=1).drop("embedding_comment", axis=1).drop("similarity_column_name", axis=1).drop("similarity_comment", axis=1).sort_values("similarity", ascending=False)

    # 1. similarity >= 0.7 を抽出
    filtered_metric_views_pdf = metric_views_pdf[metric_views_pdf["similarity"] >= 0.7]
    # 2. 5件に満たなければ上位5件を取得
    if len(filtered_metric_views_pdf) < 5:
        filtered_metric_views_pdf = metric_views_pdf.sort_values(by="similarity", ascending=False).head(5)
    filtered_metric_views_pdf = filtered_metric_views_pdf.drop("embedding_name", axis=1).drop("embedding_metric_expr", axis=1).drop("similarity_name", axis=1).drop("similarity_metric_expr", axis=1).sort_values("similarity", ascending=False)

類似度抽出結果をModel Servingの基盤モデルに投げる

ここまではUnityCatalogの世界でしたが、ここからはUnityCatalogから取ってきた情報をLLMに投げます。
もっとチューニングの余地はあると思いますが、今回はLLMに丸投げです。

app.py
    message = f"""
        "{keyword}"について分析したい。
    """
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": message})
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(message)

    with st.spinner("回答生成中"):
        similar_tables_json = generate_similar_tables_json(filtered_tables_pdf)
        similar_columns_json = generate_similar_tables_json(filtered_columns_pdf)
        similar_metric_views_json = generate_similar_tables_json(filtered_metric_views_pdf)
        
        message = f"""
            ユーザーが右記のキーワードについて分析したがっています:{keyword}
            分析手法を提案して下さい。

            事前にテーブル、テーブルコメント、列、列コメント、メトリクスビュー定義について"{keyword}"でベクトル検索をしており、以下が検索結果です。
            それぞれ類似度0.7以上のものを抽出していますが、もし0.7以上だけで5件に満たなければ上位5件を抽出しています。そのため類似度が低いものも混ざり得ます。
            テーブル名、テーブルコメントが類似していた:
            {similar_tables_json}
            列名、列コメントが類似していた:
            {similar_columns_json}
            メトリクスビューのメジャー名、関数定義が類似していた:
            {similar_metric_views_json}

            類似度が低い場合、無理して提案に含める必要はありませんが、含めなかった場合には回答の冒頭でその理由を明示するようにして下さい。
        """

        st.session_state.messages.append({"role": "system", "content": message})

        # Display assistant response in chat message container
        with st.chat_message("assistant"):
            # Query the Databricks serving endpoint
            assistant_response = query_endpoint(
                endpoint_name="databricks-llama-4-maverick",
                messages=st.session_state.messages,
                max_tokens=4000,
            )["content"]
            st.markdown(assistant_response)

        # Add assistant response to chat history
        st.session_state.messages.append({"role": "assistant", "content": assistant_response})

実際に動かしてみる。

こんな感じの画面になります。
ロクなデータも無いのでこんな感じですが、テーブルコメントや列コメントといった論理名の整備、Metric Viewの整備をしていけばある程度は返してくれるうようになりそうです。

image.png

感想

アプリはほぼ初心者レベルでしたが、裏側の処理も書きやすく可能性を感じました。(小並感)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?