LoginSignup
13

More than 1 year has passed since last update.

PythonのFastAPIをLambdaで動かそうと思ったらSQLModelも使ってみたくなったので調べてみた(テストまあまあ盛り)

Last updated at Posted at 2022-01-16

はじめに

Pythonでデータ分析やスクレイピングや画像認識などをやっているのですが、Webアプリのバックエンドとして仕事で本格的に利用したいと思い、いろいろ調べてみると、AWSの
LambdaでFastAPIを動かすとの話題が多く目に留まり、ふむふむ、と進めていくと、SQLModelも一緒に利用したいと思い、半日ほど試行錯誤してみたのでまとめてみたいと思います。

ソースはgithubに登録しております。

SQLModelの概要

SQLModelは、直感的で使いやすく、互換性が高く、堅牢になるように設計されており、Pythonの型アノテーションに基づき、PydanticとSQLAlchemyを利用しています。
SQLModelの公式ページ

利用例

公式サイトのサンプルコードに出てくるHeroクラスとTeamクラスを対象とする利用例となります。

ファイル構成は以下のとおりです。

プロジェクトルート
├─sample
│  │─common_const.py
│  │─common_function.py
│  │─hero.py
│  │─operation_hero.py
│  │─operation_hero_team.py
│  │─team.py
├─tests
│  │─__init__.py
│  │─common_function_test.py
│  │─operation_hero_team_test.py
│  │─operation_hero_test.py
ファイル名 説明
common_const.py 共通定義的な情報を格納
common_function.py 共通処理的を含むモジュール
hero.py Heroテーブルに対応するHeroクラスのモジュール
operation_hero.py Heroテーブルの処理を含むモジュール
operation_hero_team.py HeroテーブルとTeamテーブル処理を含むモジュール
team.py Teamテーブルに対応するTeamクラスを含むモジュール
common_function.py common_function.pyのテストモジュール
operation_hero_team_test.py operation_hero.pyのテストモジュール
operation_hero_test.py operation_hero_team.pyのテストモジュール

エンティティクラス

hero.py
from typing import Optional
from sqlmodel import Field, SQLModel

class Hero(SQLModel, table=True):
    id: Optional[int] = Field(default=None, primary_key=True)
    name: str
    secret_name: str
    age: Optional[int] = None
    team_id: Optional[int] = Field(default=None, foreign_key="team.id")

idはORマッパーでよくある主キーで自動採番されるあれです。
team_idはteamテーブルのidをに関連する外部キーです。default=Noneなので必須(Not null制約付きではない)ではありません。

team.py
from typing import Optional
from sqlmodel import Field, SQLModel

class Team(SQLModel, table=True):
    id: Optional[int] = Field(default=None, primary_key=True)
    name: str = Field(index=True)
    headquarters: str

common_const.py

ヒーロー名など、各種処理で共通に利用する文字列を定義しています。

common_const.py
import sys


class _heroConst:
    class ConstError(TypeError):
        pass

    def __setattr__(self, name, value):
        if name in self.__dict__:
            raise self.ConstError("Can't rebind const (%s)" % name)
        self.__dict__[name] = value


sys.modules[__name__] = _heroConst()

_heroConst.HERO_NAME_DEADPOND = "Deadpond"
_heroConst.HERO_NAME_SPIDER_BOY = "Spider-Boy"
_heroConst.HERO_NAME_RUSTY_MAN = "Rusty-Man"


_heroConst.HERO_SECRET_NAME_DEADPOND = "Dive Wilson"
_heroConst.HERO_SECRET_NAME_SPIDER_BOY = "Pedro Parqueador"
_heroConst.HERO_SECRET_NAME_RUSTY_MAN = "Tommy Sharp"


_heroConst.TEAM_NAME_PREVENTERS = "Preventers"
_heroConst.TEAM_NAME_Z_FORCE = "Z-Force"

common_function.py

common_function.pyでは、delete_all:heroとteamの全削除、初期データの登録処理、セッションの作成処理を実装しています。
全体の内容は以下のとおりです。

