0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

MLflowラッパーAPIを作ってみた

Posted at

はじめに

  • 目的
    • このMLflowラッパーAPIの目的は, 数理最適化, 特にメタヒューリスティクスのランダム性を含めた解のトラッキングを行うことです
    • 元々のMLflowには無い, プロジェクト管理機能を盛り込み, 複数人での実験管理も行えるようにしています
  • 対象読者
    • MLflowを使ったことがある人
    • 複数人で同じ実験を行いたい人
    • プロジェクト管理機能を考えていた人

目次

なぜMLflowラッパーAPIが必要になったのか?

経緯

あくまで個人的な経緯ですが...

  • メタヒューリスティクスでTSPを解いてみたくなる
  • プロトタイプをPythonコードで書いてみる
  • 規模が大きい問題を解いてみたくなる
  • Pythonだと計算時間が掛かり, 高速化にも限界がある(技術足りない...)
  • C#での移植を考える
  • C#だとMLflowを直接扱えない
  • Python, C#の両方から使えるAPIを作りたくなる

以上の理由から, MLflowラッパーAPIを実装してみました.

MLflowラッパーAPIのメリット

  • API経由で操作できるため, 利用する言語や環境を選ばない
  • 複数人での実験管理が(理論上)可能
  • プロジェクトごとに実験管理可能
  • 疎結合 (Decoupling) の実現
  • ログ記録, 実験管理, 集計といった操作をシンプルに行える

APIの技術スタックと全体構成

技術スタック

  • 言語: Python
  • Web FW: FastAPI
  • バックエンド: MLflow
  • スキーマ定義: Pydantic

全体構成

root
├── app/                                # FastAPIアプリケーションのコア
│   ├── api/                            # APIインターフェース層
│   │   ├── common/                     # 共通ヘルパー機能
│   │   │   └── common_func.py          # MLflowクライアントの動的生成, バリデーション
│   │   ├── routers/                    # エンドポイントのルーティングとロジック
│   │   │   ├── aggregation_routers.py  # データ取得・集約・分析ロジック
│   │   │   ├── experiment_routers.py   # 実験のライフサイクル管理 (開始/終了)
│   │   │   └── log_routers.py          # ログ/アーティファクトのI/O
│   │   └── endpoints.py                # メインアプリでのルーター集約ファイル
│   ├── cors/                           # CORS (Cross-Origin Resource Sharing) 設定
│   │   └── cors.py                     # CORSミドルウェアの定義
│   ├── models/                         # Pydanticスキーマによるデータ構造定義
│   │   ├── aggregations/               # データ集約用のスキーマ
│   │   │   └── agg_schemas.py          # Experiment/Run情報, 集約応答スキーマ
│   │   ├── experiments/                # 実験制御用のスキーマ
│   │   │   └── experiment_schemas.py   # Run開始/終了の要求/応答スキーマ
│   │   └── logs/                       # ロギングデータ用のスキーマ
│   │       └── log_schemas.py          # メトリクス, パラメータ, アーティファクトのI/Oスキーマ
│   └── main.py                         # FastAPIのエントリポイント (アプリの起動/設定)
├── tests/                              # ユニットテストコード
├── Makefile                            # 実行, クリーニングなどの自動化コマンド
├── README.md                           # プロジェクト説明
└── pyproject.toml                      # プロジェクト依存関係と設定

実装について

モデルスキーマ

app/models/

実験制御用のスキーマ

app/models/experiments/experiment_schemas.py

from typing import Optional

from pydantic import BaseModel


class ExperimentRequest(BaseModel):
    run_name: Optional[str] = None


class ExperimentResponse(BaseModel):
    run_id: str
    experiment_id: str
    status: str


class TerminateRunRequest(BaseModel):
    status: str

ロギングデータ用のスキーマ

app/models/logs/log_schemas.py

from pydantic import BaseModel


class ParamPair(BaseModel):
    key: str
    value: str


class MetricPair(BaseModel):
    key: str
    value: float
    epoch: int = 0


class LogParamsRequest(BaseModel):
    param_list: list[ParamPair]


