46
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

生成AIを活用したテキスト分類/名寄せのアイデア【Databricks】

Posted at

はじめに

Databricksに入社したskatoと申します!社員としては初投稿です。
今回は以前から温めていた課題「テキスト分類(名寄せ)作業、生成AIでなんとか楽にできないか?」を解決するため、Databricksを使っていろいろ検証した過程を記事にしました。

そもそもDatabricksとは何か?については、こちらの記事が詳しいのでぜひご覧いただければと思います!

背景・目的

アナリスト(分析担当者)が自社の商品データを集計してレポートを作成する際、商品名が適切に分類されていないことに悩むことが多いかと思います。
例えば、以下のようなケースが考えられます。

<商品データの集計でよくある課題>

# 課題 詳細
1 分類の不備 商品カテゴリが、集計したい単位で整備されていない 家電メーカー
商品の売上を「テレビ」「冷蔵庫」「洗濯機」...で分けて集計したいが、適切なカテゴリー分けがなされていない
2 正規化の欠如 商品名に、集計に不要なワードが含まれている 食品メーカー
「○○ポテトチップス 規格サイズ」と「○○ポテチ ファミリーパック」を、同一の「◯◯ポテトチップス」として扱いたい
3 名寄せの不備 同一商品だが、表記が別名になっている アパレルメーカー
「○○ブランド ロゴ入りパーカー」と「○○パーカー ロゴ付き」は同一商品であるが、同じものとして扱われていない

大抵の場合は人力で商品マスタをなんとか整備する(Excelで名称を修正していく/SQLでCASE文をひたすら書く...)ことが多いと思います。一方で、新しい商品が出た際のメンテナンスが大変だったり、他の人も似たような作業をやっていたりと、何かと時間コストがかかる作業かと思います。

こうした商品名の分類作業を生成AIでなんとか楽にできないか?を考えるため、Databricksで検証を行います。

要件

どのように生成AIを活用するべきか考えるため、ビジネス要件と技術要件を設定します。

業務要件

  • 自社ECを持つ家電量販店を想定する。特に、EC出品時に付与される商品カテゴリとは別軸で、売上集計のために新たな基準による商品分類を行いたいと仮定する
  • 上記を解決するため、商品名をインプットに、「エアコン・照明」「キッチン家電」...といった商品カテゴリを出力する分類器が必要である

技術要件

安定的に使うために下記2件を満たすことを目標とします。

  1. 統一した形式で返却される: 通常LLMの出力は非構造の文字列形式で出力されるが、分類を機械的に行えるよう、APIのように返却値をJSONで固定するように実装する
  2. 多くの人が同じモデルを参照できる: 今後もすべての担当者が同じロジックで分類を付与できることが望ましい。ビジネスアナリストやデータパイプラインにおいても気軽に使用できる SQL関数 として実装することを目指す

実装

実装手順について解説します。結果だけ知りたい方は0.とりあえず結果のみご参照ください!

0. とりあえず結果

SQLで関数として使用できるテキスト分類器prodname_classifier_udf()を実装しました。
Databricks SQL上で使用できます。モデルはDBRX(DatabricksのオープンソースLLMモデル)を使用しています。RAGやプリトレーニングは行っていません。

SQL
-- 使用例
WITH tar AS (
  SELECT * FROM household_appliance_ec LIMIT 20
)
SELECT
  product_name
  , prodname_classifier_udf(product_name, secret('skato-scope', 'token')) as res -- 生成AIによる分類器で商品名を分類
  , res.label as label -- 推定された商品名
  , res.confidence as confidence -- 推定値の確信度(どれだけ分類に自信があるか)
FROM
  tar

image.png

入力に対して、label(分類結果)とconfidence(分類に対する確信度)を出力します。なかなかの精度かと思います!

1. 準備

実際にSQL関数を作成します。この章では1-1のみ必須作業となります。

1-1. パラメータの設定

アクセストークン、カタログ、スキーマの指定を行います。
シークレットの作成方法の詳細はシークレット管理を参照しながら事前に登録しておきます。