common_const.py
from typing import List
from sample.hero import Hero
from sample.team import Team
import sample.common_const as HeroConst

from sqlmodel import SQLModel, Session, create_engine, delete
from logging import getLogger, StreamHandler, DEBUG
import traceback
import sqlalchemy
from sqlalchemy.future import Engine
from sqlalchemy.orm import sessionmaker

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


def delete_all() -> bool:
    engine = get_engine()

    with create_session(engine) as session:
        try:
            statement_hero = delete(Hero)
            delete_hero_result = session.exec(statement_hero)

            statement_team = delete(Team)
            delete_team_result = session.exec(statement_team)

            session.commit()
            logger.debug(f"delete_hero_result.rowcount={delete_hero_result.rowcount}")
            logger.debug(f"delete_team_result.rowcount={delete_team_result.rowcount}")
        except (sqlalchemy.exc.OperationalError, BaseException) as error:
            # テーブルが存在しないときにsqlalchemy.exc.OperationalErrorが発生するが問題問題なし
            if type(error) == sqlalchemy.exc.OperationalError:
                logger.error("Exception occurred")
                return True

            trace = traceback.format_exception_only(type(error), error)
            logger.debug(trace)
            return False

    return True


def init_data() -> List[Hero]:
    logger.debug("init_data start")

    engine = get_engine()

    SQLModel.metadata.create_all(engine)

    teams = generate_initdata_teams()
    heroes = generate_initdata_heroes()

    create_teams(teams, engine)

    # PreventersのidをDeadpondとSpider-Boyのteam_idセット
    heroes[0].team_id = teams[0].id
    heroes[1].team_id = teams[0].id
    # Z-ForceのidをRusty-Manのteam_idセット
    heroes[2].team_id = teams[1].id

    create_heroes(heroes, engine)

    logger.debug("init_data end")

    return heroes


def generate_initdata_teams() -> List[Team]:
    team_preventers = Team(
        name=HeroConst.TEAM_NAME_PREVENTERS, headquarters="Sharp Tower"
    )
    team_z_force = Team(
        name=HeroConst.TEAM_NAME_Z_FORCE, headquarters="Sister Margaret’s Bar"
    )
    return [team_preventers, team_z_force]


def generate_initdata_heroes() -> List[Hero]:
    hero_1 = Hero(
        name=HeroConst.HERO_NAME_DEADPOND,
        secret_name=HeroConst.HERO_SECRET_NAME_DEADPOND,
        age=30,
    )
    hero_2 = Hero(
        name=HeroConst.HERO_NAME_SPIDER_BOY,
        secret_name=HeroConst.HERO_SECRET_NAME_SPIDER_BOY,
    )
    hero_3 = Hero(
        name=HeroConst.HERO_NAME_RUSTY_MAN,
        secret_name=HeroConst.HERO_SECRET_NAME_RUSTY_MAN,
        age=48,
    )
    return [hero_1, hero_2, hero_3]


def create_teams(teams: List[Team], engine: Engine):
    logger.debug(f"create_teams start teams size={len(teams)}")

    with create_session(engine) as session:
        session.add_all(teams)
        session.commit()

        for team in teams:
            # refreshでなくても参照するだけで反映される。アクセスすると遅延ロードしてくれる
            session.refresh(team)
            logger.debug(f"team.id={team.id}")

    logger.debug("create_teams end")


def create_heroes(heros: List[Hero], engine: Engine):
    logger.debug(f"create_heroes start heros size={len(heros)}")

    with create_session(engine) as session:
        session.add_all(heros)
        session.commit()

    logger.debug("create_heroes end")


def get_engine() -> Engine:
    return create_engine("sqlite:///database.db", echo=True)


def create_session(engine: Engine) -> Engine:
    return Session(engine)

    """
    session_factory = sessionmaker(
        bind=engine, expire_on_commit=False, autocommit=False
    )

    return session_factory()
    """

セッションの作成処理

