10
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FastAPIで機械学習モデルを扱うためのAPI設計について

Last updated at Posted at 2024-10-07

はじめに

ネットサーフィン中に「How to Use FastAPI for Machine Learning(機械学習のためのFastAPIの使い方)」という興味深い記事を見つけました。このブログ記事では、FastAPIを使って機械学習モデルを扱うための基本的な設定や実践的な使い方が解説されています。

この記事に触発された形ではありますが、本記事ではFastAPIを用いて機械学習モデルに関するAPIを構築する際の設定や実装方法をまとめていきます。

モデル定義

FastAPIで機械学習モデルを扱うには、アプリケーションの起動時に機械学習モデルをロードし、そのモデルをアプリケーション全体で管理する必要があります。

FastAPIでは、アプリケーションのライフスパンイベントを活用して、起動時にモデルをロードし、終了時にリソースを解放することができます。

またモデルはapp.stateを利用して管理することで、アプリケーション全体で簡単にアクセスできます。

プログラム例は以下です。

main.py
from contextlib import asynccontextmanager
from fastapi import FastAPI

def model_load():
    # モデルロード関数
    pass
    
@asynccontextmanager
async def lifespan(app: FastAPI): 
    # モデルのロード
    app.state.model = model_load()
    yield
    # モデルのリソース解放

app = FastAPI(lifespan=lifespan)

コード内のasynccontextmanagerデコレータは、非同期ジェネレータ関数を非同期コンテキストマネージャに変換します。
yieldの前は__aenter__メソッドに相当し、コンテキストの開始時に実行され、yieldの後は__aexit__メソッドに相当し、コンテキストの終了時に実行されます。

またモデルが軽量の場合はpickleなどで管理し、自作したモデルロード関数で以下のように呼ぶと良いでしょう。

main.py
app.state.model = model_load('ファイルパス') #モデルロード関数

例えばLGBMモデルを使う場合は、モデルをpickle形式で保存し、以下のようなクラスを設定する方法も有効でしょう。

core/model.py
import pickle

class LGBMPredictor:
    def __init__(self, model_path: str) -> None:
        with open(model_path, "rb") as f:
            self.model = pickle.load(f)
    # predictなどの処理を後続で記載
main.py
from core.model import LGBMPredictor

app.state.model = LGBMPredictor("ファイルパス")

モデルを使いたい場合はmain.py内であればapp.state.modelにアクセスして使うことができます。
main.py外で使いたい場合は以下ルーターの項目でも説明しますが、FastAPIで用意されているRequestオブジェクトを使えば良いです。

またモデル等のリソースの解放について、以下の記事で簡単に解説しています。

モデル更新

機械学習モデルをオンライン学習する場合において、学習プロセスは時間がかかるため、バックグラウンドタスクを使用するのが適切です。

FastAPIではBackgroundTasksを使って、バックグラウンドで非同期に処理を実行できます。

以下は、バックグラウンドタスクを使ってモデルを再訓練する例です。

main.py
from fastapi import FastAPI, BackgroundTasks

def retrain_model(data, model):
    # 再訓練のロジックをここに実装
    pass

@app.post("/retrain")
async def retrain(background_tasks: BackgroundTasks, request):
    data = request.data #リクエストからデータが取れると仮定 
    background_tasks.add_task(retrain_model, data, app.state.model)
    return {"message": "Model retraining started"}

モデルを複数のスレッドで同時にapp.state.modelを更新しようとすると、予期しない動作やデータの破損等の可能性があるので、スレッドセーフティを確保する必要があります。

スレッドセーフティを考慮したプログラムは以下です。

main.py
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI
    
@asynccontextmanager
async def lifespan(app: FastAPI): 
    app.state.model = model_load()
    # スレッドセーフティ用のロックを追加
    app.state.model_lock = threading.Lock()
    yield

retrain_modelの引数はモデルのみではなくapp_stateを引数とし、更新のタイミングでロックを取得します。

def retrain_model(data, app_state):
    # 再訓練のロジックをここに実装

    ## ロックを取得してモデルを更新
    with app_state.model_lock:
        app_state.model = new_model

共通設定

以下では機械学習の使用の有無に関わらない、FastAPIを用いる全プロジェクトに共通する項目について簡単にまとめていきます。

ルーター

全てのエンドポイントをmain.pyに書いてしまうと、コードが煩雑になりメンテナンスが難しくなります。そこでエンドポイントを別のルーター用のファイルに分割することで、コードの可読性や保守性を高めることができます。

上記で説明したretrainのエンドポイントをルーター用のファイルに移行する場合は、main.pyrouter/router.pyは以下のように書けます。

router/router.py
from fastapi import APIRouter, BackgroundTasks, Body, Request
from your_module import retrain_model #モデル更新の関数
    ,
router = APIRouter()

@router.post("/retrain/")
async def retrain(background_tasks: BackgroundTasks, request_obj: Request, retrain_request = Body(...)):
    data = retrain_request.data #リクエストからデータが取れると仮定 
    background_tasks.add_task(retrain_model, data, request_obj.app.state.model)
    return {"message": "Model retraining started"}
main.py
from router import router

app.include_router(router.router)

RequestはASGI(Asynchronous Server Gateway Interface)で定義される「スコープ」やアプリケーションのステートにアクセスすることができます。

リクエストとレスポンス

コードの可読性とメンテナンス性の向上やバリデーション等のためにリクエストとレスポンスを設定します。

ディレクトリ名はmodelsとしたいところですが、機械学習モデルと混同するのでpayloadsなどが良いでしょう。

payloads/request.py
from pydantic import BaseModel
from typing import Dict

class RetrainRequest(BaseModel):
    data: Dict[str, float]
payloads/response.py
from pydantic import BaseModel

