はじめに
FastAPIでGraphQLサーバを構築する際に、FastAPIのドキュメントではStrawberryというライブラリを使用することが推奨されています。
GraphQLの勉強も兼ねて、試しに使ってみました。
Strawberryとは
Python向けのGraphQLライブラリです。PythonだとGrapheneが有名だと思います。FastAPIの設計思想と似ており、FastAPIでGraphQLサーバを構築する際に、使用することが推奨されています(もちろん、Grapheneなどの他のライブラリでも可能です)。FastAPI以外にもDjangoやFlask、Chaliceなど様々なフレームワークをサポートしています。
実施すること
タスク管理を題材にして、CRUDに相当するものをStrawberryを用いて実装したいと思います。
ファイル構成
最終的なファイル構成は以下です。
.
├── main.py
└── tasks
├── __init__.py
├── inputs.py
├── models.py
├── repositories.py
├── resolvers.py
├── services.py
└── types.py
インストール
以下の2つのライブラリをインストールします。
pip install 'strawberry-graphql[debug-server]'
pip install 'uvicorn[standard]'
前準備
GraphQLとは、あまり関連しない部分を簡単に実装しておきます。
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Optional
class Status(Enum):
TODO = "todo"
DOING = "doing"
DONE = "done"
@dataclass
class Task:
title: str
id: uuid.UUID = field(default_factory=uuid.uuid4)
description: Optional[str] = None
status: Status = Status.TODO
created_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
from copy import deepcopy
from typing import Any
from tasks.models import Task
class InMemoryRepository:
_store: dict[str, Any] = {}
@classmethod
def find_by_id(cls, id_: str) -> Task:
task = cls._store.get(id_)
if task is None:
raise Exception("Not Found")
return task
@classmethod
def find_all(cls) -> list[Task]:
tasks = list(cls._store.values())
return tasks
@classmethod
def save(cls, task: Task) -> None:
cls._store[str(task.id)] = deepcopy(task)
@classmethod
def delete(cls, task: Task) -> None:
del cls._store[str(task.id)]
from datetime import datetime
from typing import Optional
from tasks.models import Status, Task
from tasks.repositories import InMemoryRepository
class TaskService:
def __init__(self, repo: type[InMemoryRepository]) -> None:
self._repo = repo
@property
def repo(self) -> type[InMemoryRepository]:
return self._repo
def find(self, id: str) -> Task:
task = self.repo.find_by_id(id)
return task
def find_all(self) -> list[Task]:
tasks = self.repo.find_all()
return tasks
def create(self, *, title: str, description: Optional[str] = None) -> Task:
task = Task(title=title, description=description)
self.repo.save(task)
return task
def update(self, *, id: str, status: Status) -> Task:
task = self.repo.find_by_id(id)
task.status = status
task.updated_at = datetime.utcnow()
self.repo.save(task)
return task
def delete(self, id: str) -> Task:
task = self.repo.find_by_id(id)
self.repo.delete(task)
return task
Object Type
上記までで、前準備ができました。ここからGraphQL関連を実装していきます。
まずObject Tyeを定義します。GraphQLの型とPythonの型の対応の詳細はこちらをご覧ください。また、Enum
で定義したStatus
はstrawberry.enum
を用いて型を定義します。
strawberry.type
の引数にname="Task"
を指定して、Task
という名前で扱えるようにします。
from datetime import datetime
from typing import Optional
import strawberry
from tasks.models import Status, Task
StatusType = strawberry.enum(Status, name="Status")
@strawberry.type(name="Task")
class TaskType:
id: strawberry.ID
title: str
description: Optional[str]
status: StatusType
created_at: datetime
updated_at: datetime
@classmethod
def from_instance(cls, instance: Task) -> "TaskType":
data = instance.__dict__
return cls(**data)
Input Type
タスクの新規作成と更新用に、strawberry.input
を用いてInput Typeを定義します。
from typing import Optional
import strawberry
from tasks.types import StatusType
@strawberry.input
class AddTaskInput:
title: str
description: Optional[str] = None
@strawberry.input
class UpdateTaskInput:
id: strawberry.ID
status: StatusType
Resolver
Resolverを定義します。関数の引数がGraphQLに反映されます。
add_task(task_input: AddTaskInput)
の場合、GraphQLではtaskInput: AddTaskInput!
となり、スネークケースはキャメルケースに変換されます。
import strawberry
from tasks.inputs import AddTaskInput, UpdateTaskInput
from tasks.repositories import InMemoryRepository
from tasks.services import TaskService
from tasks.types import TaskType
def get_task(id: strawberry.ID) -> TaskType:
db = InMemoryRepository
service = TaskService(db)
task = service.find(id)
return TaskType.from_instance(task)
def get_tasks() -> list[TaskType]:
db = InMemoryRepository
service = TaskService(db)
tasks = service.find_all()
return [TaskType.from_instance(task) for task in tasks]
def add_task(task_input: AddTaskInput) -> TaskType:
db = InMemoryRepository
service = TaskService(db)
task = service.create(**task_input.__dict__)
return TaskType.from_instance(task)
def update_task(task_input: UpdateTaskInput) -> TaskType:
db = InMemoryRepository
service = TaskService(db)
task = service.update(**task_input.__dict__)
return TaskType.from_instance(task)
def delete_task(id: strawberry.ID) -> TaskType:
db = InMemoryRepository
service = TaskService(db)
task = service.delete(id)
return TaskType.from_instance(task)
QueryとMutation
最後に、Query、Mutationを定義します。こちらも定義したfieldがGraphQLに反映されます。スネークケースはキャメルケースに変換され、type hintが戻り値になります。
import strawberry
from tasks.resolvers import add_task, delete_task, get_task, get_tasks, update_task
from tasks.types import TaskType
@strawberry.type
class Query:
task: TaskType = strawberry.field(resolver=get_task)
tasks: list[TaskType] = strawberry.field(resolver=get_tasks)
@strawberry.type
class Mutation:
task_add: TaskType = strawberry.field(resolver=add_task)
task_update: TaskType = strawberry.field(resolver=update_task)
task_delete: TaskType = strawberry.field(resolver=delete_task)
schema = strawberry.Schema(query=Query, mutation=Mutation)
サーバの起動
以下のコマンドでサーバを起動します。http://localhost:8000/graphql
にアクセスすると、GraphiQLの画面が表示されます。
strawberry server main
終わりに
今回は、Strawberryを用いてGraphQLサーバを構築してみました。
Srawberryはまだ開発初期段階ですので、使用する際はドキュメントを確認してみてください。