sqlite利用する部分はお決まりの処理ですので、説明は省略させていただきます。
セッションはsqlmodelのSessionのコンストラクタにEngineを渡せばOKです。

common_function.py

def get_engine() -> Engine:
    return create_engine("sqlite:///database.db", echo=True)


def create_session(engine: Engine) -> Engine:
    return Session(engine)

    """
    session_factory = sessionmaker(
        bind=engine, expire_on_commit=False, autocommit=False
    )

    return session_factory()
    """

コメント化しているのですが、sqlmodelはデフォルトだとcommitするとセッション内のインスタンスが全て期限切れになるので、関連付けられていたセッションから遅延ロードしようとして、sqlalchemy.orm.exc.DetachedInstanceErrorが発生して読めなくなります。

これを回避するために、sqlalchemyのsessionmakerでexpire_on_commit=Falseを指定してセッションを作成し、commitしてもインスタンスが利用できるようにする必要があるのですが、後程登場するSessionクラスの検索処理時のexecメソッドが呼び出せなくなりますので、create_sessionのテストでインスタンスの中身の検証を行うときのみコメント部分の方を有効にする必要があります。まあ、実際のプロダクションコードではセッションを使いまわすので、問題になる事はないです。

更新系のexecメソッドだと実行可能なのは謎ですが・・・

heroとteamの全削除

HeroテーブルとTeamテーブルの全データを削除しています。

common_function.py
def delete_all() -> bool:
    engine = get_engine()

    with create_session(engine) as session:
        try:
            statement_hero = delete(Hero)
            delete_hero_result = session.exec(statement_hero)

            statement_team = delete(Team)
            delete_team_result = session.exec(statement_team)

            session.commit()
            logger.debug(f"delete_hero_result.rowcount={delete_hero_result.rowcount}")
            logger.debug(f"delete_team_result.rowcount={delete_team_result.rowcount}")
        except (sqlalchemy.exc.OperationalError, BaseException) as error:
            # テーブルが存在しないときにsqlalchemy.exc.OperationalErrorが発生するが問題問題なし
            if type(error) == sqlalchemy.exc.OperationalError:
                logger.error("Exception occurred")
                return True

            trace = traceback.format_exception_only(type(error), error)
            logger.debug(trace)
            return False

    return True

delete(Hero)とエンティティを指定して得られた処理をsession.execに渡せばHeroテーブルの全データが削除されます。
statement_hero = delete(Hero).where(Hero.name == "HOGE")
のように条件を指定しての削除も可能です。whereで条件を指定する方法は、検索処理のところで説明させていただきます。

初期データの登録処理

common_function.py
def init_data() -> List[Hero]:
    logger.debug("init_data start")

    engine = get_engine()

    SQLModel.metadata.create_all(engine)

    teams = generate_initdata_teams()
    heroes = generate_initdata_heroes()

    create_teams(teams, engine)

    # PreventersのidをDeadpondとSpider-Boyのteam_idセット
    heroes[0].team_id = teams[0].id
    heroes[1].team_id = teams[0].id
    # Z-ForceのidをRusty-Manのteam_idセット
    heroes[2].team_id = teams[1].id

    create_heroes(heroes, engine)

    logger.debug("init_data end")

    return heroes

SQLModel.metadata.create_all(engine)を実行すれば、ロードしているエンティティに対応するテーブルが存在しない場合は、DBにテーブルを作成してくれます。

generate_initdata_teamsgenerate_initdata_heroesはテーブルに登録するエンティティのインスタンスを含むリストを返却する処理となります。

create_teamscreate_heroesは上述のエンティティのリストをDBに登録する処理となります。

create_teams実行後にteamsの各要素のidに値がセットされますので、
heroes[0].team_id = teams[0].idのように、Heroのteam_idにTeamのidをセットしています。

def generate_initdata_teams() -> List[Team]:
    team_preventers = Team(
        name=HeroConst.TEAM_NAME_PREVENTERS, headquarters="Sharp Tower"
    )
    team_z_force = Team(
        name=HeroConst.TEAM_NAME_Z_FORCE, headquarters="Sister Margaret’s Bar"
    )
    return [team_preventers, team_z_force]