python
# パラメータの設定
# 事前にDatabricksのシークレットのスコープにDatabricksアクセストークンを登録しておく
import os
os.environ['DATABRICKS_TOKEN']  = dbutils.secrets.get('skato-scope','token')
DATABRICKS_TOKEN = os.environ.get('DATABRICKS_TOKEN')

# モデル・関数を保存するカタログ名・スキーマ名を定義する
catalog_name = "skato" # 使用するカタログの名称を宣言
schema_name = "name_recognition" # 使用するスキーマの名称を宣言

1-2. 使用するLLMの確認

今回はベースとなるLLMにDBRX(Databricksが公開したオープンソースLLMモデル)を採用します。自身で使用したいモデルに差し替えることも可能です。
Preview機能ではありますが、DatabricksではDBRXを含むいくつかのサービング済のオープンソース基盤モデルを使用することが可能です。
下記コードの<my-workspace-domain>にはDatabricksのワークスペースのドメイン名を入力してください。

python
from openai import OpenAI

client = OpenAI(
  api_key=DATABRICKS_TOKEN,
  base_url="https://<my-workspace-domain>/serving-endpoints"
)

chat_completion = client.chat.completions.create(
  messages=[
    {
      "role": "system",
      "content": "あなたはAIアシスタントです。"
    },
    {
      "role": "user",
      "content": "大規模言語モデルについて日本語で解説してください。"
     },
  ],
  model="databricks-dbrx-instruct",
  max_tokens=256
)

print(chat_completion.choices[0].message.content)

# 出力例: 
# 大規模言語モデルとは、大量のテキストデータを学習した人工知能モデルです。このモデルは深層学習技術を用いて、言語のパターンや構造を理解し、生成することができます。大規模言語モデルは、様々な自然言語処理タスクに応用することができ、機械翻訳、文書要約、チャットボットなど、多くの分野で活用されています。

1-3. ダミーデータの確認

分類に使用するデータhousehold_appliance_ecをインポートします。今回は分類対象としてproduct_nameのみを参照します。

使用したデータは生成AIによって生成した架空のデータであり、価格・注文日等は実際の店舗のデータには基づいたものではありません

python
df = spark.read.table(f"{catalog_name}.{schema_name}.household_appliance_ec")
display(df)

image.png

2. モデル実装〜サービング

LLMモデルをDatabricks SQLで参照できるよう、いくつか過程を踏んで実装します。
今回はカスタムモデルをサービングした上で、APIでコールする処理をSQL Functionでラップする方式※をとりました。モデルサービングの概要についてはこちらのDocが詳しいです。

※関数のUDF化や、MLFlowから直接読み込んだモデルの参照等も検証しましたが、モデルのシリアライズエラーを解決できなかったためこちらを採用しています。

  • (2024/5時点)モデルサービングは一部クラウドサービス・リージョンでは未対応であり、以下で実装する一部の機能は特定リージョンで使用できない可能性があります。使用可能リージョンの最新情報はモデルサービングの制限と地域を参照してください。

2-1. ラベリングを行うLLMモデルの実装

後続処理でMLFlowにモデルを登録するため、Predictメソッドを持たせてLLMをラップしたクラスOriginDelegatingModelを実装しました。他、細かい点について解説します。

  • LLMのアウトプットをjsonフォーマットで固定するため、LangChainJsonOutputParserを採用しました。プロンプトエンジニアリングよりも出力が安定します。
  • Parserで指定する型は、別ファイルResponseFieldにクラスとして定義したものをimportします。後続のMLFlowへの記録処理でもファイルを指定することでシリアライズのエラーを回避します。
python
from langchain_openai import ChatOpenAI 
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from mlflow.pyfunc import PythonModel
# LLMのレスポンスの型を定義したクラスを別ファイルから読み込み
from response_field import ResponseField 

import json

