9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

FastAPI と Strawberry で作る GraphQL Server ~SQLAlchemy を添えて~

Last updated at Posted at 2024-10-16

※ 注: こちらの記事は PyCon APAC 2024 に向けて、スライド作成の下書き的な立ち位置で書いている内容です。現時点では、適宜修正している可能性があります。

はじめに

こんにちは。株式会社PORTAMENTのCTOをやっているzoetaka38(川添)です。
弊社の OpsGuide(オプスガイド)というサービスでは、Python で GraphQL サーバーを構築しています。

今回は、弊社のこのサービスで GraphQL サーバーの構築に使っている Strawberry というフレームワークについて、実践的な利用法と一緒に紹介させていただきたいと思います。

コード

本記事のコードは下記にあります。参考にしてください。
https://github.com/zoetaka38/poc-graphql-react-fastapi/tree/099ffa653149f8721dec3215645fa36aa8dae52a

Strawberryとは?

2024年10月16日時点での、公式での紹介文は以下です。

image.png

URL: https://strawberry.rocks/

型ヒントや、非同期処理のサポート、バッチクエリなど、GraphQLサーバーを構築する上で、特に FastAPI と組み合わせる上で非常に優れたライブラリです。

FastAPI の公式ページでも、オススメされています!

image.png

URL: https://fastapi.tiangolo.com/how-to/graphql/#graphql-libraries

それでは、具体的な紹介に入っていきたいと思います。

どうやって FastAPI に組み込むの?

非常に簡単に組み込めると思います。
公式ドキュメントでも記載はあるので、それに沿って行えば十分できるかと思います。

取り敢えず組み込むまでのやり方は、以下の 3 ステップです。

  1. I/O interface として、input と scalarを定義
  2. Mutation と Query schemaを定義
  3. GraphQLRouter を定義してFastAPI のrouter として登録

1. I/O interface として、input と scalarを定義

Mutation の interface となる input も、 Mutation や Query の Output スキーマになる scalar も、 標準の dataclasses のような書き方で定義が可能です。

Scalar のサンプル

from dataclasses import field

import strawberry

from app.graphql.scalars.stickynotes_scalar import StickyNoteScalar


@strawberry.type
class UserScalar:
    id: int
    name: str | None = ""
    stickynotes: list[StickyNoteScalar] = field(default_factory=list)

Input のサンプル

import strawberry


@strawberry.input
class UserInput:
    id: int | None = None
    name: str

2. Mutation と Query schemaを定義

次に、定義した I/O に沿ってレスポンスを返す Mutation と Query を作成します。中身の関数は後で作成するとして、以下のように作ります。

Query のサンプル

@strawberry.type
class Query:

    @strawberry.field
    async def users(self, info: Info) -> list[UserScalar]:
        """Get all users"""
        users_data_list = await get_users(info)
        return users_data_list

Mutation のサンプル

@strawberry.type
class Mutation:

    @strawberry.mutation
    async def add_user(self, name: str) -> AddUserResponse:
        """ Add user """
        add_user_resp = await add_user(name)
        return add_user_resp

3. GraphQLRouter を定義してFastAPI のrouter として登録

最後に、これらの Query と Mutation を GraphQLRouter として登録して、FastAPI の router に追加してあげればOKです。

app.include_router で追加したエンドポイントに、Strawberryで作成したQueryやMutationが登録されます。
管理画面用などで分ける場合は、別のエンドポイントとして登録すればOKです。

import strawberry
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig

from app.graphql.schemas.mutation_schema import Mutation
from app.graphql.schemas.query_schema import Query

schema = strawberry.Schema(
    query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=True)
)

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
graphql_app = GraphQLRouter(schema)
app.include_router(graphql_app, prefix="/graphql")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8010)

REST API との共存

もちろん、通常の FastAPI の REST API 機能を生かしたまま、GraphQL サーバーを構築することができます。

import strawberry
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig

+ from app.api.main import router as api_router
from app.graphql.schemas.mutation_schema import Mutation
from app.graphql.schemas.query_schema import Query

schema = strawberry.Schema(query=Query, mutation=Mutation, config=StrawberryConfig(auto_camel_case=True))

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
+ app.include_router(api_router, prefix="/api")

graphql_app = GraphQLRouter(schema)
app.include_router(graphql_app, prefix="/graphql")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8010)

基本的に GraphQL のエンドポイントは一つだけなので、それとバッティングしなようにだけ気をつければ問題ないです。

なぜ共存?

改めて、なぜ Python の FastAPI で GraphQL サーバーを立てるといいかを考えてみますと、同じアプリケーション上に共存できることで、パフォーマンス改善に取り組みやすいことが考えられると思います。

