LoginSignup
4
1

More than 1 year has passed since last update.

SQLAlchemyとpydanticを組み合わせる

Posted at

「(機械学習のモデルをS3に保存し)そのメタデータをDB(MySQL)で管理するクラス」の実装です。JOINなどの処理は行なっていません。

from __future__ import annotations
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy import Column, Integer, String, Date
from sqlalchemy.ext.declarative import declarative_base
from pydantic import BaseModel, constr
from datetime import date


class DbClient:
    """MySQLへの問い合わせを管理する。"""

    def __init__(self, *, host: str, user: str, passwd: str, db: str):
        """データベースへの接続情報を保持する。

        Args:
            host (str): ホスト名
            user (str): DBのユーザー名
            passwd (str): DBのパスワード
            db (str): 利用するデータベース
        """
        engine = create_engine(
            f"mysql://{user}:{passwd}@{host}/{db}?charset=utf8",
            echo=True,
            encoding="utf-8",
        )
        self.session = scoped_session(
            sessionmaker(autocommit=False, autoflush=False, bind=engine)
        )

    def store_model_info(
        self, *, data_create_date: date, entry_market: str, cv_market: str
    ) -> MlModelModel:
        """データベースにモデル情報を保存する。

        Args:
            data_create_date (date): データの作成日
            entry_market (str): 対象マーケット
            cv_market (str): 成果対象
        Returns:
            MlModelModel: 追加したモデル情報
        """
        orm = _MlModelModelsOrm()
        orm.data_create_date = data_create_date
        orm.entry_market = entry_market
        orm.cv_market = cv_market
        self.session.add(orm)
        self.session.commit()
        # commit時点でauto incrementのidも参照できるようになっており、それをpydanticオブジェクトに変換する
        return MlModelModel.from_orm(orm)

    def fetch_model_info_by_id(self, model_id: int) -> MlModelModel:
        """idをキーにしてモデル情報を取得する。

        Args:
            model_id (int): モデル名

        Raises:
            ValueError: 対象のidのデータが存在しなかった場合

        Returns:
            MlModelModel: モデル情報
        """
        m = (
            self.session.query(_MlModelModelsOrm)
            .filter(_MlModelModelsOrm.id == model_id)
            .first()
        )
        if m is None:
            raise ValueError(f"model not found: id={model_id}")
        return PfsModelModel.from_orm(m)


class MlModelModel(BaseModel):
    id: int
    data_create_date: date
    entry_market: constr(max_length=255)
    cv_market: constr(max_length=255)

    class Config:
        orm_mode = True


Base = declarative_base()


class _MlModelsOrm(Base):
    __tablename__ = "ml_model_models"
    id = Column(Integer, primary_key=True, nullable=False)
    data_create_date = Column(Date)
    entry_market = Column(String(255))
    cv_market = Column(String(255))

参考資料

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1