概要
PythonでSQLAlchemyを使用してデータベースに接続し、データを取得する方法を備忘も兼ねて紹介します。
流れとコード
①DB接続の準備
②CRUD操作の実装
③データの取得
①DB接続の準備
まずはデータベースエンジンとセッションを作成します。
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
engine = create_engine("mysql+pymysql://user:pw@host:port/dbname", pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
create_engine
関数は、データベースエンジンを作成するための関数、
sessionmaker
関数は、データベースセッションを作成するための関数です。
create_engine
の引数にあるpool_pre_ping=True
は、接続が有効かどうかを確認するための「ping
」操作を有効にするもの。これにより、接続が切れている場合に自動的に再接続を試みます。
If the ping / error check determines that the connection is not usable, the connection will be immediately recycled, and all other pooled connections older than the current time are invalidated, so that the next time they are checked out, they will also be recycled before use.
https://docs.sqlalchemy.org/en/20/core/pooling.html
コネクションプールについては、いくつかアプローチがあると思いますので、適宜要件に合わせて検討ください(pool_recycle
というのもある)。
次に、データベースセッションのコンテキストマネージャを使用して「セッションを呼び出し元に返す〜呼び出し元のコードが実行される〜呼び出し元のコードが終了した後にセッションを閉じる」を以下で実施します。
from typing import Generator
from contextlib import contextmanager
from .session import SessionLocal
@contextmanager
def get_db() -> Generator:
session = SessionLocal()
try:
yield session
finally:
session.close()
ここでyield
を使っている理由は以下に記載ある通り、リソースリークを防止するためです。yield
を使用することで、セッションを一時的に呼び出し元に返し、呼び出し元のコードが終了した後にセッションを確実に閉じることができるためです。
②モデルとCRUD操作の実装
次に、country
モデルからcountry_code
の値を利用してcountry_name
を取得する関数を実装します。CRUD操作のベースクラスを利用しています。
from pydantic import BaseModel
from sqlalchemy.orm import Session
from cruds.base import CRUDBase
from models.country import Country
class CountryTable(CRUDBase[Country, BaseModel, BaseModel]):
def get_country_name(self, db: Session, country_code):
country = db.query(self.model).filter(Country.country_code == country_code).first()
if country is None:
raise ValueError("Country not found")
return country.country_name
country = CountryTable(Country)
③データの取得
最後に、上記で実装した関数を呼び出してデータを取得するコードです。
from sqlalchemy.orm import Session
import cruds
from dependency import get_db
country_code = "100"
with get_db() as db:
country_name = cruds.country.get_country_name(db, country_code)
print(country_name)
with get_db() as db
でget_db
関数を使用して上述のデータベースセッションを取得しています。
このセッションを使用して、get_country_name
メソッドを呼び出し、指定されたcountry_code
に対応するcountry_name
を取得することができました。