class LogMetricsRequest(BaseModel):
    metric_list: list[MetricPair]


class LogResponse(BaseModel):
    status: str


class TerminateRunRequest(BaseModel):
    status: str


class ArtifactsResponse(BaseModel):
    artifact_name_list: list[str]


class CalcResultResponse(BaseModel):
    coordinates: dict[str, dict[str, float]]
    tour: list[int]

データ集約用のスキーマ

app/models/aggregations/agg_schemas.py

from pydantic import BaseModel


class ProjectsInfo(BaseModel):
    project_list: list[str]


class ExperimentInfo(BaseModel):
    project_name: str
    experiment_name: str
    experiment_id: str
    run_id_list: list[str]


class ExperimentResponse(BaseModel):
    project_name: str
    # experiment_id -> ExperimentInfo
    experiment_dict: dict[str, ExperimentInfo]


class RunInfo(BaseModel):
    experiment_name: str
    experiment_id: str
    run_name: str
    run_id: str
    params_dict: dict
    metrics_dict: dict


class RunResponse(BaseModel):
    experiment_name: str
    experiment_id: str
    # run_id -> RunInfo
    run_info_dict: dict[str, RunInfo]

ルータ周り

app/api/

共通処理

app/api/common/common_func.py

MLflowクライアントの動的生成, バリデーションを行う.

インジェクション攻撃などを受けたくないので, プロジェクト名の制約として,

  • 1文字目はアルファベット大文字/小文字固定
  • 2文字目以降は
    • アルファベット大文字/小文字
    • 数字0~9
  • 全部で5文字以上

としました.


import re

from mlflow.tracking import MlflowClient


def check_project_name(project_name: str) -> bool:
    return bool(re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_-]{4,}$", project_name))


def get_mlflow_client(
    project_name: str
) -> MlflowClient:
    """project_nameに基づいてMlflowClientを動的に生成する"""
    if check_project_name(project_name):
        tracking_uri: str = f"./projects/{project_name}/mlruns"
        return MlflowClient(tracking_uri=tracking_uri)
    else:
        raise ValueError(f"Invalid project name")

各ルータ

app/api/routers/

FastAPIをモジュール化.

実験制御

app/api/routers/experiment_routers.py

実験, 特にRunのスタートとターミネートを行います. ターミネートにおいては完了ステータスを指定可能です.

import mlflow
from fastapi import APIRouter, Depends, HTTPException
from mlflow.tracking import MlflowClient

from app.api.common.common_func import get_mlflow_client
from app.models.experiments.experiment_schemas import (ExperimentRequest,
                                                       ExperimentResponse,
                                                       TerminateRunRequest)
from app.models.logs.log_schemas import LogResponse

router = APIRouter()


@router.post("/api/v1/projects/{project_name}/experiments/{experiment_name}/start", response_model=ExperimentResponse)
def start_experiment_run(
    project_name: str,
    experiment_name: str,
    req: ExperimentRequest,
    client: MlflowClient = Depends(get_mlflow_client)
):
    try:
        exp = client.get_experiment_by_name(experiment_name)
        if exp is None:
            exp_id = client.create_experiment(experiment_name)
        else:
            exp_id = exp.experiment_id

        run = client.create_run(
            experiment_id=exp_id,
            run_name=req.run_name
        )

        return ExperimentResponse(
            run_id=run.info.run_id,
            experiment_id=run.info.experiment_id,
            status=run.info.status
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to start experiment in project {project_name}: {str(e)}"
        )


@router.put("/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/terminate", response_model=LogResponse)
def terminate_run(
    project_name: str,
    experiment_name: str,
    run_id: str,
    req: TerminateRunRequest,
    client: MlflowClient = Depends(get_mlflow_client)
):
    try:
        client.set_terminated(run_id=run_id, status=req.status)
        return LogResponse(status="success")
    except mlflow.exceptions.RestException as e:
        raise HTTPException(
            status_code=404,
            detail=f"Run {run_id} not found in project {project_name}/{experiment_name}. MLflow error: {e}"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Internal error in project {project_name}/{experiment_name}: {str(e)}"
        )

