はじめに
ネットサーフィン中に「How to Use FastAPI for Machine Learning(機械学習のためのFastAPIの使い方)」という興味深い記事を見つけました。このブログ記事では、FastAPIを使って機械学習モデルを扱うための基本的な設定や実践的な使い方が解説されています。
この記事に触発された形ではありますが、本記事ではFastAPIを用いて機械学習モデルに関するAPIを構築する際の設定や実装方法をまとめていきます。
モデル定義
FastAPIで機械学習モデルを扱うには、アプリケーションの起動時に機械学習モデルをロードし、そのモデルをアプリケーション全体で管理する必要があります。
FastAPIでは、アプリケーションのライフスパンイベントを活用して、起動時にモデルをロードし、終了時にリソースを解放することができます。
またモデルはapp.state
を利用して管理することで、アプリケーション全体で簡単にアクセスできます。
プログラム例は以下です。
from contextlib import asynccontextmanager
from fastapi import FastAPI
def model_load():
# モデルロード関数
pass
@asynccontextmanager
async def lifespan(app: FastAPI):
# モデルのロード
app.state.model = model_load()
yield
# モデルのリソース解放
app = FastAPI(lifespan=lifespan)
コード内のasynccontextmanager
デコレータは、非同期ジェネレータ関数を非同期コンテキストマネージャに変換します。
yieldの前は__aenter__
メソッドに相当し、コンテキストの開始時に実行され、yieldの後は__aexit__
メソッドに相当し、コンテキストの終了時に実行されます。
またモデルが軽量の場合はpickle
などで管理し、自作したモデルロード関数で以下のように呼ぶと良いでしょう。
app.state.model = model_load('ファイルパス') #モデルロード関数
例えばLGBM
モデルを使う場合は、モデルをpickle形式で保存し、以下のようなクラスを設定する方法も有効でしょう。
import pickle
class LGBMPredictor:
def __init__(self, model_path: str) -> None:
with open(model_path, "rb") as f:
self.model = pickle.load(f)
# predictなどの処理を後続で記載
from core.model import LGBMPredictor
app.state.model = LGBMPredictor("ファイルパス")
モデルを使いたい場合はmain.py
内であればapp.state.model
にアクセスして使うことができます。
main.py
外で使いたい場合は以下ルーターの項目でも説明しますが、FastAPIで用意されているRequest
オブジェクトを使えば良いです。
またモデル等のリソースの解放について、以下の記事で簡単に解説しています。
モデル更新
機械学習モデルをオンライン学習する場合において、学習プロセスは時間がかかるため、バックグラウンドタスクを使用するのが適切です。
FastAPIではBackgroundTasks
を使って、バックグラウンドで非同期に処理を実行できます。
以下は、バックグラウンドタスクを使ってモデルを再訓練する例です。
from fastapi import FastAPI, BackgroundTasks
def retrain_model(data, model):
# 再訓練のロジックをここに実装
pass
@app.post("/retrain")
async def retrain(background_tasks: BackgroundTasks, request):
data = request.data #リクエストからデータが取れると仮定
background_tasks.add_task(retrain_model, data, app.state.model)
return {"message": "Model retraining started"}
モデルを複数のスレッドで同時にapp.state.model
を更新しようとすると、予期しない動作やデータの破損等の可能性があるので、スレッドセーフティを確保する必要があります。
スレッドセーフティを考慮したプログラムは以下です。
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.model = model_load()
# スレッドセーフティ用のロックを追加
app.state.model_lock = threading.Lock()
yield
retrain_model
の引数はモデルのみではなくapp_state
を引数とし、更新のタイミングでロックを取得します。
def retrain_model(data, app_state):
# 再訓練のロジックをここに実装
## ロックを取得してモデルを更新
with app_state.model_lock:
app_state.model = new_model
共通設定
以下では機械学習の使用の有無に関わらない、FastAPIを用いる全プロジェクトに共通する項目について簡単にまとめていきます。
ルーター
全てのエンドポイントをmain.pyに書いてしまうと、コードが煩雑になりメンテナンスが難しくなります。そこでエンドポイントを別のルーター用のファイルに分割することで、コードの可読性や保守性を高めることができます。
上記で説明したretrain
のエンドポイントをルーター用のファイルに移行する場合は、main.py
とrouter/router.py
は以下のように書けます。
from fastapi import APIRouter, BackgroundTasks, Body, Request
from your_module import retrain_model #モデル更新の関数
,
router = APIRouter()
@router.post("/retrain/")
async def retrain(background_tasks: BackgroundTasks, request_obj: Request, retrain_request = Body(...)):
data = retrain_request.data #リクエストからデータが取れると仮定
background_tasks.add_task(retrain_model, data, request_obj.app.state.model)
return {"message": "Model retraining started"}
from router import router
app.include_router(router.router)
Request
はASGI(Asynchronous Server Gateway Interface)で定義される「スコープ」やアプリケーションのステートにアクセスすることができます。
リクエストとレスポンス
コードの可読性とメンテナンス性の向上やバリデーション等のためにリクエストとレスポンスを設定します。
ディレクトリ名はmodels
としたいところですが、機械学習モデルと混同するのでpayloads
などが良いでしょう。
from pydantic import BaseModel
from typing import Dict
class RetrainRequest(BaseModel):
data: Dict[str, float]
from pydantic import BaseModel
class RetrainResponse(BaseModel):
message: str
from fastapi import APIRouter, BackgroundTasks, Request
from your_module import retrain_model #モデル更新の関数
from payloads.request import RetrainRequest
from payloads.response import RetrainResponse
,
router = APIRouter()
@app.post("/retrain/", response_model=RetrainResponse)
async def retrain(
background_tasks: BackgroundTasks,
request_obj: Request,
retrain_request: RetrainRequest
):
data = retrain_request.data
background_tasks.add_task(retrain_model, data, request_obj.app.state.model)
return {"message": "Model retraining started"}
設定したリクエストクラスは、関数の引数の型アシストで記載し、レスポンスクラスは@app.post()
のresponse_model
で定義します。
ロガー
ロガーを使うことで、アプリケーションの実行時の挙動を追跡し、エラーやパフォーマンスの問題を把握するのに役立ちます。
以下のように簡単にロガーを適用できます。
from contextlib import asynccontextmanager
from fastapi import FastAPI
import logging
from logger import logger
@asynccontextmanager
async def lifespan(app: FastAPI):
logger: Logger = logging.getLogger(f"custom.{__name__}")
logger.info("start.....")
yield
logger.info("shutdown.....")
ミドルウェア
ミドルウェアは、リクエストが処理される前後に何らかの処理を挿入するための仕組みです。例えば、全てのリクエストに対して共通の前処理や後処理を行いたい場合に役立ちます。
ミドルウェアを用いた例は以下です。
from fastapi import FastAPI, Request
import time
from logger import logger
app = FastAPI()
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
logger.info(f"Request: {request.method} {request.url}")
response = await call_next(request)
process_time = time.time() - start_time
logger.info(f"Response status: {response.status_code}, Time taken: {process_time:.4f} seconds")
return response
response = await call_next(request)
の行でリクエストに対する処理を行い、結果をresponse
に格納します。
また特定のメソッドのみに特定の処理を加えたい場合は、以下のようにrequest
からリクエストメソッドを取得して条件分岐させると良いでしょう。
@app.middleware("http")
async def log_requests(request: Request, call_next):
request_dict: dict = vars(request)
post_method_flag: bool = request_dict["scope"]["method"] == "POST"
if post_method_flag:
# 特定の処理
pass
#後続の処理
またFastAPIでは、複数のミドルウェアを順番に適用することができます。例えば、リクエストのロギングと同時に、CORS(クロスオリジンリソース共有)の設定を行うことができます。
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 許可するオリジンを指定
allow_credentials=True,
allow_methods=["*"], # 許可するHTTPメソッドを指定
allow_headers=["*"], # 許可するヘッダーを指定
)
@app.middleware("http")
async def log_requests(request: Request, call_next):
...
エラーハンドリング
アプリケーションのエラーハンドリングは、エラーの原因を素早く把握するために非常に重要です。
まずエラーの内容についてはYAMLファイル等で記載すると良いでしょう。
InternalServerError:
common: "Internal Server Error"
model_load_error: "Model load error"
YAMLファイルからエラーメッセージを読み込み、必要に応じてメッセージを提供するクラスとカスタムエラークラスを定義します。これにより、特定のエラーに対してカスタムメッセージをログに残しつつ、FastAPIに適切なレスポンスを返すことができます。
import logging
from logging import Logger
import yaml
class Message:
"""
config/message.yamlからメッセージを取得するクラス
"""
def __new__(cls) -> "Message":
if not hasattr(cls, "__instance"):
cls.__instance = super(Message, cls).__new__(cls)
with open("config/message.yaml", "r", encoding="utf-8") as file:
cls.message: dict[str, dict[str, str]] = yaml.safe_load(file)
return cls.__instance
def code_to_message(self, http_error_status: str, error_code: str) -> str:
return self.message[http_error_status][error_code]
class InternalServerError(Exception):
STATUS_CODE: int = 500
def __init__(self, err_code: str):
message = Message()
self.ERROR_MESSAGE = message.code_to_message("InternalServerError", err_code)
logger: Logger = logging.getLogger(f"custom.{__name__}")
logger.error(
"Code:" + err_code + ", Message: " + self.ERROR_MESSAGE, exc_info=True
)
super().__init__(self.ERROR_MESSAGE)
logger.error
においてexc_info=True
とすると、通常のエラーメッセージに加えてスタックトレースを取得できます。
FastAPIでは、exception_handler
を使ってカスタムエラー処理を追加することができます。
上記で説明したlifespan
関数にエラーハンドリングを追加すると以下のようになります。
from contextlib import asynccontextmanager
from core.error import InternalServerError
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
app.state.model = model_load()
yield
except Exception:
raise InternalServerError("model_load_error") #エラーの種類を引数で記載
@app.exception_handler(InternalServerError)
async def internal_server_error_handler(request: Request, exc: InternalServerError):
return JSONResponse(
status_code=InternalServerError.STATUS_CODE,
content={"message": exc.ERROR_MESSAGE},
)
おわりに
今回の記事では、FastAPIを使って機械学習モデルを管理し、APIを効率的に構築するための設定や実装方法を紹介しました。ライフサイクルイベントを活用したモデルのロードやバックグラウンドタスクによる再訓練の実装、スレッドセーフティを考慮したプログラムなど、実践的な例を通じて説明しました。
気が向いたら以下の項目についても追記形式で記載したいと思います。
- キャッシング
- メトリクスの収集
- 詳細なバリデーションチェック
- 依存性注入(injectorライブラリ)