よくある構成として、REST API サーバーとフロントを中継するものとして、GraphQL サーバーを立てることがあるかと思いますが、その場合 N + 1 問題などのパフォーマンスに大きく影響する問題が起きやすいです。

一方、アプリケーション本体と共存しておくと、データベースのモデル定義などを共有でき、クエリもそれを通じて発行して直接データベースにクエリできるので、パフォーマンス改善を行いやすいです。

弊社のアプリケーションでは、フロントエンドとのやり取りは GraphQL エンドポイントを使い、サーバー間での通信においては REST API エンドポイントを使っています。

GraphQL パフォーマンスを上げるための工夫

それではここまでは導入として、ここから具体的な Strawberry を使った GraphQL のプラクティスを紹介していきたいと思います。

フロントエンドのクエリに応じた必要な項目の取得

GraphQL では、フロントエンドから必要な項目を選択して取得することができます。データベースから取得するデータに関しても、フロントエンドからのクエリに応じて取得する項目を切り替えることで、データベースアクセスの負荷を最小限にすることができます。

GraphQL のリゾルバーでは、info オブジェクトを通じてクエリ情報にアクセスできます。
以下のような関数で、info オブジェクトからデータベースに対するクエリの対象を取得することができます。

def convert_camel_case(name):
    pattern = re.compile(r"(?<!^)(?=[A-Z])")
    name = pattern.sub("_", name).lower()
    return name


def get_only_selected_fields(db_baseclass_name, info):
    db_relations_fields = inspect(db_baseclass_name).relationships.keys()
    result = {"base": [], "relations": {}}

    for field in info.selected_fields[0].selections:
        field_name = convert_camel_case(field.name)
        if field.name not in db_relations_fields:
            result["base"].append(field_name)
        else:
            result["relations"][field_name] = [convert_camel_case(sub_field.name) for sub_field in field.selections]

    return result

この関数を使うと、例えばこのようなクエリですと、

query MyQuery {
  users {
    id
    name
    stickynotes {
      id
      text
    }
  }
}

以下のようなデータが取得できます。

{"base": ["id", "name"], "relations": {"stickynotes": ["id", "text"]}}

これをターゲットのモデルや階層化されたモデルのScalarなどを渡して上げることで、必要な項目だけを抽出して Strawberry の Scalar に変換する関数を用意します。

scalar_mapping = {
    "stickynotes": StickyNoteScalar,
    # Add other mappings here if needed
}

async def fetch_data(scalar_mapping, selected_fields, model, data_scalar):
    async with get_session() as s:
        # load_onlyで必要な項目のみクエリを行う。
        sql = (
            select(model)
            .options(load_only(*[getattr(model, attr) for attr in selected_fields["base"]]))
            .order_by(model.id)
        )
        for key, value in selected_fields["relations"].items():
            relation_attr = getattr(model, key)
            sql = sql.options(
                selectinload(relation_attr).load_only(*[getattr(relation_attr.mapper.class_, attr) for attr in value])
            )
        db_data = (await s.execute(sql)).scalars().all()
        data_list = []
        for user in db_data:
            data_dict = get_valid_data(user, model)
            # リレーションデータが必要であれば取得する。
            for relation, fields in selected_fields["relations"].items():
                related_objects = getattr(user, relation)
                data_dict[relation] = [get_valid_data(obj, obj.__class__) for obj in related_objects]

            for relation, scalar in scalar_mapping.items():
                if relation in data_dict:
                    data_dict[relation] = [scalar(**note) for note in data_dict[relation]]
            data_list.append(data_scalar(**data_dict))
        return data_list

最終的に、リゾルバーでこちらの関数を使うことができます。

async def get_users(info):
    """Get all users resolver"""
    selected_fields = get_only_selected_fields(User, info)
    return await fetch_data(scalar_mapping, selected_fields, User, UserScalar)

補足として、単一のデータを取得する場合は、以下のような関数です。
違いは、取得対象のidを指定してクエリできるようにしていることです。

async def fetch_single_data(scalar_mapping, selected_fields, model, data_scalar, id):
    async with get_session() as s:
        sql = (
            select(model)
            .options(load_only(*[getattr(model, attr) for attr in selected_fields["base"]]))
            .filter_by(id=id)
        )
        for key, value in selected_fields["relations"].items():
            relation_attr = getattr(model, key)
            sql = sql.options(
                selectinload(relation_attr).load_only(*[getattr(relation_attr.mapper.class_, attr) for attr in value])
            )
        db_data = (await s.execute(sql)).scalars().first()
        if db_data is None:
            return None

        data_dict = get_valid_data(db_data, model)
        for relation, fields in selected_fields["relations"].items():
            related_objects = getattr(db_data, relation)
            data_dict[relation] = [get_valid_data(obj, obj.__class__) for obj in related_objects]

        for relation, scalar in scalar_mapping.items():
            if relation in data_dict:
                data_dict[relation] = [scalar(**note) for note in data_dict[relation]]

        return data_scalar(**data_dict)

