「(機械学習のモデルをS3に保存し)そのメタデータをDB(MySQL)で管理するクラス」の実装です。JOINなどの処理は行なっていません。
from __future__ import annotations
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy import Column, Integer, String, Date
from sqlalchemy.ext.declarative import declarative_base
from pydantic import BaseModel, constr
from datetime import date
class DbClient:
"""MySQLへの問い合わせを管理する。"""
def __init__(self, *, host: str, user: str, passwd: str, db: str):
"""データベースへの接続情報を保持する。
Args:
host (str): ホスト名
user (str): DBのユーザー名
passwd (str): DBのパスワード
db (str): 利用するデータベース
"""
engine = create_engine(
f"mysql://{user}:{passwd}@{host}/{db}?charset=utf8",
echo=True,
encoding="utf-8",
)
self.session = scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
def store_model_info(
self, *, data_create_date: date, entry_market: str, cv_market: str
) -> MlModelModel:
"""データベースにモデル情報を保存する。
Args:
data_create_date (date): データの作成日
entry_market (str): 対象マーケット
cv_market (str): 成果対象
Returns:
MlModelModel: 追加したモデル情報
"""
orm = _MlModelModelsOrm()
orm.data_create_date = data_create_date
orm.entry_market = entry_market
orm.cv_market = cv_market
self.session.add(orm)
self.session.commit()
# commit時点でauto incrementのidも参照できるようになっており、それをpydanticオブジェクトに変換する
return MlModelModel.from_orm(orm)
def fetch_model_info_by_id(self, model_id: int) -> MlModelModel:
"""idをキーにしてモデル情報を取得する。
Args:
model_id (int): モデル名
Raises:
ValueError: 対象のidのデータが存在しなかった場合
Returns:
MlModelModel: モデル情報
"""
m = (
self.session.query(_MlModelModelsOrm)
.filter(_MlModelModelsOrm.id == model_id)
.first()
)
if m is None:
raise ValueError(f"model not found: id={model_id}")
return PfsModelModel.from_orm(m)
class MlModelModel(BaseModel):
id: int
data_create_date: date
entry_market: constr(max_length=255)
cv_market: constr(max_length=255)
class Config:
orm_mode = True
Base = declarative_base()
class _MlModelsOrm(Base):
__tablename__ = "ml_model_models"
id = Column(Integer, primary_key=True, nullable=False)
data_create_date = Column(Date)
entry_market = Column(String(255))
cv_market = Column(String(255))