def generate_initdata_heroes() -> List[Hero]:
    hero_1 = Hero(
        name=HeroConst.HERO_NAME_DEADPOND,
        secret_name=HeroConst.HERO_SECRET_NAME_DEADPOND,
        age=30,
    )
    hero_2 = Hero(
        name=HeroConst.HERO_NAME_SPIDER_BOY,
        secret_name=HeroConst.HERO_SECRET_NAME_SPIDER_BOY,
    )
    hero_3 = Hero(
        name=HeroConst.HERO_NAME_RUSTY_MAN,
        secret_name=HeroConst.HERO_SECRET_NAME_RUSTY_MAN,
        age=48,
    )
    return [hero_1, hero_2, hero_3]

ですので、hero_1とhero_2はteam_preventers(name=Preventers)、hero_3はteam_z_force(name=Z-Force)に属するようになります。

実際のデータ登録処理は以下のとおりです。

def create_teams(teams: List[Team], engine: Engine):
    logger.debug(f"create_teams start teams size={len(teams)}")

    with create_session(engine) as session:
        session.add_all(teams)
        session.commit()

        for team in teams:
            # refreshでなくても参照するだけで反映される。アクセスすると遅延ロードしてくれるので
            session.refresh(team)
            logger.debug(f"team.id={team.id}")

    logger.debug("create_teams end")


def create_heroes(heroes: List[Hero], engine: Engine):
    logger.debug(f"create_heroes start heros size={len(heroes)}")

    with create_session(engine) as session:
        session.add_all(heroes)
        session.commit()

    logger.debug("create_heroes end")

session.add_all(teams)のようにエンティティのインスタンスのリストを指定すると、一括でInsert可能です。

        for team in teams:
            session.add(teams)

のように単一インスタンスを指定してInsertすることも可能です。

common_function.pyのテスト

init_data_fixtureでデータを削除後に初期データ登録し、delete_allとinit_dataを呼び出すテストとなります。
検証ほぼないですが・・・

common_function_test.py
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG

from sample.common_function import delete_all, init_data
import sample.common_const as HeroConst

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


@pytest.fixture
def init_data_fixture():
    # sqlite:///database.dbとローカルファイル指定なので全データ削除
    delete_all()
    # 前提データの作成
    init_data()


def test_delete_all(init_data_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    assert delete_all()

    logger.debug(f"{sys._getframe().f_code.co_name} end")


def test_init_data():
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    heroes = init_data()
    assert len(heroes) == 3

    # 関連付けられていたセッションから遅延ロードしようとして、sqlalchemy.orm.exc.DetachedInstanceErrorが発生して読めない
    # デフォルトだとcommitするとセッション内のインスタンスが全て期限切れになるので
    # sqlalchemyのsessionmakerでexpire_on_commit=Falseを指定してセッションを作成しないと以下の検証は実行できません。

    """
    assert heroes[0].name == HeroConst.HERO_NAME_DEADPOND
    assert heroes[1].name == HeroConst.HERO_NAME_SPIDER_BOY
    assert heroes[2].name == HeroConst.HERO_NAME_RUSTY_MAN
    """

    logger.debug(f"{sys._getframe().f_code.co_name} end")

コメントにも記載しておりますが、init_dataが返却するList[Hero]の値の検証は、sessionmakerで生成したセッションでないと成功しません。lenだけは大丈夫なんですよね・・・

テストを実行したときのログ(一部)は以下のようになります。

--------------------------------------------------- Captured stdout setup ----------------------------------------------------
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine DELETE FROM team
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [generated in 0.00010s] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("hero")
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [raw sql] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("team")
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [raw sql] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine INSERT INTO team (name, headquarters) VALUES (?, ?)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [generated in 0.00012s] ('Preventers', 'Sharp Tower')
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO team (name, headquarters) VALUES (?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001709s ago] ('Z-Force', 'Sister Margaret’s Bar')
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine SELECT team.id, team.name, team.headquarters
FROM team
WHERE team.id = ?
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [generated in 0.00013s] (1,)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine SELECT team.id, team.name, team.headquarters
FROM team
WHERE team.id = ?
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.0009577s ago] (2,)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine ROLLBACK
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [generated in 0.00012s] ('Deadpond', 'Dive Wilson', 30, 1)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001575s ago] ('Spider-Boy', 'Pedro Parqueador', None, 1)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001779s ago] ('Rusty-Man', 'Tommy Sharp', 48, 2)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine COMMIT
--------------------------------------------------- Captured stderr setup ----------------------------------------------------
delete_hero_result.rowcount=0
delete_team_result.rowcount=2
init_data start
create_teams start teams size=2
team.id=1
team.id=2
create_teams end
create_heroes start heros size=3
create_heroes end
init_data end