Dataloader の活用

以上で、GraphQLクエリに対して常にデータベースから全部の項目を取得する必要はなくなり、データベースの負荷を下げることができました。
次に、GraphQL で考えなければいけない問題として、N + 1 問題があります。GraphQLにおいて多くの場合、負荷の観点から問題になります。

この問題への対応として、Dataloader という仕組みが GraphQL にはあります。 Strawberry にもこの Dataloader の仕組みが実装されているので、こちらを使うことでクエリの回数を最小限に抑えることができます。

DataLoaderとは、データソースからのデータ取得を Batch 処理したり結果を Cache したりするためにGraphQL上で用意されたAPIです。

具体的には、ユーザーに紐づくメモを取得する際にユーザーの数だけ走る可能性のあるメモ取得クエリを、全ユーザーのメモを取得するクエリ1回だけに留めることができます。

このことを実現するためには、fetch_data関数を修正する必要があります。
具体的には、以下の3つの機能を持つ関数に分解します。

  1. ベースモデルのデータをバッチでロードする関数
  2. リレーションのデータをバッチでロードする関数
  3. メインのfetch_data関数(上記の関数を利用してデータを組み立てます)

これらに分解するのは、それぞれの役割を明確にすることと、DataLoader インスタンスに適切に情報を渡せるようにするためです。

1. ベースモデルのデータをバッチでロードする関数

まずは、ベースモデルのデータをバッチで取得する関数を用意します。
指定された model_ids に対応するモデルのインスタンスを一度に取得します。
load_only を使用して必要な項目だけを取得する機能は残しています。

async def batch_load_models(model_ids, model, selected_fields):
    async with get_session() as s:
        sql = (
            select(model)
            .filter(model.id.in_(model_ids))
            .options(load_only(*[getattr(model, attr) for attr in selected_fields["base"]]))
        )
        db_data = (await s.execute(sql)).scalars().all()
        # IDをキーとする辞書を作成
        model_map = {obj.id: obj for obj in db_data}
        # 入力の順序に合わせて結果を返す
        return [model_map.get(model_id) for model_id in model_ids]

2. リレーションのデータをバッチでロードする関数

次に、リレーションのデータも一度のクエリで取得できるようにします。

def get_fk_attribute_name(related_model, parent_model):
    for column in related_model.__table__.columns:
        for fk in column.foreign_keys:
            if fk.column.table == parent_model.__table__:
                return column.name
    return None


async def batch_load_relations(parent_ids, parent_model, relation_name, selected_fields):
    relation_attr = getattr(parent_model, relation_name)
    related_model = relation_attr.property.mapper.class_

    fk_attribute_name = get_fk_attribute_name(related_model, parent_model)
    if fk_attribute_name is None:
        raise ValueError(f"{related_model.__name__}{parent_model.__name__}への外部キーが見つかりませんでした。")

    async with get_session() as s:
        sql = (
            select(related_model)
            .filter(getattr(related_model, fk_attribute_name).in_(parent_ids))
            .options(load_only(*[getattr(related_model, attr) for attr in selected_fields + [fk_attribute_name]]))
        )
        db_data = (await s.execute(sql)).scalars().all()

        # 親IDごとに関連オブジェクトをマッピング
        relation_map = {}
        for obj in db_data:
            parent_id = getattr(obj, fk_attribute_name)
            relation_map.setdefault(parent_id, []).append(obj)

        # 入力の順序に合わせて結果を返す
        return [relation_map.get(parent_id, []) for parent_id in parent_ids]

3. メインのfetch_data関数

Dataloaderの初期化しつつ、それを用いてデータの取得をしてそれらをStrawberryで返却できるようにScalarに組み立てていきます。

async def fetch_data(scalar_mapping, selected_fields, model, data_scalar, target_ids=[]):
    # Dataloaderのインスタンスを作成
    model_loader = DataLoader(load_fn=lambda ids: batch_load_models(ids, model, selected_fields))
    relation_loaders = {}
    for relation, fields in selected_fields["relations"].items():
        # クロージャの問題を回避するために、デフォルト引数を使用
        relation_loaders[relation] = DataLoader(
            load_fn=lambda parent_ids, r=relation, f=fields: batch_load_relations(parent_ids, model, r, f)
        )

    # 対象のIDリストを取得
    if target_ids:
        ids = target_ids
    else:
        # 全てのIDを取得
        async with get_session() as s:
            ids = (await s.execute(select(model.id).order_by(model.id))).scalars().all()

    # ベースモデルのデータをロード
    db_data = await model_loader.load_many(ids)
    data_list = []
    for obj in db_data:
        if obj is None:
            continue  # データが存在しない場合はスキップ
        data_dict = get_valid_data(obj, model)

        # リレーションのデータをロード
        for relation, loader in relation_loaders.items():
            related_objects = await loader.load(getattr(obj, 'id'))
            data_dict[relation] = [get_valid_data(rel_obj, rel_obj.__class__) for rel_obj in related_objects]

            # スカラーに変換
            if relation in scalar_mapping:
                scalar = scalar_mapping[relation]
                data_dict[relation] = [scalar(**item) for item in data_dict[relation]]

        data_list.append(data_scalar(**data_dict))
    return data_list

