こちらのノートブックを動かしてみました。
実際に動かしたら、モデルサービングエンドポイントとの連携部分でエラーになったので修正しています。
# 必要なPythonパッケージをインストールします。
%pip install transformers torch
# 上記パッケージを有効にするためにPythonプロセスを再起動します。
dbutils.library.restartPython()
import io, base64, torch, mlflow.pyfunc
import pandas as pd
from mlflow.types.schema import ColSpec
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification
class MicrosoftResnetFifty(mlflow.pyfunc.PythonModel):
# モデルの初期化。画像プロセッサと分類モデルを設定します。
def __init__(self):
super(MicrosoftResnetFifty, self).__init__()
self.processor = AutoImageProcessor.from_pretrained('microsoft/resnet-50')
self.model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')
# 推論用に画像を準備します。既知の画像を透過なしPNGに変換します。
def prepare_image_for_inference(self, image_data):
# Base64文字列の場合はデコード
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
image_bytes = image_data
img = Image.open(io.BytesIO(image_bytes))
buf = io.BytesIO()
img.convert("RGB").save(buf, format='PNG')
img_bytes = buf.getvalue()
image = Image.open(io.BytesIO(img_bytes))
return image
# 画像を前処理し、推論を実行して予測ラベルを返します。
def classify(self, image):
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits
predicted_label = logits.argmax(-1).item()
return self.model.config.id2label[predicted_label]
# 任意の画像タイプをPNGに変換し、モデルに渡して推論します。
def classify_image(self, image_data):
image = self.prepare_image_for_inference(image_data)
prediction = self.classify(image)
return prediction
# Databricksが外部から入力を渡す際のインターフェースです。
def predict(self, model_input: pd.DataFrame):
# Base64デコードは prepare_image_for_inference 内で行われるため、ここでは不要
return [self.classify_image(image) for image in model_input['image']]
以下のセルでは元記事になかったロジックを追加しています。
バイナリを渡す場合には変更は不要なのですが、モデルサービングエンドポイントに上のモデル実装をデプロイした際、文字列で渡す必要があるのでbase64のデコード処理を加えています。
# Base64文字列の場合はデコード
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
image_bytes = image_data
一方で画像データを渡す際にはbase64エンコードしています。
# クラスの動作確認:正しい値が返るかテストします。
with open('cat.jpg', 'rb') as file:
file_content = base64.b64encode(file.read())
input_example = pd.DataFrame({"image": [file_content]})
MicrosoftResnetFifty().predict(input_example)
['Egyptian cat']
Egyptianかどうかはさておき、猫として認識されています。以下の画像です。

import pandas as pd
from mlflow.types import Schema, ColSpec
from mlflow.models.signature import ModelSignature
# 入力スキーマと出力スキーマを定義します
input_schema = Schema([ColSpec("binary", "image")])
output_schema = Schema([ColSpec("string")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
# モデルをUnity Catalogに保存するためのレジストリURIを設定します
mlflow.set_registry_uri("databricks-uc")
# MLflowランを開始し、モデルを登録します
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
python_model=MicrosoftResnetFifty(),
artifact_path="infer_model",
input_example=input_example,
signature=signature
)
model_uri = f"runs:/{run.info.run_id}/infer_model"
registered_model_name = "takaakiyayoi_catalog.image_ai_query.resnet50"
mlflow.register_model(model_uri, registered_model_name)
Unity Catalog配下に登録されます。このモデルをサービングをクリックしてサービングエンドポイントにデプロイします。

モデルサービングエンドポイントのURLを用いて画像判別を行います。画像をデータフレームにまとめ上げてバッチ処理を可能にしています。
import os
import requests
import numpy as np
import pandas as pd
import json
import base64
from PIL import Image
from io import BytesIO
# Pandas DataFrameをモデルサービングエンドポイントが受け付けるJSON文字列に変換します
def to_dataframe_split_json(df):
obj = {'dataframe_split': json.loads(df.to_json(orient='split'))}
return json.dumps(obj, allow_nan=True)
# モデルサービングエンドポイント呼び出しに必要な認証ヘッダーを取得します
def get_headers():
token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
headers = {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
return headers
# モデルサービングエンドポイントを呼び出し、結果を返します
def infer(dataset):
endpoint_url = 'https://xxxx.cloud.databricks.com/serving-endpoints/taka-resnet50/invocations'
headers = get_headers()
post_data = to_dataframe_split_json(dataset)
response = requests.request(
method='POST',
headers=headers,
url=endpoint_url,
data=post_data)
if response.status_code != 200:
raise Exception(f'Request failed with status {response.status_code}, {response.text}')
return response.json()
# 画像をbase64エンコードするヘルパー関数
def get_image_base64(image_path):
with open(image_path, 'rb') as f:
image_bytes = f.read()
# Base64エンコードして文字列として返す
return base64.b64encode(image_bytes).decode('utf-8')
# 複数画像をDataFrameに格納し、バッチ推論をテストします
image_df = pd.DataFrame({"image":
[
get_image_base64('bird.jpg'),
get_image_base64('cat.jpg'),
get_image_base64('davidmorton.png')
]})
# infer関数を呼び出します
results = infer(image_df)
# 結果をコンソールに出力します
print(results['predictions'])
['barbershop', 'Egyptian cat', 'flat-coated retriever']


