0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Databricksモデルサービングエンドポイントへの画像分類モデルのデプロイ

Posted at

こちらのノートブックを動かしてみました。

実際に動かしたら、モデルサービングエンドポイントとの連携部分でエラーになったので修正しています。

# 必要なPythonパッケージをインストールします。
%pip install transformers torch

# 上記パッケージを有効にするためにPythonプロセスを再起動します。
dbutils.library.restartPython()
import io, base64, torch, mlflow.pyfunc
import pandas as pd
from mlflow.types.schema import ColSpec
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification

class MicrosoftResnetFifty(mlflow.pyfunc.PythonModel):
   # モデルの初期化。画像プロセッサと分類モデルを設定します。
   def __init__(self):
       super(MicrosoftResnetFifty, self).__init__()
       self.processor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')
       self.model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')

   # 推論用に画像を準備します。既知の画像を透過なしPNGに変換します。
   def prepare_image_for_inference(self, image_data):
      # Base64文字列の場合はデコード
      if isinstance(image_data, str):
        image_bytes = base64.b64decode(image_data)
      else:
        image_bytes = image_data

      img = Image.open(io.BytesIO(image_bytes))
      buf = io.BytesIO()
      img.convert("RGB").save(buf, format='PNG')
      img_bytes = buf.getvalue()
      image = Image.open(io.BytesIO(img_bytes))
      return image

   # 画像を前処理し、推論を実行して予測ラベルを返します。
   def classify(self, image):
       inputs = self.processor(image, return_tensors="pt")
       with torch.no_grad():
           logits = self.model(**inputs).logits
       predicted_label = logits.argmax(-1).item()
       return self.model.config.id2label[predicted_label]
  
   # 任意の画像タイプをPNGに変換し、モデルに渡して推論します。
   def classify_image(self, image_data):
       image = self.prepare_image_for_inference(image_data)
       prediction = self.classify(image)
       return prediction

   # Databricksが外部から入力を渡す際のインターフェースです。
   def predict(self, model_input: pd.DataFrame):
       # Base64デコードは prepare_image_for_inference 内で行われるため、ここでは不要
       return [self.classify_image(image) for image in model_input['image']]

以下のセルでは元記事になかったロジックを追加しています。

バイナリを渡す場合には変更は不要なのですが、モデルサービングエンドポイントに上のモデル実装をデプロイした際、文字列で渡す必要があるのでbase64のデコード処理を加えています。

      # Base64文字列の場合はデコード
      if isinstance(image_data, str):
        image_bytes = base64.b64decode(image_data)
      else:
        image_bytes = image_data

一方で画像データを渡す際にはbase64エンコードしています。

# クラスの動作確認:正しい値が返るかテストします。
with open('cat.jpg', 'rb') as file:
   file_content = base64.b64encode(file.read())

input_example = pd.DataFrame({"image": [file_content]})

MicrosoftResnetFifty().predict(input_example)
['Egyptian cat']

Egyptianかどうかはさておき、猫として認識されています。以下の画像です。
cat.jpg

import pandas as pd
from mlflow.types import Schema, ColSpec
from mlflow.models.signature import ModelSignature

# 入力スキーマと出力スキーマを定義します
input_schema = Schema([ColSpec("binary", "image")])
output_schema = Schema([ColSpec("string")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

# モデルをUnity Catalogに保存するためのレジストリURIを設定します
mlflow.set_registry_uri("databricks-uc")

# MLflowランを開始し、モデルを登録します
with mlflow.start_run() as run:
   mlflow.pyfunc.log_model(
       python_model=MicrosoftResnetFifty(),
       artifact_path="infer_model",
       input_example=input_example,
       signature=signature
   )
   model_uri = f"runs:/{run.info.run_id}/infer_model"
   registered_model_name = "takaakiyayoi_catalog.image_ai_query.resnet50"
   mlflow.register_model(model_uri, registered_model_name)

Unity Catalog配下に登録されます。このモデルをサービングをクリックしてサービングエンドポイントにデプロイします。
Screenshot 2025-11-10 at 12.56.17.png

エンドポイント名を指定してデプロイします。
Screenshot 2025-11-10 at 12.56.39.png

十数分でエンドポイントが稼働します。
Screenshot 2025-11-10 at 12.56.58.png
Screenshot 2025-11-10 at 13.12.18.png

モデルサービングエンドポイントのURLを用いて画像判別を行います。画像をデータフレームにまとめ上げてバッチ処理を可能にしています。

import os
import requests
import numpy as np
import pandas as pd
import json
import base64
from PIL import Image
from io import BytesIO

# Pandas DataFrameをモデルサービングエンドポイントが受け付けるJSON文字列に変換します
def to_dataframe_split_json(df):
   obj = {'dataframe_split': json.loads(df.to_json(orient='split'))}
   return json.dumps(obj, allow_nan=True)

# モデルサービングエンドポイント呼び出しに必要な認証ヘッダーを取得します
def get_headers():
   token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
   headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
   return headers

# モデルサービングエンドポイントを呼び出し、結果を返します
def infer(dataset):
   endpoint_url = 'https://xxxx.cloud.databricks.com/serving-endpoints/taka-resnet50/invocations' 
  
   headers = get_headers()
   post_data = to_dataframe_split_json(dataset)

   response = requests.request(
       method='POST',
       headers=headers,
       url=endpoint_url,
       data=post_data)
  
   if response.status_code != 200:
       raise Exception(f'Request failed with status {response.status_code}, {response.text}')
   return response.json()

# 画像をbase64エンコードするヘルパー関数
def get_image_base64(image_path):
    with open(image_path, 'rb') as f:
        image_bytes = f.read()
        # Base64エンコードして文字列として返す
        return base64.b64encode(image_bytes).decode('utf-8')

# 複数画像をDataFrameに格納し、バッチ推論をテストします
image_df = pd.DataFrame({"image":
   [
       get_image_base64('bird.jpg'),
       get_image_base64('cat.jpg'),
       get_image_base64('davidmorton.png')
   ]})

# infer関数を呼び出します
results = infer(image_df)

# 結果をコンソールに出力します
print(results['predictions'])
['barbershop', 'Egyptian cat', 'flat-coated retriever']

はじめてのDatabricks

はじめてのDatabricks

Databricks無料トライアル

Databricks無料トライアル

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?