以上で、データベースからDataLoaderを使ってデータを取得することができるようになりました。
N+1問題の軽減のみではなく、同じクエリに対してはCacheも効くようになるので、その分の負荷低減も見込めます。

Other Tips

サブスクリプションの導入

次に、サブスクリプションの導入についてです。こちらは特段大きな工夫というほどでもないですが、redisに接続して Pub/Sub を実現します。

まずはRedisクライアントを作成します。

from typing import AsyncGenerator

import redis.asyncio as redis

from app.settings import settings


async def get_redis() -> AsyncGenerator[redis.Redis, None]:
    async with redis.from_url(settings.REDIS_URL) as redis_client:
        try:
            yield redis_client
        finally:
            await redis_client.close()

次に、Pub/Sub の Broker を作成します。

import asyncio
import dataclasses
import json
from datetime import datetime
from typing import AsyncGenerator, Optional

import redis.asyncio as redis
import strawberry
from strawberry.types import Info

from app.graphql.db.redis import get_redis
from app.graphql.resolvers.stickynote_resolver import get_stickynote, get_stickynotes
from app.graphql.scalars.stickynotes_scalar import StickyNoteScalar


class StickyNoteSubscriptionBroker:
    channel = "channel:StickyNote"

    async def publish(self, stickynote: StickyNoteScalar, redis: redis.Redis):
        def datetime_serializer(obj):
            if isinstance(obj, datetime):
                return obj.isoformat()
            raise TypeError(f"Type {type(obj)} not serializable")

        payload = json.dumps(dataclasses.asdict(stickynote), default=datetime_serializer)
        await redis.publish(self.channel, payload)

    async def subscribe(self, redis: redis.Redis) -> dict:  # type: ignore
        pubsub = redis.pubsub()
        await pubsub.subscribe(self.channel)
        async for message in pubsub.listen():
            if message["type"] != "message":
                continue
            yield message

次に、Subscription Schema の定義をします。
先程の Broker を読み込んでおきます。

import json
from dataclasses import dataclass, fields
from datetime import datetime
from typing import AsyncGenerator, Optional

import strawberry
from strawberry.types import Info

from app.graphql.brokers.stickynote_broker import StickyNoteSubscriptionBroker
from app.graphql.db.redis import get_redis
from app.graphql.scalars.stickynotes_scalar import StickyNoteScalar


@strawberry.type
class Subscription:
    @strawberry.subscription
    async def subscribe_stickynote(self, info: Info) -> AsyncGenerator[Optional[StickyNoteScalar], None]:
        async for redis_client in get_redis():
            async for message in stickynote_subscriptions.subscribe(redis_client):
                data = json.loads(message["data"])
                key_type_list = [{"key": field.name, "type": field.type} for field in fields(StickyNoteScalar)]
                for key_type in key_type_list:
                    if key_type["type"] == datetime or key_type["type"] == datetime | None:
                        data[key_type["key"]] = datetime.fromisoformat(data[key_type["key"]])
                stickynote = StickyNoteScalar(**data)
                yield stickynote


stickynote_subscriptions = StickyNoteSubscriptionBroker()

最後に、Mutation などで Subscription に対して Publish する処理を入れると、Subscription の完成です。
あとは、フロント側で適切に処理をしましょう。

@strawberry.type
class Mutation:

    @strawberry.mutation
    async def add_stickynotes(self, text: str, user_id: int, info: Info) -> AddStickyNotesResponse:
        """Add sticky note"""
        async for redis_client in get_redis():
            add_stickynotes_resp = await add_stickynotes(text, user_id)
            await stickynote_subscriptions.publish(add_stickynotes_resp, redis_client)
            return add_stickynotes_resp

最後に

いかがでしたでしょうか?
今回は、弊社サービスでも使っているGraphQLのプラクティスを紹介させていただきました。もっとこうした方がいいよ、というアドバイスやご指摘も大歓迎です!

弊社ではエンジニアを常に募集しておりますので、もしご興味があれば下記Twitterにてご連絡ください!
https://twitter.com/zoetaka38

9
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?