はじめに
Pythonでデータ分析やスクレイピングや画像認識などをやっているのですが、Webアプリのバックエンドとして仕事で本格的に利用したいと思い、いろいろ調べてみると、AWSの
LambdaでFastAPIを動かすとの話題が多く目に留まり、ふむふむ、と進めていくと、SQLModelも一緒に利用したいと思い、半日ほど試行錯誤してみたのでまとめてみたいと思います。
ソースはgithubに登録しております。
SQLModelの概要
SQLModelは、直感的で使いやすく、互換性が高く、堅牢になるように設計されており、Pythonの型アノテーションに基づき、PydanticとSQLAlchemyを利用しています。
SQLModelの公式ページ
利用例
公式サイトのサンプルコードに出てくるHeroクラスとTeamクラスを対象とする利用例となります。
ファイル構成は以下のとおりです。
プロジェクトルート
├─sample
│ │─common_const.py
│ │─common_function.py
│ │─hero.py
│ │─operation_hero.py
│ │─operation_hero_team.py
│ │─team.py
├─tests
│ │─__init__.py
│ │─common_function_test.py
│ │─operation_hero_team_test.py
│ │─operation_hero_test.py
ファイル名 | 説明 |
---|---|
common_const.py | 共通定義的な情報を格納 |
common_function.py | 共通処理的を含むモジュール |
hero.py | Heroテーブルに対応するHeroクラスのモジュール |
operation_hero.py | Heroテーブルの処理を含むモジュール |
operation_hero_team.py | HeroテーブルとTeamテーブル処理を含むモジュール |
team.py | Teamテーブルに対応するTeamクラスを含むモジュール |
common_function.py | common_function.pyのテストモジュール |
operation_hero_team_test.py | operation_hero.pyのテストモジュール |
operation_hero_test.py | operation_hero_team.pyのテストモジュール |
エンティティクラス
from typing import Optional
from sqlmodel import Field, SQLModel
class Hero(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secret_name: str
age: Optional[int] = None
team_id: Optional[int] = Field(default=None, foreign_key="team.id")
id
はORマッパーでよくある主キーで自動採番されるあれです。
team_id
はteamテーブルのidをに関連する外部キーです。default=None
なので必須(Not null制約付きではない)ではありません。
from typing import Optional
from sqlmodel import Field, SQLModel
class Team(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(index=True)
headquarters: str
common_const.py
ヒーロー名など、各種処理で共通に利用する文字列を定義しています。
import sys
class _heroConst:
class ConstError(TypeError):
pass
def __setattr__(self, name, value):
if name in self.__dict__:
raise self.ConstError("Can't rebind const (%s)" % name)
self.__dict__[name] = value
sys.modules[__name__] = _heroConst()
_heroConst.HERO_NAME_DEADPOND = "Deadpond"
_heroConst.HERO_NAME_SPIDER_BOY = "Spider-Boy"
_heroConst.HERO_NAME_RUSTY_MAN = "Rusty-Man"
_heroConst.HERO_SECRET_NAME_DEADPOND = "Dive Wilson"
_heroConst.HERO_SECRET_NAME_SPIDER_BOY = "Pedro Parqueador"
_heroConst.HERO_SECRET_NAME_RUSTY_MAN = "Tommy Sharp"
_heroConst.TEAM_NAME_PREVENTERS = "Preventers"
_heroConst.TEAM_NAME_Z_FORCE = "Z-Force"
common_function.py
common_function.pyでは、delete_all:heroとteamの全削除、初期データの登録処理、セッションの作成処理を実装しています。
全体の内容は以下のとおりです。
from typing import List
from sample.hero import Hero
from sample.team import Team
import sample.common_const as HeroConst
from sqlmodel import SQLModel, Session, create_engine, delete
from logging import getLogger, StreamHandler, DEBUG
import traceback
import sqlalchemy
from sqlalchemy.future import Engine
from sqlalchemy.orm import sessionmaker
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
def delete_all() -> bool:
engine = get_engine()
with create_session(engine) as session:
try:
statement_hero = delete(Hero)
delete_hero_result = session.exec(statement_hero)
statement_team = delete(Team)
delete_team_result = session.exec(statement_team)
session.commit()
logger.debug(f"delete_hero_result.rowcount={delete_hero_result.rowcount}")
logger.debug(f"delete_team_result.rowcount={delete_team_result.rowcount}")
except (sqlalchemy.exc.OperationalError, BaseException) as error:
# テーブルが存在しないときにsqlalchemy.exc.OperationalErrorが発生するが問題問題なし
if type(error) == sqlalchemy.exc.OperationalError:
logger.error("Exception occurred")
return True
trace = traceback.format_exception_only(type(error), error)
logger.debug(trace)
return False
return True
def init_data() -> List[Hero]:
logger.debug("init_data start")
engine = get_engine()
SQLModel.metadata.create_all(engine)
teams = generate_initdata_teams()
heroes = generate_initdata_heroes()
create_teams(teams, engine)
# PreventersのidをDeadpondとSpider-Boyのteam_idセット
heroes[0].team_id = teams[0].id
heroes[1].team_id = teams[0].id
# Z-ForceのidをRusty-Manのteam_idセット
heroes[2].team_id = teams[1].id
create_heroes(heroes, engine)
logger.debug("init_data end")
return heroes
def generate_initdata_teams() -> List[Team]:
team_preventers = Team(
name=HeroConst.TEAM_NAME_PREVENTERS, headquarters="Sharp Tower"
)
team_z_force = Team(
name=HeroConst.TEAM_NAME_Z_FORCE, headquarters="Sister Margaret’s Bar"
)
return [team_preventers, team_z_force]
def generate_initdata_heroes() -> List[Hero]:
hero_1 = Hero(
name=HeroConst.HERO_NAME_DEADPOND,
secret_name=HeroConst.HERO_SECRET_NAME_DEADPOND,
age=30,
)
hero_2 = Hero(
name=HeroConst.HERO_NAME_SPIDER_BOY,
secret_name=HeroConst.HERO_SECRET_NAME_SPIDER_BOY,
)
hero_3 = Hero(
name=HeroConst.HERO_NAME_RUSTY_MAN,
secret_name=HeroConst.HERO_SECRET_NAME_RUSTY_MAN,
age=48,
)
return [hero_1, hero_2, hero_3]
def create_teams(teams: List[Team], engine: Engine):
logger.debug(f"create_teams start teams size={len(teams)}")
with create_session(engine) as session:
session.add_all(teams)
session.commit()
for team in teams:
# refreshでなくても参照するだけで反映される。アクセスすると遅延ロードしてくれる
session.refresh(team)
logger.debug(f"team.id={team.id}")
logger.debug("create_teams end")
def create_heroes(heros: List[Hero], engine: Engine):
logger.debug(f"create_heroes start heros size={len(heros)}")
with create_session(engine) as session:
session.add_all(heros)
session.commit()
logger.debug("create_heroes end")
def get_engine() -> Engine:
return create_engine("sqlite:///database.db", echo=True)
def create_session(engine: Engine) -> Engine:
return Session(engine)
"""
session_factory = sessionmaker(
bind=engine, expire_on_commit=False, autocommit=False
)
return session_factory()
"""
セッションの作成処理
sqlite利用する部分はお決まりの処理ですので、説明は省略させていただきます。
セッションはsqlmodelのSessionのコンストラクタにEngineを渡せばOKです。
def get_engine() -> Engine:
return create_engine("sqlite:///database.db", echo=True)
def create_session(engine: Engine) -> Engine:
return Session(engine)
"""
session_factory = sessionmaker(
bind=engine, expire_on_commit=False, autocommit=False
)
return session_factory()
"""
コメント化しているのですが、sqlmodelはデフォルトだとcommitするとセッション内のインスタンスが全て期限切れになるので、関連付けられていたセッションから遅延ロードしようとして、sqlalchemy.orm.exc.DetachedInstanceErrorが発生して読めなくなります。
これを回避するために、sqlalchemyのsessionmakerでexpire_on_commit=Falseを指定してセッションを作成し、commitしてもインスタンスが利用できるようにする必要があるのですが、後程登場するSessionクラスの検索処理時のexecメソッドが呼び出せなくなりますので、create_sessionのテストでインスタンスの中身の検証を行うときのみコメント部分の方を有効にする必要があります。まあ、実際のプロダクションコードではセッションを使いまわすので、問題になる事はないです。
更新系のexecメソッドだと実行可能なのは謎ですが・・・
heroとteamの全削除
HeroテーブルとTeamテーブルの全データを削除しています。
def delete_all() -> bool:
engine = get_engine()
with create_session(engine) as session:
try:
statement_hero = delete(Hero)
delete_hero_result = session.exec(statement_hero)
statement_team = delete(Team)
delete_team_result = session.exec(statement_team)
session.commit()
logger.debug(f"delete_hero_result.rowcount={delete_hero_result.rowcount}")
logger.debug(f"delete_team_result.rowcount={delete_team_result.rowcount}")
except (sqlalchemy.exc.OperationalError, BaseException) as error:
# テーブルが存在しないときにsqlalchemy.exc.OperationalErrorが発生するが問題問題なし
if type(error) == sqlalchemy.exc.OperationalError:
logger.error("Exception occurred")
return True
trace = traceback.format_exception_only(type(error), error)
logger.debug(trace)
return False
return True
delete(Hero)
とエンティティを指定して得られた処理をsession.exec
に渡せばHeroテーブルの全データが削除されます。
statement_hero = delete(Hero).where(Hero.name == "HOGE")
のように条件を指定しての削除も可能です。whereで条件を指定する方法は、検索処理のところで説明させていただきます。
初期データの登録処理
def init_data() -> List[Hero]:
logger.debug("init_data start")
engine = get_engine()
SQLModel.metadata.create_all(engine)
teams = generate_initdata_teams()
heroes = generate_initdata_heroes()
create_teams(teams, engine)
# PreventersのidをDeadpondとSpider-Boyのteam_idセット
heroes[0].team_id = teams[0].id
heroes[1].team_id = teams[0].id
# Z-ForceのidをRusty-Manのteam_idセット
heroes[2].team_id = teams[1].id
create_heroes(heroes, engine)
logger.debug("init_data end")
return heroes
SQLModel.metadata.create_all(engine)
を実行すれば、ロードしているエンティティに対応するテーブルが存在しない場合は、DBにテーブルを作成してくれます。
generate_initdata_teams
とgenerate_initdata_heroes
はテーブルに登録するエンティティのインスタンスを含むリストを返却する処理となります。
create_teams
とcreate_heroes
は上述のエンティティのリストをDBに登録する処理となります。
create_teams
実行後にteamsの各要素のidに値がセットされますので、
heroes[0].team_id = teams[0].id
のように、Heroのteam_idにTeamのidをセットしています。
def generate_initdata_teams() -> List[Team]:
team_preventers = Team(
name=HeroConst.TEAM_NAME_PREVENTERS, headquarters="Sharp Tower"
)
team_z_force = Team(
name=HeroConst.TEAM_NAME_Z_FORCE, headquarters="Sister Margaret’s Bar"
)
return [team_preventers, team_z_force]
def generate_initdata_heroes() -> List[Hero]:
hero_1 = Hero(
name=HeroConst.HERO_NAME_DEADPOND,
secret_name=HeroConst.HERO_SECRET_NAME_DEADPOND,
age=30,
)
hero_2 = Hero(
name=HeroConst.HERO_NAME_SPIDER_BOY,
secret_name=HeroConst.HERO_SECRET_NAME_SPIDER_BOY,
)
hero_3 = Hero(
name=HeroConst.HERO_NAME_RUSTY_MAN,
secret_name=HeroConst.HERO_SECRET_NAME_RUSTY_MAN,
age=48,
)
return [hero_1, hero_2, hero_3]
ですので、hero_1とhero_2はteam_preventers(name=Preventers)、hero_3はteam_z_force(name=Z-Force)に属するようになります。
実際のデータ登録処理は以下のとおりです。
def create_teams(teams: List[Team], engine: Engine):
logger.debug(f"create_teams start teams size={len(teams)}")
with create_session(engine) as session:
session.add_all(teams)
session.commit()
for team in teams:
# refreshでなくても参照するだけで反映される。アクセスすると遅延ロードしてくれるので
session.refresh(team)
logger.debug(f"team.id={team.id}")
logger.debug("create_teams end")
def create_heroes(heroes: List[Hero], engine: Engine):
logger.debug(f"create_heroes start heros size={len(heroes)}")
with create_session(engine) as session:
session.add_all(heroes)
session.commit()
logger.debug("create_heroes end")
session.add_all(teams)
のようにエンティティのインスタンスのリストを指定すると、一括でInsert可能です。
for team in teams:
session.add(teams)
のように単一インスタンスを指定してInsertすることも可能です。
common_function.pyのテスト
init_data_fixtureでデータを削除後に初期データ登録し、delete_allとinit_dataを呼び出すテストとなります。
検証ほぼないですが・・・
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG
from sample.common_function import delete_all, init_data
import sample.common_const as HeroConst
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
@pytest.fixture
def init_data_fixture():
# sqlite:///database.dbとローカルファイル指定なので全データ削除
delete_all()
# 前提データの作成
init_data()
def test_delete_all(init_data_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
assert delete_all()
logger.debug(f"{sys._getframe().f_code.co_name} end")
def test_init_data():
logger.debug(f"{sys._getframe().f_code.co_name} start")
heroes = init_data()
assert len(heroes) == 3
# 関連付けられていたセッションから遅延ロードしようとして、sqlalchemy.orm.exc.DetachedInstanceErrorが発生して読めない
# デフォルトだとcommitするとセッション内のインスタンスが全て期限切れになるので
# sqlalchemyのsessionmakerでexpire_on_commit=Falseを指定してセッションを作成しないと以下の検証は実行できません。
"""
assert heroes[0].name == HeroConst.HERO_NAME_DEADPOND
assert heroes[1].name == HeroConst.HERO_NAME_SPIDER_BOY
assert heroes[2].name == HeroConst.HERO_NAME_RUSTY_MAN
"""
logger.debug(f"{sys._getframe().f_code.co_name} end")
コメントにも記載しておりますが、init_dataが返却するList[Hero]の値の検証は、sessionmakerで生成したセッションでないと成功しません。lenだけは大丈夫なんですよね・・・
テストを実行したときのログ(一部)は以下のようになります。
--------------------------------------------------- Captured stdout setup ----------------------------------------------------
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine DELETE FROM team
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [generated in 0.00010s] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("hero")
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [raw sql] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine PRAGMA main.table_info("team")
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [raw sql] ()
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine INSERT INTO team (name, headquarters) VALUES (?, ?)
2022-01-16 17:40:13,595 INFO sqlalchemy.engine.Engine [generated in 0.00012s] ('Preventers', 'Sharp Tower')
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO team (name, headquarters) VALUES (?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001709s ago] ('Z-Force', 'Sister Margaret’s Bar')
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine COMMIT
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine SELECT team.id, team.name, team.headquarters
FROM team
WHERE team.id = ?
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [generated in 0.00013s] (1,)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine SELECT team.id, team.name, team.headquarters
FROM team
WHERE team.id = ?
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.0009577s ago] (2,)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine ROLLBACK
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [generated in 0.00012s] ('Deadpond', 'Dive Wilson', 30, 1)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001575s ago] ('Spider-Boy', 'Pedro Parqueador', None, 1)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine INSERT INTO hero (name, secret_name, age, team_id) VALUES (?, ?, ?, ?)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine [cached since 0.001779s ago] ('Rusty-Man', 'Tommy Sharp', 48, 2)
2022-01-16 17:40:13,611 INFO sqlalchemy.engine.Engine COMMIT
--------------------------------------------------- Captured stderr setup ----------------------------------------------------
delete_hero_result.rowcount=0
delete_team_result.rowcount=2
init_data start
create_teams start teams size=2
team.id=1
team.id=2
create_teams end
create_heroes start heros size=3
create_heroes end
init_data end
Heroテーブルの検索と更新処理
Heroテーブルの全件検索、name指定での検索、nameとsecret_name指定(AND)での検索、nameとsecret_name指定(OR)での検索、条件のageより大きい年齢のHeroの検索、nameを条件にageを更新
との処理を含みます。
from typing import List
from sample.hero import Hero
from sample.team import Team
from sample.common_function import get_engine, create_session
from sqlmodel import (
select,
or_,
)
from logging import exception, getLogger, StreamHandler, DEBUG
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
def select_all() -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
heroes = session.exec(select(Hero)).all()
return heroes
def select_by_name(name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.name == name)
return session.exec(statement).all()
def select_by_name_and_secret_name(name: str, secret_name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = (
select(Hero).where(Hero.name == name).where(Hero.secret_name == secret_name)
)
# 単一のwhereでも指定可能
# statement = select(Hero).where(Hero.name == name, Hero.secret_name == secret_name)
return session.exec(statement).all()
def select_by_name_or_secret_name(name: str, secret_name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(
or_(Hero.name == name, Hero.secret_name == secret_name)
)
return session.exec(statement).all()
def select_by_age_above(age: int) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.age > age)
return session.exec(statement).all()
def update_age_by_name(name: str, age: int) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.name == name)
# 条件にあう先頭のレコードを取得
hero = session.exec(statement).one()
# heroインスタンスのageを更新してsession.addとsession.commitで更新
hero.age = age
session.add(hero)
session.commit()
Heroテーブルの全件検索
select(Hero)
で検索処理に必要なオブジェクトが返却されますので、これをsession.execに渡し、結果に対してallを呼び出します。
def select_all() -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
heroes = session.exec(select(Hero)).all()
return heroes
name指定での検索
select(Hero).where(Hero.name == name)
でHero.nameの検索条件が指定されたオブジェクトが返却されますので、session.execに渡し、結果に対してallを呼び出します。
def select_by_name(name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.name == name)
return session.exec(statement).all()
nameとsecret_name指定(AND)での検索
whereを複数指定することでANDで条件が指定可能となります。
コメントでも記載済みですが、単一のwhere内で複数条件を指定することで同様の動作が実現可能です。
def select_by_name_and_secret_name(name: str, secret_name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = (
select(Hero).where(Hero.name == name).where(Hero.secret_name == secret_name)
)
# 単一のwhereでも指定可能
# statement = select(Hero).where(Hero.name == name, Hero.secret_name == secret_name)
return session.exec(statement).all()
nameとsecret_name指定(OR)での検索
whereにor_(Hero.name == name, Hero.secret_name == secret_name)を指定することで、Hero.nameとHero.secret_nameのORの条件で検索しています。
def select_by_name_or_secret_name(name: str, secret_name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(
or_(Hero.name == name, Hero.secret_name == secret_name)
)
return session.exec(statement).all()
条件のageより大きい年齢のHeroの検索
素直な仕様ですね、説明は不要と思います。
def select_by_age_above(age: int) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.age > age)
return session.exec(statement).all()
nameを条件にageを更新
条件にあうレコードに対応するインスタンスを取得し、値を更新、session.add後にcommitすると更新できます。
def update_age_by_name(name: str, age: int) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero).where(Hero.name == name)
# 条件にあう先頭のレコードを取得
hero = session.exec(statement).one()
# heroインスタンスのageを更新してsession.addとsession.commitで更新
hero.age = age
session.add(hero)
session.commit()
Heroテーブルの検索と更新処理のテスト
sqlalchemy.orm.exc.DetachedInstanceErrorの関係で検証はlenしかしておりません。
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG
from typing import List, Tuple
import sample.common_const as HeroConst
from sample.operation_hero import (
select_all,
select_by_name,
select_by_name_and_secret_name,
select_by_name_or_secret_name,
select_by_age_above,
update_age_by_name,
)
from sample.hero import Hero
from sample.common_function import delete_all, init_data
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
@pytest.fixture
def init_data_fixture():
# sqlite:///database.dbとローカルファイル指定なので全データ削除
delete_all()
# 前提データの作成
init_data()
@pytest.fixture(params=[(HeroConst.HERO_NAME_DEADPOND, 1), ("Hoge", 0)])
def test_select_by_name_fixture(request) -> Tuple[str, int]:
request.getfixturevalue("init_data_fixture")
return (request.param[0], request.param[1])
def test_select_by_name(test_select_by_name_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
name, expected_result = test_select_by_name_fixture
result = select_by_name(name)
output_name(result)
assert len(result) == expected_result
logger.debug(f"{sys._getframe().f_code.co_name} start")
@pytest.fixture(
params=[
(HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_DEADPOND, 1),
(HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 0),
(HeroConst.HERO_NAME_SPIDER_BOY, HeroConst.HERO_SECRET_NAME_DEADPOND, 0),
]
)
def test_select_by_name_and_secret_name_fixture(request) -> Tuple[str, str, int]:
request.getfixturevalue("init_data_fixture")
return (request.param[0], request.param[1], request.param[2])
def test_select_by_name_and_secret_name(test_select_by_name_and_secret_name_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
(
name,
secret_name,
expected_result,
) = test_select_by_name_and_secret_name_fixture
result = select_by_name_and_secret_name(name, secret_name)
output_name(result)
assert len(result) == expected_result
logger.debug(f"{sys._getframe().f_code.co_name} start")
@pytest.fixture(
params=[
(HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_DEADPOND, 1),
(HeroConst.HERO_NAME_DEADPOND, HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 2),
(HeroConst.HERO_NAME_SPIDER_BOY, HeroConst.HERO_SECRET_NAME_DEADPOND, 2),
(HeroConst.HERO_NAME_SPIDER_BOY, "Hoge Hoge", 1),
("Hoge", HeroConst.HERO_SECRET_NAME_SPIDER_BOY, 1),
]
)
def test_select_by_name_or_secret_name_fixture(request) -> Tuple[str, str, int]:
request.getfixturevalue("init_data_fixture")
return (request.param[0], request.param[1], request.param[2])
def test_select_by_name_or_secret_name(test_select_by_name_or_secret_name_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
(
name,
secret_name,
expected_result,
) = test_select_by_name_or_secret_name_fixture
result = select_by_name_or_secret_name(name, secret_name)
output_name(result)
error_message = (
f"name={name} secret_name={secret_name}で検索した結果の要素数の期待値は{expected_result}です。"
)
assert len(result) == expected_result, error_message
logger.debug(f"{sys._getframe().f_code.co_name} start")
@pytest.fixture(params=[(29, 2), (30, 1), (47, 1), (48, 0)])
def test_select_by_age_above_fixture(request) -> Tuple[int, int]:
request.getfixturevalue("init_data_fixture")
return (request.param[0], request.param[1])
def test_select_by_age_above(test_select_by_age_above_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
(
age,
expected_result,
) = test_select_by_age_above_fixture
result = select_by_age_above(age)
output_name(result)
assert len(result) == expected_result
logger.debug(f"{sys._getframe().f_code.co_name} start")
def test_select_all(init_data_fixture):
logger.debug("test_select_all start")
result = select_all()
output_name(result)
assert len(result) == 3
logger.debug("test_select_all end")
def test_update_age_by_name(init_data_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
target_name = "Rusty-Man"
update_age_by_name(target_name, 50)
heroes = select_by_name(target_name)
assert len(heroes) == 1
assert heroes[0].age == 50
logger.debug(f"{sys._getframe().f_code.co_name} end")
def output_name(target_list: List[Hero]):
for index, hero in enumerate(target_list):
logger.debug(f"index={index} hero.name={hero.name}")
HeroテーブルとTeamテーブルをjoinする処理
Team.nameを条件にHeroを検索する処理を実装しています。
from typing import List
from sample.hero import Hero
from sample.team import Team
from sample.common_function import get_engine, create_session
from sqlmodel import Field, Session, SQLModel, create_engine, select, delete, or_
from logging import getLogger, StreamHandler, DEBUG
from sqlalchemy.orm.session import sessionmaker
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
def select_heroes_by_team_name(team_name: str) -> List[Hero]:
engine = get_engine()
with create_session(engine) as session:
statement = select(Hero, Team).where(
Hero.team_id == Team.id, Team.name == team_name
)
return session.exec(statement).all()
select(Hero, Team)
で対象テーブルがHeroとTeamであることを指定、whereでHero.team_id == Team.id
と指定された``team_name`でTeam.nameを絞り込むとの処理となります。
テスト実行時のログ(抜粋)は以下のようになります。
INFO sqlalchemy.engine.Engine:log.py:117 BEGIN (implicit)
INFO sqlalchemy.engine.Engine:log.py:117 SELECT hero.id, hero.name, hero.secret_name, hero.age, hero.team_id, team.id AS id_1, team.name AS name_1, team.headquarters
FROM hero, team
WHERE hero.team_id = team.id AND team.name = ?
INFO sqlalchemy.engine.Engine:log.py:117 [no key 0.00015s] ('Preventers',)
INFO sqlalchemy.engine.Engine:log.py:117 ROLLBACK
join句は利用せず、fromに複数テーブルを指定し、where句でHeroテーブルとTeamテーブルの結合条件とTeamのnameの条件指定を行っているSQLになっていることが分かります。
HeroテーブルとTeamテーブルをjoinする処理のテスト
from multiprocessing.dummy.connection import Listener
import sys
import pytest
from logging import getLogger, StreamHandler, DEBUG
import sample.common_const as HeroConst
from sample.operation_hero_team import (
select_heroes_by_team_name,
)
from sample.hero import Hero
from sample.common_function import delete_all, init_data
logger = getLogger(__name__)
handler = StreamHandler()
handler.setLevel(DEBUG)
logger.setLevel(DEBUG)
logger.addHandler(handler)
logger.propagate = False
@pytest.fixture
def init_data_fixture():
delete_all()
init_data()
def test_select_heroes_by_team_name(init_data_fixture):
logger.debug(f"{sys._getframe().f_code.co_name} start")
result = select_heroes_by_team_name(HeroConst.TEAM_NAME_PREVENTERS)
assert len(result) == 2
result = select_heroes_by_team_name(HeroConst.TEAM_NAME_Z_FORCE)
assert len(result) == 1
logger.debug(f"{sys._getframe().f_code.co_name} end")
まとめ
ざっくりと動作確認をした感想ですが、良い意味で、既視感の強い作りですので、多くの方がストレスなく利用できるプロダクトだと感じます。
今回は触りませんでしたが、async sessionにも対応しており、FastAPIと一緒に利用すると作業が捗りそうですね。
懸念材料ですが、バージョンが0.0.6であり、まだまだ足りない部分も多いですし、「Breaking Changes」が行われる可能性が高いと感じます。これからのプロダクトですので、より良い物になっていくことを期待してSQLModelを選択する方も多いと思いますので、今後も注目していきたいです。
ソースはgithubに登録しております。