class OriginDelegatingModel(PythonModel):
    """OriginDelegatingModel
    大規模言語モデル(LLM)を使用して入力テキストを分類するためのモデルを定義したクラス。
    指定されたラベルの候補から最も適切なラベルを選択し、入力テキストへの割り当てを行う。
    
    Attributes:
        labels (list): 分類対象のラベルのリスト
        features (str): 分類対象の入力テキストが含まれるDataFrameの列名
    """
    def __init__(self, labels, features):
        self.labels = labels
        self.features = features
        
    def llm_classifier(self, row):
      """
      入力テキストを大規模言語モデルを使用して分類する関数
        Args:
            row (pandas.Series): 分類対象の入力テキストが含まれる DataFrame の行
        Returns:
            res (str): 入力テキストに割り当てるラベル
      """
      DATABRICKS_TOKEN = os.environ.get('DATABRICKS_TOKEN')

      model = ChatOpenAI(model="databricks-dbrx-instruct"
                         , openai_api_base="https://<my-workspace-domain>/serving-endpoints"
                         , openai_api_key=DATABRICKS_TOKEN
                         )
      
      parser = JsonOutputParser(pydantic_object=ResponseField)

      prompt = PromptTemplate(
        template="""
        入力されたワードに対して、ラベル候補に基づいてラベリングを行ってください。
        - 分類対象のワード: {query}
        - 分類を行うラベルの候補: {labels}
        - レスポンスのフォーマット: {format_instructions}
        """,
        input_variables=["query"],
        partial_variables={"format_instructions": parser.get_format_instructions(), "labels": self.labels}
      )

      chain = prompt | model | parser
      res = chain.invoke({"query": row[self.features]})
      return res

    
    def predict(self, context, model_input):
        """
        入力DataFrame に対してllm_classifier関数を適用し、各行のテキストに対するラベルを予測する
        
        Args:
            context (mlflow.pyfunc.PythonModel): モデルの実行コンテキスト
            model_input (pandas.DataFrame): 分類対象の入力テキストが含まれるDataFrame
            
        Returns:
            pandas.Series: 各行のテキストに割り当てるラベルのSeries
        """
      return model_input.apply(self.llm_classifier, axis=1)
resonse_field.py
from langchain_core.pydantic_v1 import BaseModel, Field
# 別のpyファイルで定義する、LLM出力フォーマットを定義したクラス
class ResponseField(BaseModel):
    label: str = Field(description="分類結果")
    confidence: float = Field(description="labelの分類結果に対する確信度。0から1の連続値で表現する")

実装したモデルが正しく動作するか確認します。とりあえず行ごとにラベルを付与できることを確認しました。

python
ec_labels = ["エアコン・照明", "キッチン家電", "パソコン・周辺機器・PCソフト", "カメラ・カメラレンズ・メモリーカード", "テレビ・レコーダー", "オーディオ・電子ピアノ・カー用品", "スマートフォン", "電子辞書・電話機・FAX・事務機器", "その他"]

test_model = OriginDelegatingModel(labels=ec_labels, features=["product_name"])
res_df = test_model.predict(None, df.toPandas().head(10))

display(res_df)

image.png

2-2. MLFlowにモデルを記録・Unity Catalogに登録

Databricks上でモデルサービングを行う前工程として、MLFlowにモデルを登録します。

  • mlflow.start_run()の引数でregistered_model_nameを指定することで、Unity Catalogに登録する際のモデル名称と保存場所を決定します
  • 同じくcode_pathresponse_field.pyを指定することで、同ファイルをモデルのアーティファクトとしてMLFlowに記録することができます
python
# MLFlowにモデルを記録
from mlflow.types import ColSpec, Schema
from mlflow.types.schema import DataType
from mlflow.models.signature import ModelSignature
from mlflow.pyfunc import PythonModel
from mlflow.tracking import MlflowClient
import mlflow
mlflow.set_registry_uri("databricks-uc")

ec_labels = ["エアコン・照明", "キッチン家電", "パソコン・周辺機器・PCソフト", "カメラ・カメラレンズ・メモリーカード", "テレビ・レコーダー", "オーディオ・電子ピアノ・カー用品", "スマートフォン", "電子辞書・電話機・FAX・事務機器", "その他"]

model_name = f"{catalog_name}.{schema_name}.prodname_classifier"

prodname_classifier = OriginDelegatingModel(labels=ec_labels, features=["product_name"])