Heroテーブルの検索と更新処理

Heroテーブルの全件検索、name指定での検索、nameとsecret_name指定(AND)での検索、nameとsecret_name指定(OR)での検索、条件のageより大きい年齢のHeroの検索、nameを条件にageを更新
との処理を含みます。

operation_hero_test.py
from typing import List
from sample.hero import Hero
from sample.team import Team
from sample.common_function import get_engine, create_session

from sqlmodel import (
    select,
    or_,
)

from logging import exception, getLogger, StreamHandler, DEBUG

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


def select_all() -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        heroes = session.exec(select(Hero)).all()
        return heroes


def select_by_name(name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.name == name)
        return session.exec(statement).all()


def select_by_name_and_secret_name(name: str, secret_name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = (
            select(Hero).where(Hero.name == name).where(Hero.secret_name == secret_name)
        )

        # 単一のwhereでも指定可能
        # statement = select(Hero).where(Hero.name == name, Hero.secret_name == secret_name)

        return session.exec(statement).all()


def select_by_name_or_secret_name(name: str, secret_name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(
            or_(Hero.name == name, Hero.secret_name == secret_name)
        )
        return session.exec(statement).all()


def select_by_age_above(age: int) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.age > age)
        return session.exec(statement).all()


def update_age_by_name(name: str, age: int) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.name == name)
        # 条件にあう先頭のレコードを取得
        hero = session.exec(statement).one()

        # heroインスタンスのageを更新してsession.addとsession.commitで更新
        hero.age = age
        session.add(hero)
        session.commit()

Heroテーブルの全件検索

select(Hero)で検索処理に必要なオブジェクトが返却されますので、これをsession.execに渡し、結果に対してallを呼び出します。

def select_all() -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        heroes = session.exec(select(Hero)).all()
        return heroes

name指定での検索

select(Hero).where(Hero.name == name)でHero.nameの検索条件が指定されたオブジェクトが返却されますので、session.execに渡し、結果に対してallを呼び出します。

def select_by_name(name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.name == name)
        return session.exec(statement).all()

nameとsecret_name指定(AND)での検索

whereを複数指定することでANDで条件が指定可能となります。
コメントでも記載済みですが、単一のwhere内で複数条件を指定することで同様の動作が実現可能です。

def select_by_name_and_secret_name(name: str, secret_name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = (
            select(Hero).where(Hero.name == name).where(Hero.secret_name == secret_name)
        )

        # 単一のwhereでも指定可能
        # statement = select(Hero).where(Hero.name == name, Hero.secret_name == secret_name)

        return session.exec(statement).all()

nameとsecret_name指定(OR)での検索

whereにor_(Hero.name == name, Hero.secret_name == secret_name)を指定することで、Hero.nameとHero.secret_nameのORの条件で検索しています。

def select_by_name_or_secret_name(name: str, secret_name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(
            or_(Hero.name == name, Hero.secret_name == secret_name)
        )
        return session.exec(statement).all()

条件のageより大きい年齢のHeroの検索

素直な仕様ですね、説明は不要と思います。

def select_by_age_above(age: int) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.age > age)
        return session.exec(statement).all()

nameを条件にageを更新

条件にあうレコードに対応するインスタンスを取得し、値を更新、session.add後にcommitすると更新できます。

def update_age_by_name(name: str, age: int) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero).where(Hero.name == name)
        # 条件にあう先頭のレコードを取得
        hero = session.exec(statement).one()

        # heroインスタンスのageを更新してsession.addとsession.commitで更新
        hero.age = age
        session.add(hero)
        session.commit()