データのロギング

app/api/routers/log_routers.py

プロジェクト名, 実験名, Run Idを指定し, 実験のパラメータやメトリクスを記録する.

アーティファクトのアップロード処理, 計算結果の出力も担う.

import json
import os
import shutil
import tempfile
from typing import Optional

import mlflow
from fastapi import (APIRouter, Depends, File, HTTPException, Response,
                     UploadFile)
from mlflow import MlflowClient
from mlflow.tracking import MlflowClient

from app.api.common.common_func import get_mlflow_client
from app.models.logs.log_schemas import (ArtifactsResponse, CalcResultResponse,
                                         LogMetricsRequest, LogParamsRequest,
                                         LogResponse)

router = APIRouter()


@router.post(
    "/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/log/params",
    response_model=LogResponse
)
def log_params(
    project_name: str,
    experiment_name: str,
    run_id: str,
    req: LogParamsRequest,
    client: MlflowClient = Depends(get_mlflow_client)
):
    try:
        list_to_log = [
            mlflow.entities.Param(param.key, param.value)
            for param in req.param_list
        ]

        client.log_batch(run_id=run_id, params=list_to_log)

        return LogResponse(status="success")
    except mlflow.exceptions.RestException as e:
        raise HTTPException(
            status_code=404,
            detail=f"MLflow error in project {project_name}/{experiment_name}: {e}"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Internal error in project {project_name}/{experiment_name}: {str(e)}"
        )


@router.post(
    "/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/log/metrics",
    response_model=LogResponse
)
def log_metrics(
    project_name: str,
    experiment_name: str,
    run_id: str,
    req: LogMetricsRequest,
    client: MlflowClient = Depends(get_mlflow_client)
):
    try:
        list_to_log = [
            mlflow.entities.Metric(
                metric.key,
                metric.value,
                0,
                metric.epoch
            )
            for metric in req.metric_list
        ]

        client.log_batch(run_id, metrics=list_to_log)
        return LogResponse(status="success")
    except mlflow.exceptions.RestException as e:
        raise HTTPException(
            status_code=404,
            detail=f"MLflow error in project {project_name}/{experiment_name}: {e}"
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Internal error in project {project_name}/{experiment_name}: {str(e)}"
        )


@router.post(
    "/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/artifacts",
    response_model=LogResponse
)
def upload_artifact(
    project_name: str,
    experiment_name: str,
    run_id: str,
    file: UploadFile = File(...),
    client: MlflowClient = Depends(get_mlflow_client)
):
    # 一時的なファイルパス
    tmp_path: Optional[str] = None
    try:
        # ローカルなファイルをshutilでコピーする
        with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
            tmp_path = tmp.name
            shutil.copyfileobj(file.file, tmp)

        client.log_artifact(run_id=run_id, local_path=tmp_path)

        return LogResponse(status="success")
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to upload artifact for run {run_id} in project {project_name}/{experiment_name}: {e}"
        )
    finally:
        # mlflow サーバ上の一時ファイルを削除する
        if tmp_path and os.path.exists(tmp_path):
            os.remove(tmp_path)


@router.get(
    "/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/artifacts",
    response_model=ArtifactsResponse
)
def get_artifact_info(
    project_name: str,
    experiment_name: str,
    run_id: str,
    client: MlflowClient = Depends(get_mlflow_client)
):
    try:
        file_info = client.list_artifacts(run_id, path="")
        artifact_names = [each_info.path for each_info in file_info]
        return ArtifactsResponse(artifact_name_list=artifact_names)
    except Exception as e:
        # エラー処理
        raise HTTPException(
            status_code=500,
            detail=f"Failed to get artifact paths for run {run_id} in project {project_name}/{experiment_name}: {e}"
        )