input_schema = Schema([ColSpec(DataType.string, "product_name")])
output_schema = Schema([ColSpec(DataType.string, "response")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

with mlflow.start_run():
  mlflow.pyfunc.log_model(
    artifact_path="prodname_classifier", 
    python_model=prodname_classifier,
    registered_model_name=model_name,
    signature=signature,
    code_paths=["response_field.py"]
    )

Databricksで実行すると、モデルがUnity Catalogに登録されます。

名称未設定.jpg

モデルが登録されたことを確認したら、モデルの最新バージョンにエイリアスchampionを付与します。

python
# 最新のバージョンのモデルのaliasをChampionに設定
def get_latest_model_version(model_name):
  client = MlflowClient()
  model_version_infos = client.search_model_versions("name = '%s'" % model_name)
  return max([int(model_version_info.version) for model_version_info in model_version_infos])

client = MlflowClient()
latest_version = get_latest_model_version(model_name)
client.set_registered_model_alias(model_name, "Champion", latest_version)

なお、この時点で実装したモデルをSpark UDFとして登録することで、ノートブック上であればSQLでもモデルを参照することが可能です。実装方法の詳細は下記記事を参照してください。

2-3. モデルサービング

こちらのDocに従い、MLFlowに登録したモデルをサーブするためのエンドポイントを作成します。サービングエンドポイントの名称はprodname_classifierとしました。

注意点として、サービングエンドポイントを作成する際に、詳細設定で環境変数としてAPIキーを登録する必要があります。下記記事を参照しながら環境変数を登録します。

  • キー:OPENAI_API_KEY
    • 注:名称は必ずOPENAI_API_KEYとして登録してください。
  • バリュー:{{secrets/skato/token}}
    • OPENAIのAPIキーを登録します。シークレットは{{secrets//}}の形式で登録する必要があります。今回使用しているモデルはDatabricks上でサービングされたDBRXのAPIなのでで、Databricksのアクセストークンが登録されたシークレットスコープとキーを指定します。

名称未設定.jpg

3. Databricks SQLのファンクション登録

2.までの実装で、Notebookからモデルを使用することが可能になりました。
ここからさらに、Databricks SQLでも使用できるようにします。サービングしたモデルをSQL Functionとして登録することでSQL Editorからモデルを参照できるようにします。

3-1. サービングしたモデルの検証

まずはサービングされたモデルの入出力を確認します。

python
import os
import requests
import numpy as np
import pandas as pd
import json

def create_tf_serving_json(data):
  return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}

def score_model(dataset):
  url = 'https://<my-workspace-domain>/serving-endpoints/prodname_classifier/invocations'
  headers = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
  ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
  data_json = json.dumps(ds_dict, allow_nan=True)
  response = requests.request(method='POST', headers=headers, url=url, data=data_json)
  if response.status_code != 200:
    raise Exception(f'Request failed with status {response.status_code}, {response.text}')
  return response.json()
python
test_df = df.limit(3).toPandas()["product_name"]
score_model(test_df)["predictions"]

# 実行結果
# [{'0': {'label': 'テレビ・レコーダー', 'confidence': 0.95}},
#  {'0': {'label': 'キッチン家電', 'confidence': 0.9}},
#  {'0': {'label': 'キッチン家電', 'confidence': 0.9}}]

3-2. SQL Functionの登録

検証したコードを一部書き換えてSQL Functionを定義します。実装はSQLの文法で行いますが、部分的にPythonを使用することが可能です。
なお、モデルサービングAPIを参照するために必要なAPI KeyはSQL実行時に渡す設計としました。
詳細のSQL Functionの書き方はCREATE FUNCTION (SQL and Python)を参照しました。

sql
%sql 
-- FUNCTIONの作成(サービング済モデルのAPIコール処理をラップする)
CREATE OR REPLACE FUNCTION skato.name_recognition.prodname_classifier_udf(product_name STRING, secret STRING)
RETURNS STRUCT<label:STRING, confidence:DOUBLE> 
LANGUAGE PYTHON
AS $$
  import os
  import requests
  import numpy as np
  import pandas as pd
  import json

  def create_tf_serving_json(data):
    return {'inputs': {name: data[name].tolist() for name in data.keys()} if isinstance(data, dict) else data.tolist()}

  def score_model(dataset, secret):
    DATABRICKS_TOKEN = secret

    url = 'https://<my-workspace-domain>/prodname_classifier/invocations'
    headers = {'Authorization': f'Bearer {DATABRICKS_TOKEN}', 'Content-Type': 'application/json'}
    ds_dict = {'dataframe_split': dataset.to_dict(orient='split')} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code == 200:
      return response.json()['predictions'][0]['0']
    else:
      return response

  input = pd.Series(product_name)
  res = score_model(input, secret)

  return res
$$

指定したカタログ・スキーマにファンクションが登録されていれば作業完了です!
あとはDatabricks SQLのSQLエディタで0.のコードを実行することで、ファンクションを使用することができます。

名称未設定.jpg

改善/発展のアクション

今回はテキスト分類を行うSQL Functionを実装しましたが、商品名の正規化や名寄せを行う際にはまた別の工夫が必要になると思います。いくつか改善/発展のアイデアを記載します。

1. 名寄せ/正規化のような高カーディナリティな分類タスクへの対応

今回は10種類程度のカテゴリ付与を想定ケースとして実装しました。一方で、正規化や名寄せを行う場合、ときには100種類以上にもなるような非常に多くのクラスへの分類を考慮しなければなりません。旧来の機械学習モデルはなかなか精度が出にくいタスクです。
LLMではどう対応すればよいでしょうか?例えばプロンプトエンジニアリングの領域なら、Few Shotプロンプティングで名寄せの例示をプロンプトに組み込むと良さそうです。また、分類種類が数千・数万件にも及ぶようなら、Retrieval Augmented Generationを導入することも効果的かもしれません。

2. モデルの汎化

今回、ラベルの候補(エアコン、テレビ、...)はモデルクラスをインスタンスする時に指定するよう設計しました。ラベル候補を入力をより後の工程、例えばSQL実行時に指定するように設計することで、よりユーザーが自由に分類を行えるような関数を作ることができると思います。
一方で、複数人で使用する場合は、人によって分類候補が異なってしまうようなことを避けるようなガバナンスの仕組みを作る必要がありそうです。

3. リクエストの時間/コストの短縮

今回作ったファンクションは、レコードが100件あれば生成モデルも100回コールする仕組みになっています。現在100件の処理に1-2分程かかっているため、数千件以上のデータを処理しようとするとレイテンシーの課題が発生し得ます。
例えば1回のコールで複数件のレコードをまとめてラベリングする作りにすることで、レイテンシーとコストを改善できる可能性があります。一方で、SQL Functionで参照する際は入力時にデータのもたせ方に工夫が必要だったり、あまりに多くのレコードを一度に処理しようとすると精度が出ない/エラーになってしまったり等、動作が不安定になる可能性も考慮する必要があります。

おまけ. AI Functions

日本のリージョンではまだ使用できませんが、Databricks SQLで使用できるDatabricks AI Functionsがプレビュー機能として公開されています。

AI Functionsは特定のユースケース(分類、要約、...)SQL関数です。たとえばai_classify()を使用すると、特にモデルをカスタマイズせずにSQLでテキスト分類をすることが可能です!! 最初からこれでも良かったのでは

sql
WITH tar AS (
  SELECT * FROM household_appliance_ec LIMIT 20
)
SELECT
  product_name
  -- ai_classify(): 引数1に列名、引数2にラベリングする候補の配列を指定し、テキスト分類を行う
  , ai_classify(product_name, array("エアコン・照明", "キッチン家電", "パソコン・周辺機器・PCソフト", "カメラ・カメラレンズ・メモリーカード", "テレビ・レコーダー", "オーディオ・電子ピアノ・カー用品", "スマートフォン", "電子辞書・電話機・FAX・事務機器", "その他")) as res
FROM
  tar

image.png

LLMを気楽にSQLでも使いたい場合はAI Function、モデルを自由に選択したい、出力をカスタマイズしたい場合については上記記載した方法で実際する等、状況に合わせて選択するのが良さそうです。

おわりに

なんとかSQLでLLMを参照できないかといろいろ試行錯誤しました。手前味噌ですが、やはりモデル実装、サービング、SQL実行のすべてが1プラットフォームで完結して実装できるのはかなり便利だと思います。

ご質問やご指摘があればコメントお願いします!

46
46
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
46
46

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?