Heroテーブルの検索と更新処理のテスト

sqlalchemy.orm.exc.DetachedInstanceErrorの関係で検証はlenしかしておりません。

operation_hero_test.py
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG
from typing import List, Tuple
import sample.common_const as HeroConst

from sample.operation_hero import (
    select_all,
    select_by_name,
    select_by_name_and_secret_name,
    select_by_name_or_secret_name,
    select_by_age_above,
    update_age_by_name,
)
from sample.hero import Hero
from sample.common_function import delete_all, init_data

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


@pytest.fixture
def init_data_fixture():
    # sqlite:///database.dbとローカルファイル指定なので全データ削除
    delete_all()
    # 前提データの作成
    init_data()


@pytest.fixture(params=[(HeroConst.HERO_NAME_DEADPOND, 1), ("Hoge", 0)])
def test_select_by_name_fixture(request) -> Tuple[str, int]:
    request.getfixturevalue("init_data_fixture")
    return (request.param[0], request.param[1])


def test_select_by_name(test_select_by_name_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    name, expected_result = test_select_by_name_fixture

    result = select_by_name(name)
    output_name(result)

    assert len(result) == expected_result

    logger.debug(f"{sys._getframe().f_code.co_name} start")


@pytest.fixture(
    params=[
        (HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_DEADPOND, 1),
        (HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 0),
        (HeroConst.HERO_NAME_SPIDER_BOY, HeroConst.HERO_SECRET_NAME_DEADPOND, 0),
    ]
)
def test_select_by_name_and_secret_name_fixture(request) -> Tuple[str, str, int]:
    request.getfixturevalue("init_data_fixture")
    return (request.param[0], request.param[1], request.param[2])


def test_select_by_name_and_secret_name(test_select_by_name_and_secret_name_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    (
        name,
        secret_name,
        expected_result,
    ) = test_select_by_name_and_secret_name_fixture

    result = select_by_name_and_secret_name(name, secret_name)
    output_name(result)

    assert len(result) == expected_result

    logger.debug(f"{sys._getframe().f_code.co_name} start")


@pytest.fixture(
    params=[
        (HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_DEADPOND, 1),
        (HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 2),
        (HeroConst.HERO_NAME_SPIDER_BOY, HeroConst.HERO_SECRET_NAME_DEADPOND, 2),
        (HeroConst.HERO_NAME_SPIDER_BOY, "Hoge Hoge", 1),
        ("Hoge", HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 1),
    ]
)
def test_select_by_name_or_secret_name_fixture(request) -> Tuple[str, str, int]:
    request.getfixturevalue("init_data_fixture")
    return (request.param[0], request.param[1], request.param[2])


def test_select_by_name_or_secret_name(test_select_by_name_or_secret_name_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    (
        name,
        secret_name,
        expected_result,
    ) = test_select_by_name_or_secret_name_fixture

    result = select_by_name_or_secret_name(name, secret_name)
    output_name(result)

    error_message = (
        f"name={name} secret_name={secret_name}で検索した結果の要素数の期待値は{expected_result}です。"
    )

    assert len(result) == expected_result, error_message

    logger.debug(f"{sys._getframe().f_code.co_name} start")


@pytest.fixture(params=[(29, 2), (30, 1), (47, 1), (48, 0)])
def test_select_by_age_above_fixture(request) -> Tuple[int, int]:
    request.getfixturevalue("init_data_fixture")
    return (request.param[0], request.param[1])


def test_select_by_age_above(test_select_by_age_above_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    (
        age,
        expected_result,
    ) = test_select_by_age_above_fixture

    result = select_by_age_above(age)
    output_name(result)

    assert len(result) == expected_result

    logger.debug(f"{sys._getframe().f_code.co_name} start")


def test_select_all(init_data_fixture):
    logger.debug("test_select_all start")

    result = select_all()
    output_name(result)

    assert len(result) == 3

    logger.debug("test_select_all end")


def test_update_age_by_name(init_data_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    target_name = "Rusty-Man"
    update_age_by_name(target_name, 50)

    heroes = select_by_name(target_name)

    assert len(heroes) == 1
    assert heroes[0].age == 50

    logger.debug(f"{sys._getframe().f_code.co_name} end")


def output_name(target_list: List[Hero]):
    for index, hero in enumerate(target_list):
        logger.debug(f"index={index} hero.name={hero.name}")

HeroテーブルとTeamテーブルをjoinする処理

Team.nameを条件にHeroを検索する処理を実装しています。

operation_hero_team.py
from typing import List
from sample.hero import Hero
from sample.team import Team
from sample.common_function import get_engine, create_session

from sqlmodel import Field, Session, SQLModel, create_engine, select, delete, or_
from logging import getLogger, StreamHandler, DEBUG
from sqlalchemy.orm.session import sessionmaker

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


def select_heroes_by_team_name(team_name: str) -> List[Hero]:
    engine = get_engine()

    with create_session(engine) as session:
        statement = select(Hero, Team).where(
            Hero.team_id == Team.id, Team.name == team_name
        )
        return session.exec(statement).all()

select(Hero, Team)で対象テーブルがHeroとTeamであることを指定、whereでHero.team_id == Team.idと指定された`team_nameでTeam.nameを絞り込むとの処理となります。

テスト実行時のログ(抜粋)は以下のようになります。

INFO     sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO     sqlalchemy.engine.Engine:log.py:117 SELECT hero.id, hero.name, hero.secret_name, hero.age, hero.team_id, team.id AS id_1, team.name AS name_1, team.headquarters
FROM hero, team
WHERE hero.team_id = team.id AND team.name = ?
INFO     sqlalchemy.engine.Engine:log.py:117 [no key 0.00015s] ('Preventers',)
INFO     sqlalchemy.engine.Engine:log.py:117 ROLLBACK

join句は利用せず、fromに複数テーブルを指定し、where句でHeroテーブルとTeamテーブルの結合条件とTeamのnameの条件指定を行っているSQLになっていることが分かります。

HeroテーブルとTeamテーブルをjoinする処理のテスト

operation_hero_team_test.py
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG
import sample.common_const as HeroConst

from sample.operation_hero_team import (
    select_heroes_by_team_name,
)
from sample.hero import Hero
from sample.common_function import delete_all, init_data

logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False


@pytest.fixture
def init_data_fixture():
    delete_all()
    init_data()


def test_select_heroes_by_team_name(init_data_fixture):
    logger.debug(f"{sys._getframe().f_code.co_name} start")

    result = select_heroes_by_team_name(HeroConst.TEAM_NAME_PREVENTERS)
    assert len(result) == 2

    result = select_heroes_by_team_name(HeroConst.TEAM_NAME_Z_FORCE)
    assert len(result) == 1

    logger.debug(f"{sys._getframe().f_code.co_name} end")

まとめ

ざっくりと動作確認をした感想ですが、良い意味で、既視感の強い作りですので、多くの方がストレスなく利用できるプロダクトだと感じます。
今回は触りませんでしたが、async sessionにも対応しており、FastAPIと一緒に利用すると作業が捗りそうですね。

懸念材料ですが、バージョンが0.0.6であり、まだまだ足りない部分も多いですし、「Breaking Changes」が行われる可能性が高いと感じます。これからのプロダクトですので、より良い物になっていくことを期待してSQLModelを選択する方も多いと思いますので、今後も注目していきたいです。

ソースはgithubに登録しております。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13