@router.get(
    "/api/v1/projects/{project_name}/experiments/{experiment_name}/runs/{run_id}/artifacts/{file_name}",
    response_model=CalcResultResponse
)
def get_calc_result(
    project_name: str,
    experiment_name: str,
    run_id: str,
    file_name: str,
    client: MlflowClient = Depends(get_mlflow_client)
):
    """
    指定されたRun IDとファイル名(artifactパス)に一致する計算結果を返す
    """

    # 一時ディレクトリを作成し, artifactをそこにダウンロードする
    with tempfile.TemporaryDirectory() as temp_dir:
        try:
            # MLflowクライアントを使用してartifactを一時ディレクトリにダウンロード
            local_path = client.download_artifacts(
                run_id=run_id,
                path=file_name,
                dst_path=temp_dir
            )

            if not os.path.exists(local_path):
                raise FileNotFoundError(
                    f"Artifact file not found: {file_name}"
                )

            # ファイルのコンテンツを読み込む

            content: str
            with open(local_path, 'r', encoding='utf-8') as f:
                content = f.read()

            json_obj: dict = json.loads(content)

            coordinates: dict = json_obj["problem"]["coordinates"]
            tour: list = json_obj["tour"]

            return CalcResultResponse(
                coordinates=coordinates,
                tour=tour
            )

        except FileNotFoundError as e:
            # Artifactが見つからない場合のHTTP 404
            raise HTTPException(
                status_code=404,
                detail=f"Artifact '{file_name}' not found for run {run_id}."
            )
        except Exception as e:
            # その他のMLflow関連のエラー
            print(f"MLflow Error: {e}")
            raise HTTPException(
                status_code=500,
                detail=f"Failed to retrieve artifact '{file_name}' for {project_name}/{experiment_name}/{run_id}: {e}"
            )

データの集約

app/api/routers/aggregation_routers.py

以下の情報を返す

  • プロジェクト名の一覧
  • プロジェクト名直下の実験一覧
  • 実験名直下のRun情報一覧
# /app/api/routers/aggretation_routers.py
import os

from fastapi import APIRouter, Depends, HTTPException
from mlflow.tracking import MlflowClient

from app.api.common.common_func import get_mlflow_client
from app.models.aggregations.agg_schemas import (ExperimentInfo,
                                                 ExperimentResponse,
                                                 ProjectsInfo, RunInfo,
                                                 RunResponse)

router = APIRouter()


@router.get("/")
async def root():
    return {"message": "Welcome to MLflow API"}


@router.get("/api/v1/projects", response_model=ProjectsInfo)
def get_project_names():
    """プロジェクトの名前リストを返す (ファイルシステム上の有効なディレクトリをスキャン)"""
    project_root = "./projects"

    try:
        # プロジェクトルートディレクトリが存在しない場合は作成
        if not os.path.exists(project_root):
            os.makedirs(project_root, exist_ok=True)

        # プロジェクトルートディレクトリ内のエントリをリストアップ
        project_list: list[str] = []
        for entry in os.listdir(project_root):
            full_path = os.path.join(project_root, entry)
            # ディレクトリであれば追加
            if os.path.isdir(full_path):
                project_list.append(entry)

        return ProjectsInfo(
            project_list=project_list
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to get project names: {str(e)}"
        )


@router.get("/api/v1/projects/{project_name}/experiments", response_model=ExperimentResponse)
def get_experiment_names(
    project_name: str,
    client: MlflowClient = Depends(get_mlflow_client)
):
    """プロジェクト名に対し, 紐づく実験の情報を返す"""
    project_dir = f"./projects/{project_name}"

    try:
        # プロジェクトルートディレクトリが存在しない場合は例外発生
        if not os.path.exists(project_dir):
            raise ValueError(f"Project name: {project_name} does not exsist.")

        experiment_dict: dict = {}
        experiments = client.search_experiments()
        for experiment in experiments:
            # Id, 名前を取得
            exp_id = experiment.experiment_id
            if exp_id == "0":
                continue
            exp_name = experiment.name
            # Run Idのリストを取得する
            run_id_list: list = [
                run.info.run_id for run in client.search_runs(experiment_ids=[exp_id])
            ]
            experiment_dict[exp_id] = ExperimentInfo(
                project_name=project_name,
                experiment_name=exp_name,
                experiment_id=exp_id,
                run_id_list=run_id_list
            )

        return ExperimentResponse(
            project_name=project_name,
            experiment_dict=experiment_dict
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to get project names: {str(e)}"
        )