class RetrainResponse(BaseModel):
    message: str
router/router.py
from fastapi import APIRouter, BackgroundTasks, Request
from your_module import retrain_model #モデル更新の関数
from payloads.request import RetrainRequest
from payloads.response import RetrainResponse
    ,
router = APIRouter()

@app.post("/retrain/", response_model=RetrainResponse)
async def retrain(
    background_tasks: BackgroundTasks, 
    request_obj: Request, 
    retrain_request: RetrainRequest
):
    data = retrain_request.data
    background_tasks.add_task(retrain_model, data, request_obj.app.state.model)
    return {"message": "Model retraining started"}

設定したリクエストクラスは、関数の引数の型アシストで記載し、レスポンスクラスは@app.post()response_modelで定義します。

ロガー

ロガーを使うことで、アプリケーションの実行時の挙動を追跡し、エラーやパフォーマンスの問題を把握するのに役立ちます。

以下のように簡単にロガーを適用できます。

from contextlib import asynccontextmanager
from fastapi import FastAPI
import logging
from logger import logger

@asynccontextmanager
async def lifespan(app: FastAPI): 
    logger: Logger = logging.getLogger(f"custom.{__name__}")
    logger.info("start.....")
    yield
    logger.info("shutdown.....")

ミドルウェア

ミドルウェアは、リクエストが処理される前後に何らかの処理を挿入するための仕組みです。例えば、全てのリクエストに対して共通の前処理や後処理を行いたい場合に役立ちます。

ミドルウェアを用いた例は以下です。

main.py
from fastapi import FastAPI, Request
import time
from logger import logger

app = FastAPI()

@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time.time()
    logger.info(f"Request: {request.method} {request.url}")
    response = await call_next(request)
    process_time = time.time() - start_time
    logger.info(f"Response status: {response.status_code}, Time taken: {process_time:.4f} seconds")
    return response

response = await call_next(request)の行でリクエストに対する処理を行い、結果をresponseに格納します。

また特定のメソッドのみに特定の処理を加えたい場合は、以下のようにrequestからリクエストメソッドを取得して条件分岐させると良いでしょう。

main.py
@app.middleware("http")
async def log_requests(request: Request, call_next):
    request_dict: dict = vars(request)
    post_method_flag: bool = request_dict["scope"]["method"] == "POST"
    if post_method_flag:
        # 特定の処理
        pass
    #後続の処理

またFastAPIでは、複数のミドルウェアを順番に適用することができます。例えば、リクエストのロギングと同時に、CORS(クロスオリジンリソース共有)の設定を行うことができます。

main.py
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 許可するオリジンを指定
    allow_credentials=True,
    allow_methods=["*"],  # 許可するHTTPメソッドを指定
    allow_headers=["*"],  # 許可するヘッダーを指定
)

@app.middleware("http")
async def log_requests(request: Request, call_next):
    ...

エラーハンドリング

アプリケーションのエラーハンドリングは、エラーの原因を素早く把握するために非常に重要です。

まずエラーの内容についてはYAMLファイル等で記載すると良いでしょう。

config/message.yaml
InternalServerError:
  common: "Internal Server Error"
  model_load_error: "Model load error" 

YAMLファイルからエラーメッセージを読み込み、必要に応じてメッセージを提供するクラスとカスタムエラークラスを定義します。これにより、特定のエラーに対してカスタムメッセージをログに残しつつ、FastAPIに適切なレスポンスを返すことができます。

core/error.py
import logging
from logging import Logger
import yaml

class Message:
    """
    config/message.yamlからメッセージを取得するクラス
    """
    def __new__(cls) -> "Message":
        if not hasattr(cls, "__instance"):
            cls.__instance = super(Message, cls).__new__(cls)
            with open("config/message.yaml", "r", encoding="utf-8") as file:
                cls.message: dict[str, dict[str, str]] = yaml.safe_load(file)
        return cls.__instance

    def code_to_message(self, http_error_status: str, error_code: str) -> str:
        return self.message[http_error_status][error_code]

        
class InternalServerError(Exception):
    STATUS_CODE: int = 500

    def __init__(self, err_code: str):
        message = Message()
        self.ERROR_MESSAGE = message.code_to_message("InternalServerError", err_code)
        logger: Logger = logging.getLogger(f"custom.{__name__}")
        logger.error(
            "Code:" + err_code + ", Message: " + self.ERROR_MESSAGE, exc_info=True
        )
        super().__init__(self.ERROR_MESSAGE)

logger.errorにおいてexc_info=Trueとすると、通常のエラーメッセージに加えてスタックトレースを取得できます。

FastAPIでは、exception_handlerを使ってカスタムエラー処理を追加することができます。

上記で説明したlifespan関数にエラーハンドリングを追加すると以下のようになります。

main.py
from contextlib import asynccontextmanager
from core.error import InternalServerError
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        app.state.model = model_load()
        yield
    except Exception:
        raise InternalServerError("model_load_error") #エラーの種類を引数で記載    

@app.exception_handler(InternalServerError)
async def internal_server_error_handler(request: Request, exc: InternalServerError):
    return JSONResponse(
        status_code=InternalServerError.STATUS_CODE,
        content={"message": exc.ERROR_MESSAGE},
    )

おわりに

今回の記事では、FastAPIを使って機械学習モデルを管理し、APIを効率的に構築するための設定や実装方法を紹介しました。ライフサイクルイベントを活用したモデルのロードやバックグラウンドタスクによる再訓練の実装、スレッドセーフティを考慮したプログラムなど、実践的な例を通じて説明しました。

気が向いたら以下の項目についても追記形式で記載したいと思います。

  • キャッシング
  • メトリクスの収集
  • 詳細なバリデーションチェック
  • 依存性注入(injectorライブラリ)
10
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?