@router.get("/api/v1/projects/{project_name}/experiments/{experiment_name}/runs", response_model=RunResponse)
def get_run_info(
    project_name: str,
    experiment_name: str,
    client: MlflowClient = Depends(get_mlflow_client)
):
    exp_dir = f"./projects/{project_name}"

    try:
        # 実験のルートディレクトリが存在しない場合は例外発生
        if not os.path.exists(exp_dir):
            raise ValueError(f"Experiment: {exp_dir} does not exsist.")

        experiment = client.get_experiment_by_name(experiment_name)
        exp_id: str = experiment.experiment_id
        runs = client.search_runs(experiment_ids=[exp_id])

        run_info_dict = {}
        for run in runs:
            run_name: str = run.info.run_name
            run_id: str = run.info.run_id
            params_dict: dict = run.data.params
            metrics_dict: dict = run.data.metrics

            run_info_dict[run_id] = RunInfo(
                experiment_name=experiment_name,
                experiment_id=exp_id,
                run_name=run_name,
                run_id=run_id,
                params_dict=params_dict,
                metrics_dict=metrics_dict
            )
        return RunResponse(
            experiment_name=experiment_name,
            experiment_id=exp_id,
            run_info_dict=run_info_dict
        )
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"Failed to get project names: {str(e)}"
        )

メインアプリでのルータ集約ファイル

app/api/endpoints.py

app/api/routers/の直下にあるルータを集約する

from fastapi import APIRouter

from app.api.routers.aggregation_routers import router as agg_router
from app.api.routers.experiment_routers import router as exp_router
from app.api.routers.log_routers import router as log_router

api_router = APIRouter()
api_router.include_router(agg_router)
api_router.include_router(exp_router)
api_router.include_router(log_router)

CORS (Cross-Origin Resource Sharing) 設定

app/api/cors/cors.py

APIとして受け入れ可能なリソースの設定

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware


def set_cors(app: FastAPI):
    app.add_middleware(
        CORSMiddleware,
        allow_origins=[
            "http://localhost:8080",
            # ...
        ],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"]
    )

FastAPIのエントリポイント

app/main.py

from fastapi import FastAPI

from app.api.endpoints import api_router
from app.cors.cors import set_cors

app = FastAPI(title="MLflow API", docs_url="/docs")
set_cors(app)
app.include_router(api_router)


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

実装のポイント

  • ルーティングの分離
    • *_routers.pyなど役割ごとにファイルを分割し, 今後の大規模化に対応
  • RESTfulなパス設定
    • experiment_routers.py, log_routers.pyなど, Pathの一貫性

まとめ

MLflowラッパーAPIが実現したこと

  • 言語・環境非依存のトラッキング
    • C#のようなPython以外の環境からも, RESTful API経由でMLflowの全機能を容易に利用可能
      • 実験開始/終了
      • パラメータ/メトリクス/アーティファクトのロギング
    • 特に, メタヒューリスティクスのランダム性を含む計算結果のトラッキングをプラットフォームの制約なく実現
  • 独自のプロジェクト管理機能
    • MLflowの「実験 (Experiment)」の上位概念として「プロジェクト (Project)」をファイルシステムベースで導入
    • 複数人・複数チームでの実験管理を可能にした
  • 堅牢性の向上と定型処理の抽象化
    • FastAPIの依存性注入 (Depends) とPydanticスキーマ (*_schemas.py) により, データの型チェックとバリデーションを徹底し
    • 堅牢で使いやすいAPIインターフェースを提供
    • 煩雑なMLflowクライアントの操作やエラーハンドリングをAPI側で一手に引き受け, 利用者はシンプルに実験データを記録/取得

本ラッパーAPIが, 皆様の機械学習および数理最適化の実験管理をよりシンプルかつ効率的にする一助となれば幸いです.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?