41
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

FastAPI + SQLAlchemy(postgresql)によるCRUD API実装ハンズオン

Last updated at Posted at 2020-10-24

最終的な成果物

編集履歴

・2021/4/23
認証関連のコードを追加して、Qiitaも更新

・2021/4/18
全体的にリファクタリングし、Qiitaも更新

・2020/10/24
初版

このハンズオンで実装するもの

  • FastAPIのDocker環境(Nginxコンテナ, Applicationコンテナ, DBコンテナ)

  • alembic環境

    • DBマイグレーション用のツール
  • ユーザー情報のモデル

    • DBに作成するテーブルの元になるもの。マイグレーションツールで使用します。
  • データアクセスクラス

  • ユーザー情報のCRUDを行うAPI

    • 2020/10/27現在 パスワードのハッシュ化とかは未考慮
  • 各種ミドルウェア

    • API実行前の処理を行うミドルウェア
  • APIのテストコード

    • テスト用のDBを作成し、テストケース実行ごとにDBロールバック、テストが全て完了したらテスト用のDB削除
  • CORS問題の回避

  • カスタム例外(アプリケーション例外とシステム例外)

アーキテクチャ

python: v3.8.5
postgresql: v12.4
fastapi: v0.60.2
SQLAlchemy: v1.3.18

※ v1.3までのSQLAlchemyに依存しているものが多いので、SQLAlchemyはとりあえずv1.3系を使ったほうが良さそう

マイグレーションツール

alembic: v1.4.2

FastAPIとは

公式より

FastAPI は、Pythonの標準である型ヒントに基づいてPython 3.6 以降でAPI を構築するための、モダンで、高速(高パフォーマンス)な、Web フレームワークです。

ドキュメントがかなり豊富で、Swaggerと互換性あるのが素晴らしいです(Swaggerドキュメントをコードから自動生成)。
また、非同期処理を行うためのクラス(BackgroundTasks)も用意されているため、Celeryなどのインストールが不要です。
(パフォーマンスも良いらしいがベンチマーク測ったことないので断言できない、でも多分速い。)
 
しっかり触ったことのあるフレームワークはDjangoだけなのでミドルウェアやテストコード周りは苦労しました・・・。
(Djangoはフルスタックフレームワークで、勝手にいろいろやってくれるからミドルウェアとかテストの実装環境とかあまり気にしたことない)

Python3.8.5の仮装環境作成

pyenv, virtualenvが入ってない場合は下記を参考に入れてください。
Mac
Windows

$ pyenv virtualenv 3.8.5 env_fastapi_sample

仮装環境を適用

プロジェクトルートを作成

$ mkdir fastapi_sample

プロジェクトルートにcd

$ cd fastapi_sample

仮装環境を適用

$ pyenv local env_fastapi_sample

requirements.txt作成/インストール

$ touch requirements.txt

requirements.txtにfastapiを追記

fastapi_sample/requirements.txt
fastapi==0.60.2
uvicorn==0.11.8

仮装環境にインストール

$ pip3 install -r requirements.txt

エントリーポイント(main.py)を作成して、Swagger-UIを表示してみる

エントリーポイント(main.py)作成

$ touch main.py

main.pyを編集

fastapi_sample/main.py
from fastapi import FastAPI

app = FastAPI()


@app.get("/")
async def root():
    return {"message": "Hello World"}

Swagger-UI表示

$ uvicorn main:app --reload

ブラウザから「http://127.0.0.1:8000/docs」にアクセスしてSwaggerUIが表示され、そこに上述のエントリーポイントが表示されていれば完了です。

Docker化

Nginxのリバースプロキシによってアプリケーションをhttpsで公開するようにします。

ベースのdocker-compose.yml

fastapi_sample/docker-compose.yml
version: '3'
services:
# Nginxコンテナ
  nginx:
    container_name: nginx_fastapi_sample
    image: nginx:alpine
    depends_on:
      - app
      - db
    environment:
      TZ: "Asia/Tokyo"
    ports:
      - "80:80"
      - "443:443"
    volumes:
      - ./docker/nginx/conf.d:/etc/nginx/conf.d
      - ./docker/nginx/ssl:/etc/nginx/ssl

# アプリケーションコンテナ
  app:
    build:
      context: .
      dockerfile: Dockerfile
    container_name: app_fastapi_sample
    volumes:
      - '.:/fastapi_sample/'
    environment:
      - LC_ALL=ja_JP.UTF-8
    expose:
      - 8000
    depends_on:
      - db
    entrypoint: /fastapi_sample/docker/wait-for-it.sh db 5432 postgres postgres db_fastapi_sample
    command: bash /fastapi_sample/docker/rundevserver.sh
    restart: always
    tty: true

# DBコンテナ
  db:
    image: postgres:12.4-alpine
    container_name: db_fastapi_sample
    environment:
      - POSTGRES_USER=postgres
      - POSTGRES_PASSWORD=postgres
      - POSTGRES_DB=db_fastapi_sample
      - POSTGRES_INITDB_ARGS=--encoding=UTF-8 --locale=C
    volumes:
      - db_data:/var/lib/postgresql/data
    ports:
      - '5432:5432'

volumes:
  db_data:
    driver: local

Nginxコンテナの設定

####confファイルを用意する

$ mkdir -p nginx/conf.d
$ touch nginx/conf.d/app.conf
fastapi_sample/docker/nginx/conf.d/app.conf
upstream backend {
    server app:8000;  # appはdocker-compose.ymlの「app」
}

# 80番ポートへのアクセスは443番ポートへのアクセスに強制する
server {
    listen 80;
    return 301 https://$host$request_uri;
}

server {
    listen 443 ssl;
    ssl_certificate     /etc/nginx/ssl/server.crt;
    ssl_certificate_key /etc/nginx/ssl/server.key;
    ssl_protocols        TLSv1.2 TLSv1.3;

    location / {
        proxy_set_header    Host    $http_host;
        proxy_set_header    X-Real-IP    $remote_addr;
        proxy_set_header    X-Forwarded-Host      $http_host;
        proxy_set_header    X-Forwarded-Server    $http_host;
        proxy_set_header    X-Forwarded-Server    $host;
        proxy_set_header    X-Forwarded-For    $proxy_add_x_forwarded_for;
        proxy_set_header    X-Forwarded-Proto  $scheme;
        proxy_redirect      http:// https://;

        proxy_pass http://backend;
    }

    # ログを出力したい場合はコメントアウト外してください
    # access_log /var/log/nginx/access.log;
    # error_log /var/log/nginx/error.log;
}

server_tokens off;

NginxのSSL化に必要なファイルを用意

$ mkdir -p nginx/conf.d nginx/ssl

OpenSSLを使って秘密鍵(server.key)を生成する

# ディレクトリ移動
$ cd nginx/conf.d

# 秘密鍵生成
$ openssl genrsa 2024 > server.key

# 確認
$ ls -ll
total 8
-rw-r--r--  1 user  staff  1647 10 24 13:23 server.key

証明書署名要求(server.csr)を生成

$ openssl req -new -key server.key > server.csr
$ openssl req -new -key server.key > server.csr
...
..
.
Country Name (2 letter code) [AU]:JP  # 国を示す2文字のISO略語
State or Province Name (full name) [Some-State]:Tokyo  # 会社が置かれている都道府県
Locality Name (eg, city) []:Chiyodaku  # 会社が置かれている市区町村
Organization Name (eg, company) [Internet Widgits Pty Ltd]:fabeee  # 会社名
Organizational Unit Name (eg, section) []:-  # 部署名(ハイフンにしました。)
Common Name (e.g. server FQDN or YOUR name) []:localhost(ウェブサーバのFQDN。一応localhostにした)
Email Address []:  # 未入力でエンター

Please enter the following 'extra' attributes
to be sent with your certificate request
A challenge password []:  # 未入力でエンター
An optional company name []:  # 未入力でエンター
...
..
.
$ ls -ll
total 16
-rw-r--r--  1 user  staff   980 10 24 13:27 server.csr  # 証明書署名要求ができた
-rw-r--r--  1 user  staff  1647 10 24 13:23 server.key

サーバ証明書(server.crt)を生成

openssl x509 -req -days 3650 -signkey server.key < server.csr > server.crt

% ls -ll
total 24
-rw-r--r--  1 tabata  staff  1115 10 24 13:40 server.crt  # サーバ証明書ができた
-rw-r--r--  1 tabata  staff   948 10 24 13:29 server.csr
-rw-r--r--  1 tabata  staff  1647 10 24 13:23 server.key

サーバ証明書を信頼する

server.crtをダブルクリック
image.png

「この証明書を使用するとき」のプルダウンから「常に信頼する」を選択する
image.png

最終的にこうなっていればNginxの設定は完了です

fastapi_sample
├── docker
│   └── nginx
│       ├── conf.d
│       │   └── app.conf
│       └── ssl
│           ├── server.crt
│           ├── server.csr
│           └── server.key
├── docker-compose.yml
├── main.py
└── requirements.txt

app(FastAPI)コンテナの設定

Dockerfile

イメージはubuntu20.04です。
・pyenvでコンテナ内のpythonバージョンをv3.8.5にしている
・apt installで必要なモジュールをインストール(不要なものは削ってもらって構いません)
pip3 install -r requirements.txtでpythonモジュールをコンテナ内にインストール

FROM ubuntu:20.04
ENV DEBIAN_FRONTEND=noninteractive
ENV HOME /root
ENV PYTHONPATH /fastapi_sample/
ENV PYTHON_VERSION 3.8.5
ENV PYTHON_ROOT $HOME/local/python-$PYTHON_VERSION
ENV PATH $PYTHON_ROOT/bin:$PATH
ENV PYENV_ROOT $HOME/.pyenv

RUN mkdir /fastapi_sample \
    && rm -rf /var/lib/apt/lists/*

RUN apt update && apt install -y git curl locales python3-pip python3-dev python3-passlib python3-jwt \
    libssl-dev libffi-dev zlib1g-dev libpq-dev postgresql

RUN echo "ja_JP UTF-8" > /etc/locale.gen \
    && locale-gen

RUN git clone https://github.com/pyenv/pyenv.git $PYENV_ROOT \
    && $PYENV_ROOT/plugins/python-build/install.sh \
    && /usr/local/bin/python-build -v $PYTHON_VERSION $PYTHON_ROOT

WORKDIR /fastapi_sample
ADD . /fastapi_sample/
RUN LC_ALL=ja_JP.UTF-8 \
    && pip3 install -r requirements.txt

wait-for-it.sh

データベースが立ち上がるのを待ってからappコンテナを起動するようにするためのシェルスクリプトです。
 
docker-componseに記載しているdepends_onコンテナの起動順は制御できますが、
データベースが立ち上がっていない場合、appコンテナでエラーが発生することがあります。
(appコンテナ起動時にDB操作をするようなコマンドを実行しようとした場合、データベースが立ち上がっていないためにエラー、とか)

fastapi_sample/docker/wait-for-it.sh
#!/bin/sh

set -e

# 引数はdocker-compose.ymlで指定している。
# app:
#   ...
#   ..
#   .
#   entrypoint: /fastapi_sample/docker/wait-for-it.sh db 5432 postgres postgres db_fastapi_sample  # ココ
#   ...
host="$1"
shift
port="$1"
shift
user="$1"
shift
password="$1"
shift
database="$1"
shift
cmd="$@"

echo "Waiting for postgresql"
until pg_isready -h"$host" -U"$user" -p"$port" -d"$database"
do
  echo -n "."
  sleep 1
done

>&2 echo "PostgreSQL is up - executing command"
exec $cmd

# 僕は優しいのでMySQL用も書いてあげるのだ
# #!/bin/sh
# 
# set -e
# 
# host="$1"
# shift
# user="$1"
# shift
# password="$1"
# shift
# cmd="$@"
# 
# echo "Waiting for mysql"
# until mysql -h"$host" -u"$user" -p"$password" &> /dev/null
# do
#     >$2 echo -n "."
#     sleep 1
# done
# 
# >&2 echo "MySQL is up - executing command"
# exec $cmd

rundevserver.sh

uvicornを起動してアプリケーションを起動するためのシェルスクリプトファイルです。
docker-compose up時に毎回requirements.txtのモジュールを読み込む
・uvicornでアプリケーションを起動する
・--reloadでホットリロード(pythonファイルを変更すると即反映してくれる)
・--portを8000から変える場合はnginxのapp.conf内の8000も変える必要がある

pip3 install -r requirements.txt

uvicorn main:app\
    --reload\
    --port 8000\
    --host 0.0.0.0\
    --log-level debug

最終的なディレクトリ構成

fastapi_sample
├── Dockerfile
├── docker
│   ├── nginx
│   │   ├── conf.d
│   │   │   └── app.conf
│   │   └── ssl
│   │       ├── server.crt
│   │       ├── server.csr
│   │       └── server.key
│   ├── rundevserver.sh
│   └── wait-for-it.sh
├── docker-compose.yml
├── main.py
└── requirements.txt

コンテナ起動

$ docker-compose up -d

こんなエラーが出た場合は、Dockerイメージがディスクを逼迫している可能性があるので、 docker image pruneで不要なイメージを削除してください。(僕はこれで40GB分ディスクに空きができました😩)

.
..
...
Get:1 http://security.debian.org/debian-security buster/updates InRelease [65.4 kB]
Get:2 http://deb.debian.org/debian buster InRelease [122 kB]
Get:3 http://deb.debian.org/debian buster-updates InRelease [49.3 kB]
Err:1 http://security.debian.org/debian-security buster/updates InRelease
  At least one invalid signature was encountered.
Err:2 http://deb.debian.org/debian buster InRelease
  At least one invalid signature was encountered.
Err:3 http://deb.debian.org/debian buster-updates InRelease
  At least one invalid signature was encountered.
Reading package lists...
W: GPG error: http://security.debian.org/debian-security buster/updates InRelease: At least one invalid signature was encountered.
E: The repository 'http://security.debian.org/debian-security buster/updates InRelease' is not signed.
W: GPG error: http://deb.debian.org/debian buster InRelease: At least one invalid signature was encountered.
E: The repository 'http://deb.debian.org/debian buster InRelease' is not signed.
W: GPG error: http://deb.debian.org/debian buster-updates InRelease: At least one invalid signature was encountered.
E: The repository 'http://deb.debian.org/debian buster-updates InRelease' is not signed.
...
..
.

ブラウザからアクセス

コンテナの起動が完了したか確認します。
STATUS列が"Up..."となっていればOKです。

$ docker ps
CONTAINER ID        IMAGE                  COMMAND                  CREATED              STATUS              PORTS                                      NAMES
b77fa2465fb1        nginx:alpine           "/docker-entrypoint.…"   About a minute ago   Up About a minute   0.0.0.0:80->80/tcp, 0.0.0.0:443->443/tcp   nginx_fastapi_sample
eb18438efd2c        fastapi_sample_app     "/fastapi_sample/doc…"   About a minute ago   Up About a minute   8000/tcp                                   app_fastapi_sample
383b0e46af68        postgres:12.4-alpine   "docker-entrypoint.s…"   About a minute ago   Up About a minute   0.0.0.0:5432->5432/tcp                     db_fastapi_sample

https://localhostにアクセスして、{"message":"Hello World"}が表示されれば完了です。

ライブラリ「pydantic」を使用して環境変数(.env)を扱う

.envファイル作成

$ touch .env
fastapi_sample/.env
DEBUG=True
DATABASE_URL=postgresql://postgres:postgres@db:5432/db_fastapi_sample

「pydatic」を使って環境変数を読み込む設定ファイルを作成

ライブラリ「pydantic」を使うと、.env から値を読み込む処理を簡単に実装できます。
また型に合わせてキャストしてくれたりするので超便利。
os.getenv('DEBUG') == 'True' みたいな気持ち悪い条件式を書かなくてよくなります。(distutils.util.strtobool使えって話ですが、、、)

pydanticをインストール

fastapi_sample/requirements.txt
pydantic[email]==1.6.1  # 追記したらpip3 install

フォルダ「core」を作成、その直下に「config.py」を作成

$ mkdir core
$ touch core/config.py
fastapi_sample/core/config.py
import os
from functools import lru_cache
from pydantic import BaseSettings

PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


class Environment(BaseSettings):
    """ 環境変数を読み込むファイル
    """
    debug: bool  # .envから読み込んだ値をbool型にキャッシュ
    database_url: str

    class Config:
        env_file = os.path.join(PROJECT_ROOT, '.env')


@lru_cache
def get_env():
    """ 「@lru_cache」でディスクから読み込んだ.envの結果をキャッシュする
    """
    return Environment()

alembic(マイグレーションツール)環境の用意 と モデルを用意

何はともあれインストール

requirements.txtに alembic と sqlalchemy 追記してからpip3 install -r requirements.txtでインストール
(Dockerの手順を踏んでいる人はdocker-compose restart app または docker-compose exec app pip3 install -r requiements.txt

fastapi_samepl/requirements.txt
alembic==1.4.2  # 追記
.
..
...
psycopg2==2.8.6  # 追記
.
..
...
SQLAlchemy==1.3.18  # 追記
SQLAlchemy-Utils==0.36.8  # 追記

プロジェクトルート直下にalembic環境を作成する

$ alembic init migrations  # Dockerの手順を踏んだ人はdocker-compose exec app alembic init migrations

プロジェクトルートにalembicテンプレートが作成されます。(alembic.ini, migrationsフォルダ)

fastapi_sample(プロジェクトルート)
├── ...
├── ..
├── .
├── alembic.ini
├── migrations
│   ├── README
│   ├── env.py
│   ├── script.py.mako
│   └── versions
└── ...

ベースモデルを用意

まずはベースモデルを実装します。

touch models.py
fastapi_sample/migrations/models.py
from sqlalchemy import Column
from sqlalchemy.dialects.postgresql import INTEGER, TIMESTAMP
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.sql.functions import current_timestamp

Base = declarative_base()


class BaseModel(Base):
    """ ベースモデル
    """
    __abstract__ = True

    id = Column(
        INTEGER,
        primary_key=True,
        autoincrement=True,
    )

    created_at = Column(
        'created_at',
        TIMESTAMP(timezone=True),
        server_default=current_timestamp(),
        nullable=False,
        comment='登録日時',
    )

    updated_at = Column(
        'updated_at',
        TIMESTAMP(timezone=True),
        onupdate=current_timestamp(),
        comment='最終更新日時',
    )

    @declared_attr
    def __mapper_args__(cls):
        """ デフォルトのオーダリングは主キーの昇順
        
        降順にしたい場合
        from sqlalchemy import desc
        # return {'order_by': desc('id')}
        """
        return {'order_by': 'id'}

ユーザーモデルを用意

モデルはなんでもいいですが、認証編を書くときに使えそうなのでユーザーモデルを作成します。
モジュール分割はあえてしていません。(マイグレーションファイル生成時などで不都合発生して手間なので、しない方がいいかも?)

fastapi_sample/migrations/models.py
from sqlalchemy import (
    BOOLEAN,
    Column
    INTEGER,
    TEXT,
    TIMESTAMP,
    VARCHAR,
)
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.sql.functions import current_timestamp

Base = declarative_base()


class BaseModel(Base):
    ...
    ..
    .


class User(BaseModel):
    __tablename__ = 'users'

    username = Column(TEXT, unique=True, nullable=False)
    password = Column(VARCHAR(128), nullable=False)
    last_name = Column(VARCHAR(100), nullable=False)
    first_name = Column(VARCHAR(100), nullable=False)
    is_admin = Column(BOOLEAN, nullable=False, default=False)
    is_active = Column(BOOLEAN, nullable=False, default=True)

env.pyを編集する

fastapi_sample/migrations/env.py
from migrations.models import Base  # 追加
...
..
.
target_metadata = Base.metadata  # メタデータをセット
...
..
.

DBの接続先を変更

モデルを実装したところで、早速マイグレーションを行いたいところです。
が、alembicがプロジェクトテンプレートのままなのでalembic.iniのなかのデータベースURLもテンプレートのままです。

fastapi_sample/alembic.ini
.
..
...
sqlalchemy.url = driver://user:pass@localhost/dbname
...
..
.

なので勿論alembicのマイグレーションファイル生成コマンドなどは失敗してしまいます。

$ docker-compose exec app alembic revision --autogenerate
Traceback (most recent call last):
  File "/root/local/python-3.8.5/bin/alembic", line 8, in <module>
    sys.exit(main())
  ...
  (長いので省略)
  .
    cls = registry.load(name)
  File "/root/local/python-3.8.5/lib/python3.8/site-packages/sqlalchemy/util/langhelpers.py", line 267, in load
    raise exc.NoSuchModuleError(
sqlalchemy.exc.NoSuchModuleError: Can't load plugin: sqlalchemy.dialects:driver

alembic.iniを直接書き換えるのもいいですが、そうすると開発環境やステージング、本番環境でそれぞれ異なるalembic.iniファイルができてしまうため気持ち悪いです。
そうならないように、マイグレーションファイル生成時やマイグレート実行時は、alembic.iniのデータベースURLを一時的に.envDATABASE_URLで書き換わるようにしましょう。

python-dotenvをインストール

fastapi_sample/requirements.txt
python-dotenv==0.14.0  # 追記したらインストールする

env.pyを修正

fastapi_sample/migrations/env.py
import os
from core.config import PROJECT_ROOT
from dotenv import load_dotenv
.
..
...

# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.

# alembic.iniの'sqlalchemy.url'を.envのDATABASE_URLで書き換える
load_dotenv(dotenv_path=os.path.join(PROJECT_ROOT, '.env'))
config.set_main_option('sqlalchemy.url', os.getenv('DATABASE_URL'))


def run_migrations_offline():
    ...
    ..
    .

再度マイグレーションファイル生成コマンド実行

% docker-compose exec app alembic revision --autogenerate
None /fastapi_sample
INFO  [alembic.runtime.migration] Context impl PostgresqlImpl.
INFO  [alembic.runtime.migration] Will assume transactional DDL.
INFO  [alembic.autogenerate.compare] Detected added table 'users'
  Generating /fastapi_sample/migrations/versions/2bc0a23e563c_.py ...  done

マイグレーションファイルを生成できました。

migrations
├── ...
├── ...
├── models.py
└── versions
    └── 2bc0a23e563c_.py  # マイグレーションファイルができた。

ただ、今のままだとマイグレーションファイルの生成順がパッと見でわからないので、ファイル名に「日時(YYYYMMDD_HHMMSS)」がつくようにします。

alembic.iniを修正します。

fastapi_sample/alembic.ini
.
..
...
[alembic]
# path to migration scripts
script_location = migrations

# ファイルの接頭に「YYYYMMDD_HHMMSS_」がつくようにする
# template used to generate migration files
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d%%(second).2d_%%(rev)s_%%(slug)s

# タイムゾーンを日本時間に
# timezone to use when rendering the date
# within the migration file as well as the filename.
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
timezone = Asia/Tokyo
...
..
.

先ほど生成したマイグレーションファイルを消して、再度マイグレーションファイルを生成します。
無事ファイル名に日時がつきましたので、マイグレーションファイルが生成順に表示されるようになりました。

migrations
├── ...
├── ...
├── models.py
└── versions
    └── 20201024_200033_8a19c4c579bf_.py

マイグレート実行

$ alembic upgrade head  # Dockerの手順を踏んだ人はdocker-compose exec app alembic upgrade head

ちゃんとユーザーテーブルが作成されています。
(alembic_versionはalembicがバージョン管理に使用するテーブルで、自動生成されます)
image.png

コンテナ起動時にマイグレートが実行されるようにする

スクリプトファイルにマイグレーション実行コマンドを追記します。
これでdocker-compose up -d や docker-compose restartなどでコンテナを起動したときにマイグレーションも実行されるようになります。
バックエンドとフロントエンドで担当が別れている場合に、フロントの人にわざわざマイグレーションコマンドを打ち込んでもらう必要が無くなりますね。

fastapi_sample/docker/rundevserver.sh
pip3 install -r requirements.txt

alembic upgrade head  # 追記

uvicorn main:app\
    --reload\
    --port 8000\
    --host 0.0.0.0\
    --log-level debug

データアクセスクラス作成

ベースのデータアクセスクラスを作成

DBセッションを返す関数 や 単純な全件取得、1件取得、登録、更新、削除などを定義します。
これから実装するデータアクセスクラスは全てこのベースクラスを継承させます。

mkdir crud
touch crud/__init__.py
fastapi_sample/crud/__init__.py
from typing import List, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session, query

from core.config import get_env
from migrations.models import Base

ModelType = TypeVar("ModelType", bound=Base)

# プロダクト使用時はpool_sizeやmax_overflowも指定したほうがいいかも
connection = create_engine(
    get_env().database_url,
    echo=get_env().debug,
    encoding='utf-8',
    # pool_size=,
    # max_overflow=
)

Session = scoped_session(sessionmaker(connection))


def get_db_session() -> scoped_session:
    """ 新しいDBコネクションを返す
    """
    return scoped_session(sessionmaker(connection))


class BaseCRUD:
    """ データアクセスクラスのベース
    """
    model: ModelType = None

    def __init__(self, db_session: scoped_session) -> None:
        self.db_session = db_session
        self.model.query = self.db_session.query_property()

    def get_query(self) -> query.Query:
        return self.model.query

    def gets(self) -> List[ModelType]:
        """ 全件取得
        """
        return self.get_query().all()

    def get_by_id(self, id: int) -> ModelType:
        """ 主キーで取得
        """
        return self.get_query().filter_by(id=id).first()

    def create(self, data: dict = {}) -> ModelType:
        """ 新規登録
        """
        obj = self.model()
        for key, value in data.items():
            if hasattr(obj, key):
                setattr(obj, key, value)
        self.db_session.add(obj)
        self.db_session.flush()
        self.db_session.refresh(obj)
        return obj

    def update(self, obj: ModelType, data: dict = {}) -> ModelType:
        """ 更新
        """
        for key, value in data.items():
            if hasattr(obj, key):
                setattr(obj, key, value)
        self.db_session.flush()
        self.db_session.refresh(obj)
        return obj

    def delete_by_id(self, id: int) -> None:
        """ 主キーで削除
        """
        obj = self.get_by_id(id)
        if obj:
            obj.delete()
            self.db_session.flush()
        return None

ユーザーデータアクセスクラスを作成

touch crud/crud_user.py
fastapi_sample/crud/crud_user.py
from crud import BaseCRUD
from migrations.models import User


class CRUDUser(BaseCRUD):
    """ ユーザーデータアクセスクラスのベース
    """
    model = User

API実装

お待たせしました、ようやくAPI実装編です。

API実装前準備

リクエスト情報にDBセッションを格納するミドルウェアを定義します。

mkdir middlewares
touch middlewares/__init__.py
middlewares/__init__.py
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

from crud import get_db_session


class DBSessionMiddleware(BaseHTTPMiddleware):
    """ リクエスト情報にDBセッションを設定するミドルウェア
    """
    async def dispatch(self, request: Request, call_next) -> Response:
        """ ミドルウェアの処理

        Args:
            request (Request): リクエスト情報
            call_next (method): 次の処理

        Returns:
            Response: レスポンス
        """
        request.state.db_session = get_db_session()
        return await call_next(request)

ミドルウェアを実装したら、main.pyのエントリポイントに追加してあげるのを忘れずに。

main.py
.
..
...
from middlewares import DBSessionMiddleware  # 追記
...
..
.
# ミドルウェアの設定
app.add_middleware(DBSessionMiddleware)
...
..
.

こうすることで、これから実装するAPIコントローラ内で、DBセッションに「request.state.db_session」でアクセスできるようになります。

APIを実装していくフォルダ作成

mkdir -p api/v1

ユーザーの一覧取得API実装

touch api/v1/user.py
fastapi_sample/api/v1/user.py
from crud.crud_user import CRUDUser
from fastapi import Request
from fastapi.encoders import jsonable_encoder


class UserAPI:
    """ ユーザーに関するAPI
    """
    @classmethod
    def gets(cls, request: Request):
        """ 一覧取得
        """
        # ミドルウェアでリクエスト情報にDBセッションをセットしたので、
        # 「request.state.db_session」でDBセッションにアクセスできる
        return jsonable_encoder(CRUDUser(request.state.db_session).gets())

APIルーター実装

mkdir -p api/endpoints/v1
touch api/endpoints/v1/user.py
fastapi_sample/api/endpoints/v1/user.py
from api.v1.user import UserAPI
from fastapi import APIRouter, Depends, Request

router = APIRouter()


@router.get('/', response_model=List[UserInDB])
async def gets(request: Request) -> List[UserInDB]:
    """ 一覧取得
    """
    return UserAPI.gets(request)

ルーターをエントリーポイントに登録する

APIルーターを作成

touch api/endpoints/v1/__init__.py
fastapi_sample/api/endpoints/v1/__init__.py
from fastapi import APIRouter
from api.endpoints.v1 import user

api_v1_router = APIRouter()
api_v1_router.include_router(
    user.router,
    prefix='users',
    tags=['users'])

main.pyを編集

fastapi_sample/main.py
from api.endpoints.v1 import api_v1_router  # 追記
...
..
.
app.include_router(api_v1_router, prefix='/api/v1')  # 追記
...
..
.

ブラウザからhttps://localhost/docsアクセスすると実装したAPIが表示されているはずです。
image.png

登録/更新時のリクエストパラメータとレスポンス用のスキーマクラスを定義する

mkdir api/schemas
touch api/schemas/user.py
fastapi_sample/api/schemas/user.py
from pydantic import BaseModel
from typing import Optional


class BaseUser(BaseModel):
    username: str
    last_name: str
    first_name: str
    is_admin: bool


class CreateUser(BaseUser):
    password: str


class UpdateUser(BaseUser):
    password: Optional[str]
    last_name: Optional[str]
    first_name: Optional[str]
    is_admin: bool


class UserInDB(BaseUser):
    class Config:
        orm_mode = True

ユーザーの登録/更新/削除のAPIを追加

fastapi_sample/api/v1/user.py
from typing import List

from fastapi import Request
# from fastapi.encoders import jsonable_encoder  # 削除

from api.schemas.user import CreateUser, UpdateUser, UserInDB
from crud.crud_user import CRUDUser


class UserAPI:
    """ ユーザーに関するAPI
    """
    @classmethod
    def gets(cls, request: Request) -> List[UserInDB]:
        """ 一覧取得
        """
        return CRUDUser(request.state.db_session).gets()  # jsonable_encoderは使わない

    @classmethod
    def create(
        cls,
        request: Request,
        schema: CreateUser
    ) -> UserInDB:
        """ 新規登録
        """
        return CRUDUser(request.state.db_session).create(schema.dict())

    @classmethod
    def update(
        cls,
        request: Request,
        id: int,
        schema: UpdateUser
    ) -> UserInDB:
        """ 更新
        """
        crud = CRUDUser(request.state.db_session)
        obj = crud.get_by_id(id)
        return CRUDUser(request.state.db_session).update(obj, schema.dict())

    @classmethod
    def delete(cls, request: Request, id: int) -> None:
        """ 削除
        """
        return CRUDUser(request.state.db_session).delete_by_id(id)

ユーザーのAPIルーターを編集

fastapi_sample/api/endpoints/v1/user.py
from typing import List

from fastapi import APIRouter, Request

from api.schemas.user import CreateUser, UpdateUser, UserInDB
from api.v1.user import UserAPI
from migrations.models import User

router = APIRouter()


@router.get('/', response_model=List[UserInDB])
async def gets(request: Request) -> List[User]:
    """ 一覧取得
    """
    return UserAPI.gets(request)


@router.post('/', response_model=UserInDB)
async def create(request: Request, schema: CreateUser) -> User:
    """ 新規登録
    """
    return UserAPI.create(request, schema)


@router.put('/{id}/', response_model=UserInDB)
async def update(request: Request, id: int, schema: UpdateUser) -> User:
    """ 更新
    """
    return UserAPI.update(request, id, schema)


@router.delete('/{id}/')
async def delete(request: Request, id: int) -> None:
    """ 削除
    """
    return UserAPI.delete(request, id)

ルータの引数にリクエストパラメータ用の関数を指定するだけでOKです。
async def update(request: Request, id: int, schema: UpdateUser) -> UserInDB:
Swaggerドキュメントの「Request Body」反映されます。
image.png

response_modelを指定すると、DBから取得したオブジェクトのJSONシリアライズを開発者が明示的に実装する必要がなくなります。
@router.get('/', response_model=List[UserInDB])
こちらももちろんSwaggerドキュメントの「Response」に反映されます。
image.png

ブラウザからアクセスして確認

CRUDのAPIを無事実装できました。
image.png

リクエストミドルウェアを実装する

さて、APIは実行できましたが、クエリの結果をDBに反映していない(commitしていない)ので登録/更新/削除を実行してもDBに反映されません。
変更をDBに反映させるため、ミドルウェアを実装してエントリポイントに追加します。

fastapi_sample/middlewares/__init__.py
.
..
...

# このミドルウェアクラスを追記
class HttpRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next) -> Response:
        try:
            response = await call_next(request)
            request.state.db_session.commit()  # コミット
            return response

        # DBセッションの破棄は必ず行う
        # commit()が実行されていない場合はこのremove()でロールバックが実行される
        finally:
            request.state.db_session.remove()

エントリーポイントにミドルウェアを適用する

fastapi_sample/main.py
.
..
...
from middleware import (
    DBSessionMiddleware,
    HttpRequestMiddleware  # 追加
)
...
..
.
# ミドルウェアの設定
...
..
.
app.add_middleware(HttpRequestMiddleware)  # 追加
...
..
.

CRUD実行で確認

無事、変更がDBに反映されるようになりました。
image.png

テストコード実装編

CRUDを実装できたので、次はテストコードを実装していきます。
とはいえ、単純にテストコードからAPIを実行してしまうと、開発で使用しているDBに対して登録・更新・削除が実行されるため、テストを実行するたびに結果が変わってしまう可能性があります。
それを防ぐため、テスト実行時にテスト用のDBを作成し、それを使用するようにします。

pytestをインストール

fastapi_sample/requirements.txt
pytest==6.1.0  # 追記したらrequirements.txtをインストールし直すのを忘れずに

テスト用のDB接続情報を環境変数に追加

fastapi_sample/.env
DEBUG=True
DATABASE_URL=postgresql://postgres:postgres@db:5432/db_fastapi_sample
TEST_DATABASE_URL=postgresql://postgres:postgres@db:5432/test_db_fastapi_sample  # 追加
fastapi_sample/core/config.py
class Environment(BaseSettings):
    """ 環境変数を読み込むファイル
    """
    debug: bool
    database_url: str
    test_database_url: str  # 追加

テストコード実装用のフォルダ作成

mkdir tests

conftest.pyを作成

$ touch tests/conftest.py

pytestのセットアップコードを実装するファイルです。
ここでテスト用のエントリポイントを作成しています(main.pyのエントリポイントは使用しない)。
ミドルウェアなどの処理をテスト用に切り替えたい場合に、下記のような奇妙な条件分岐が発生しないようにするためです。

db_session = 開発用のDBセッション

# pytest実行時はテスト用のDBセッションを使う
if get_env().is_test:
    db_session = テスト用のDBセッション

 
pytest_sessionstart: pytest実行前に1回だけ呼ばれる関数(ここでテストDB作成)
pytest_sessionfinish: pytestで全てのテストクラス、テストケースの実行が終わった後に1回だけ呼ばれる関数(ここでテストDB削除)

fastapi_sample/tests/conftest.py
import psycopg2
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
# from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy_utils import database_exists, drop_database

from core.config import get_env
from migrations.models import Base
from tests.db_session import test_db_connection

# conftestで初期データを登録する場合はこのSessionを使用する
# Session = scoped_session(
#     sessionmaker(
#         bind=test_db_connection
#     )
# )


def create_test_database():
    # テストDBが削除されずに残ってしまっている場合は削除
    if database_exists(get_env().test_database_url):
        drop_database(get_env().test_database_url)

    # テストDB作成
    _con = psycopg2.connect('host=db user=postgres password=postgres')
    _con.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)
    _cursor = _con.cursor()
    _cursor.execute('CREATE DATABASE test_db_fastapi_sample')

    # テストDBにテーブル追加
    Base.metadata.create_all(bind=test_db_connection)


def pytest_sessionstart(session):
    """ pytest実行時に一度だけ呼ばれる処理
    """
    # テストDB作成
    create_test_database()


def pytest_sessionfinish(session, exitstatus):
    """ pytest終了時に一度だけ呼ばれる処理
    """
    # テストDB削除
    if database_exists(get_env().test_database_url):
        drop_database(get_env().test_database_url)

テスト用のDBセッション管理クラスを作成

$ touch tests/db_session.py
fastapi_sample/tests/db_session.py
from threading import local as thread_local

from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker, Session

from core.config import get_env

_thread_local = thread_local()


def set_current_test_db_session(db_session: scoped_session) -> None:
    """ スレッドローカルにテスト用のDBセッションをセット

    Args:
        db_session (scoped_session): テスト用のDBセッション
    """
    setattr(_thread_local, 'db_session', db_session)


def get_current_test_db_session() -> scoped_session:
    """ スレッドローカルからテスト用のDBセッションを取得

    Returns:
        scoped_session: テスト用のDBセッション
    """
    return getattr(_thread_local, 'db_session')


class TestingDBSession(Session):
    """ commit()の挙動を変えるため、Sessionクラスをオーバーライド
    """
    def commit(self):
        # データアクセスクラス(fastapi_sample/crud)やAPIの中でflush()は実行する想定なので、
        # ここでflush()はとりあえず不要
        # self.flush()
        self.expire_all()


class test_scoped_session(scoped_session):
    """ リクエストミドルウェアのremove()で何も実行されないように、scoped_sessionクラスをオーバーライド
    """
    def remove(self):
        pass

    def test_db_session_remove(self):
        """ テストDB用のremove()
        """
        if self.registry.has():
            self.registry().close()
        self.registry.clear()


test_db_connection = create_engine(
    get_env().test_database_url,
    encoding='utf8',
    pool_pre_ping=True,
)


def get_test_db_session():
    """ テストDBセッションを返す
    """
    return test_scoped_session(
        sessionmaker(
            bind=test_db_connection,
            class_=TestingDBSession,
            expire_on_commit=False,
        )
    )

テスト用のDBセッションミドルウェアを作成

fastapi_sample/tests/middleware.py
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

from tests.db_session import get_current_test_db_session


class TestDBSessionMiddleware(BaseHTTPMiddleware):
    """ リクエスト情報にDBセッションを設定するミドルウェア
    """
    async def dispatch(self, request: Request, call_next) -> Response:
        """ ミドルウェアの処理

        Args:
            request (Request): リクエスト情報
            call_next (method): 次の処理

        Returns:
            Response: レスポンス
        """
        request.state.db_session = get_current_test_db_session()
        return await call_next(request)

conftest.pyにテスト用のエントリポイントを定義

tests/conftest.py
from fastapi import FastAPI  # 追記
.
..
...
from api.endpoints.v1 import api_v1_router  # 追記
from core.config import get_env
...
..
.
from middlewares import HttpRequestMiddleware  # 追記
from tests.db_session import test_db_connection
from tests.middleware import TestDBSessionMiddleware  # 追記

# テスト用のエントリポイント
test_app = FastAPI()
test_app.add_middleware(HttpRequestMiddleware)
test_app.add_middleware(TestDBSessionMiddleware)
test_app.include_router(api_v1_router, prefix='/api/v1')

APIのベーステストクラスを作成

touch tests/base.py
fastapi_sample/tests/base.py
from fastapi.testclient import TestClient

from tests.conftest import test_app
from tests.db_session import get_test_db_session, set_current_test_db_session


class BaseTestCase:
    """ ベーステストクラス

    Attributes:
        client (FastAPI): APIクライアント
    """
    client = TestClient(test_app)

    def setup_method(self, method) -> None:
        """ テストケースごとの前処理
        """
        self.db_session = get_test_db_session()
        set_current_test_db_session(self.db_session)

    def teardown_method(self, method) -> None:
        """ テストケースごとの後処理
        """
        self.db_session.test_db_session_remove()  # ロールバック

一覧取得のテストコードを実装してみる

requirements.txt
requests==2.24.0  # 追記したら必ずrequirements.txtをインストール
mkdir -p tests/api/v1
touch tests/api/v1/test_user.py  # "test_"を接頭に付けないとpytest実行時にスルーされてしまうので必ずつける
fastapi_sample/tests/api/v1/test_user.py
import json
from crud.crud_user import CRUDUser
from fastapi import status
from tests.base import BaseTestCase


class TestUserAPI(BaseTestCase):
    """ ユーザーAPIのテストクラス
    """
    TEST_URL = '/api/v1/users/'

    def test_gets(self):
        """ 一覧取得のテスト
        """
        # テストユーザー登録
        test_data = [
            {
                'email': 'test1@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': False
            },
            {
                'email': 'test2@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': True
            },
            {
                'email': 'test3@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': False
            },
        ]
        for data in test_data:
            CRUDUser(self.db_session).create(data)
            self.db_session.commit()

        response = self.client.get(self.TEST_URL)

        # ステータスコードの検証
        assert response.status_code == status.HTTP_200_OK

        # 取得した件数の検証
        response_data = json.loads(response._content)
        assert len(response_data) == len(test_data)

        # レスポンスの内容を検証
        expected_data = [{
            'email': item['email'],
            'last_name': item['last_name'],
            'first_name': item['first_name'],
            'is_admin': item['is_admin'],
        } for i, item in enumerate(test_data, 1)]
        assert response_data == expected_data

    def test_confirm_rollback(self):
        """ 一覧取得のテスト(ロールバックされているか確認する用)
        """
        # テストユーザー登録
        test_data = [
            {
                'email': 'test1@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': False
            },
            {
                'email': 'test2@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': True
            },
            {
                'email': 'test3@example.com',
                'password': 'password',
                'last_name': 'last_name',
                'first_name': 'first_name',
                'is_admin': False
            },
        ]
        for data in test_data:
            CRUDUser(self.db_session).create(data)
            self.db_session.commit()

        response = self.client.get(self.TEST_URL)

        # ステータスコードの検証
        assert response.status_code == status.HTTP_200_OK

        # 取得した件数の検証
        response_data = json.loads(response._content)
        assert len(response_data) == len(test_data)

        # レスポンスの内容を検証
        expected_data = [{
            'email': item['email'],
            'last_name': item['last_name'],
            'first_name': item['first_name'],
            'is_admin': item['is_admin'],
        } for i, item in enumerate(test_data, 1)]
        assert response_data == expected_data

テスト実行

pytest  # Dockerの手順を踏んだ方は、docker-compose exec app pytest
# print()文の出力を確認したい場合は、pytest -v --capture=no
$ docker-compose exec app pytest
===================================================================================== test session starts =====================================================================================
platform linux -- Python 3.8.5, pytest-6.1.0, py-1.9.0, pluggy-0.13.1
rootdir: /fastapi_sample
collected 2 items

tests/api/v1/test_user.py ..                                                                                                                                                            [100%]

====================================================================================== warnings summary =======================================================================================
<string>:2
  <string>:2: SADeprecationWarning: The mapper.order_by parameter is deprecated, and will be removed in a future release. Use Query.order_by() to determine the ordering of a result set.

-- Docs: https://docs.pytest.org/en/stable/warnings.html

テストも無事成功しました。
データベースを確認してみましょう。
API実装時に試しに登録したデータしかありません。
テスト用のDBに切り替えてのテスト実行は成功したようです。

image.png

これで、開発用のDBを汚さず かつ テストケースごとのテストデータも干渉することなく、テストを実行することができるようになりました。

【オマケ①】CORS問題を回避

このままだと、フロントエンドからのAPI呼び出し時にCORSエラー発生してしまいます。
image.png

この問題を回避するため、CORSミドルウェアを実装します。

.env環境変数を扱うクラスを編集

fastapi_sample/.env
ALLOW_HEADERS='["*"]'  # 追加
ALLOW_ORIGINS='["*"]'  # 追加
DEBUG=True
DATABASE_URL=postgresql://postgres:postgres@db:5432/db_fastapi_sample
TEST_DATABASE_URL=postgresql://postgres:postgres@db:5432/test_db_fastapi_sample
fastapi_sample/core/config.py
class Environment(BaseSettings):
    """ 環境変数を読み込むファイル
    """
    allow_headers: list  # 追加
    allow_origins: list  # 追加
    ...
    ..
    .

CORSミドルウェアを実装・エントリポイントに適用

fastapi_sample/middlewares/__init__.py
from core.config import get_env  # 追加
...
..
.
from starlette.middleware import cors  # 追加
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp  # 追加
...
..
.
class CORSMiddleware(cors.CORSMiddleware):
    """ CORS問題を回避するためのミドルウェア
    """
    def __init__(self, app: ASGIApp) -> None:
        super().__init__(
            app,
            allow_origins=get_env().allow_origins,
            allow_methods=cors.ALL_METHODS,
            allow_headers=get_env().allow_headers,
            allow_credentials=True,
        )
...
..
.
fastapi_sample/main.py
...
..
.
from middleware import (
    DBSessionMiddleware,
    CORSMiddleware,  # 追記
    HttpRequestMiddleware
)
...
..
.

# ミドルウェアの設定
# ・ミドルウェアは 後 に追加したものが先に実行される
# ・CORSMiddlewareは必ず一番最後に追加すること
app.add_middleware(HttpRequestMiddleware)
app.add_middleware(DBSessionMiddleware)
app.add_middleware(CORSMiddleware)  # 追加
...
..
.

ポイントはミドルウェアを追加する順番です。
後に書いたミドルウェアが先に実行されるので、CORSミドルウェアは一番最後に書きましょう。
 
 
無事、フロントエンドのCORSエラーは回避することができるようになりました。
(文字列"ログイン成功"を返すだけのAPIを作成してフロントエンドから実行)
image.png

【オマケ②】カスタム例外

Exceptionを継承してカスタム例外を作成してみます。
エラーレスポンスはこんな感じで、エラーコードとエラーメッセージを複数返せるような形式します。

detail: [
  {
    error_code: 'エラーコード',
    error_message: 'エラーメッセージ'
  },
  {
    error_code: 'エラーコード',
    error_message: 'エラーメッセージ'
  },
  { ... },
  { ... },
]

メッセージクラスを作成

mkdir exceptions
touch exception/error_messages.py
fastapi_sample/exceptions/error_messages.py
class BaseMessage:
    """ メッセージクラスのベース
    """
    text: str

    def __str__(self) -> str:
        return self.__class__.__name__


class ErrorMessage:
    """ メッセージクラス
    """
    class INTERNAL_SERVER_ERROR(BaseMessage):
        text = 'システムエラーが発生しました、管理者に問い合わせてください'

    class FAILURE_LOGIN(BaseMessage):
        text = 'ログイン失敗'

カスタム例外を実装するフォルダ作成

touch exceptions/__init__.py
fastapi_sample/exceptions/__init__.py
import traceback

from fastapi import status, HTTPException

from exceptions.error_messages import ErrorMessage


class ApiException(HTTPException):
    """ API例外
    """
    default_status_code = status.HTTP_400_BAD_REQUEST

    def __init__(
        self,
        *errors,
        status_code: int = default_status_code
    ) -> None:
        self.status_code = status_code
        self.detail = [
            {
                'error_code': str(error['error_code']),
                'error_msg': error['error_code'].text.format(*error['msg_params']),
            } for error in list(errors)
        ]
        super().__init__(self.status_code, self.detail)


def create_error(error_code: ErrorMessage, *msg_params) -> dict:
    """ エラー生成

    Examples
    --------
    >>> create_error(messages.INTERNAL_SERVER_ERROR)
    {'error_code': INTERNAL_SERVER_ERROR, 'msg_params': None}

    >>> create_error(messages.E_REGISTRATION, 'ユーザー')
    {'error_code': E_REGISTRATION, 'msg_params': ユーザー}
    """
    return {
        'error_code': error_code(),
        'msg_params': msg_params,
    }

例外ハンドラを実装する

$ touch core/handlers.py
fastapi_sample/core/handlers.py
from fastapi import Request, status
from fastapi.responses import JSONResponse

from exceptions import ApiException
from exceptions.messages import ErrorMessage


async def api_exception_handler(request: Request, exception: ApiException) -> None:
    """ ApiExceptionハンドラ
    """
    return JSONResponse(exception.detail, status_code=exception.status_code)

忘れずにエントリポイントに追加する

main.py
.
..
...
from core.handlers import api_exception_handler
from exceptions import ApiException
...
..
.
# 例外ハンドラの設定
app.add_exception_handler(ApiException, api_exception_handler)
...
..
.

カスタム例外をスローする

fastapi_sample/api/v1/auth.py
from fastapi import Request
from exceptions import (
    ApiException,
    create_error
)
from exceptions.error_messages import ErrorMessage


class AuthAPI:
    """ 認証に関するAPI
    """
    @classmethod
    def login(cls, request: Request):
        """ ログインAPI
        """
        # カスタム例外をスロー
        raise ApiException(create_error(ErrorMessage.FAILURE_LOGIN))
        return 'ログイン成功'

▼フロント側(エラーレスポンスを確認したいだけなのでスルーでOKです)

base-api.js
axios.interceptors.response.use(
  response => {
    return response;
  },
  error => {
    // エラーレスポンスをコンソール表示
    console.log(error.response);
    ...
    ..
    .

カスタム例外がスローされ、フロント側でエラーレスポンスを確認することができました。
image.png

システムエラーについて

アプリが意図的に返す例外は問題ないですが、意図せぬ例外(システムエラー)はどうでしょうか。

fastapi_sample/api/v1/auth.py
from fastapi import Request


class AuthAPI:
    """ 認証に関するAPI
    """
    @classmethod
    def login(cls, request: Request):
        """ ログインAPI
        """
        result = 1 / 0  # 0除算でエラー
        return result

 
結果はCORSエラーが返却された上に、エラーレスポンス(error.response)の中身がundefinedです。
(CORSエラーはバックエンドが返したエラーレスポンスの形式が不正だから・・・??)
image.png

予期せぬエラー用のカスタムExceptionを作成

fastapi_sample/exception/__init__.py
import traceback  # 追加

from fastapi import status

from exception import error_messages


class ApiException(Exception):
    ...
    ..
    .


# 追記
class SystemException(HTTPException):
    """ システム例外
    """
    def __init__(self, e: Exception) -> None:
        self.exc = e
        self.stack_trace = traceback.format_exc()
        self.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
        self.detail = [
            {
                'error_code': str(ErrorMessage.INTERNAL_SERVER_ERROR()),
                'error_msg': ErrorMessage.INTERNAL_SERVER_ERROR.text
            }
        ]
        super().__init__(self.status_code, self.detail)

予期せぬエラー用の例外ハンドラを実装

fastapi_sample/core/handlers.py
.
..
...
from exceptions import ApiException, SystemException  # SystemException追加
...
..
.
async def system_exception_handler(request: Request, exception: Exception) -> None:
    """ システム例外ハンドラ
    """
    exception = SystemException(exception)
    return JSONResponse(exception.detail, status_code=exception.status_code)
...
..
.

忘れずにエントリポイントに追加

fastapi_sample/main.py
.
..
...
from core.handlers import api_exception_handler, system_exception_handler  # system_exception_handler追加
...
..
.
# 例外ハンドラの設定
fabeee_one_api.add_exception_handler(ApiException, api_exception_handler)
fabeee_one_api.add_exception_handler(Exception, system_exception_handler)  # 追加

フロント側でシステムエラーのエラーレスポンスを受け取ることができるようになりました。
image.png

ログイン処理の実装

誰でも彼でもAPIを実装できてしまうと困りますよね。
ログイン済みユーザーのみAPIを実行できるようにするために、まずはログイン処理を実装します。

JWT生成関連のユーティリティ実装

認証には「JWT(Json Web Token)」を使用するので、そのユーティリティを実装します。

$ pip3 install "python-jose[cryptography]==3.2.0"  # JWTを扱うためのモジュール
$ mkdir utilities
$ touch utilities/jwt_handler.py

jwt_claims_handler: JWTの元になる情報(クレームセット)を生成するための関数です。
jwt_encode_handler: クレームセットをエンコードしてJWTを生成するための関数です。(クレームセット >>> JWT)
jwt_decord_handler: JWTをデコードしてクレームセットを返す関数です。(クレームセット <<< JWT)

fastapi_sample/utilities/jwt_handler.py
from datetime import datetime, timedelta
from typing import Any, Dict

from jose import jwt  # python-jose

from core.config import get_env
from migrations.models import User

TYPE_ACCESS_TOKEN = 'access_token'
TYPE_REFRESH_TOKEN = 'refresh_token'

PROTECTED_TOKEN_TYPES = (TYPE_ACCESS_TOKEN)


def jwt_claims_handler(user: User, token_type: str = '') -> Dict[str, Any]:
    """ クレームセットを生成

    Args:
        user (User): クレームセット含めるユーザー情報
        token_type (str): トークンタイプ

    Returns:
        Dict[str, Any]: クレームセット

    Raises:
        AssertionError: 不正なトークンタイプが指定された場合
    """
    assert token_type in PROTECTED_TOKEN_TYPES, \
        f'引数token_type には{"".join(PROTECTED_TOKEN_TYPES)}を指定してください'

    claims = {
        'token_type': token_type,
        'user_id': user.id,
    }

    # 「アクセストークン」の有効期限設定
    if claims['token_type'] == TYPE_ACCESS_TOKEN:
        claims['exp'] = datetime.utcnow() + timedelta(seconds=get_env().jwt_access_token_expire)

    return claims


def jwt_encode_handler(claims: dict) -> str:
    """ クレームセットをエンコードしてJWT文字列を返す

    Args:
        claims (dict): クレームセット

    Returns:
        str: JsonWebToken
    """
    return jwt.encode(
        claims,
        get_env().jwt_secret_key,
        get_env().jwt_algorithm
    )


def jwt_decord_handler(jwt_string: str) -> Dict[str, Any]:
    """ JWT文字列をデコードしてクレームセットを返す

    Args:
        jwt_string (str): JWT文字列

    Returns:
        Dict[str, Any]: JWTをデコードして取得したクレームセット
    """
    claims = jwt.decode(
        jwt_string,
        get_env().jwt_secret_key,
        algorithms=get_env().jwt_algorithm,)
    return claims


def jwt_response_handler(access_token: str) -> Dict[str, str]:
    """ JWT文字列を含んだ辞書データを返す

    Args:
        access_token (str): アクセストークン

    Returns:
        Dict[str, str]: JWT認証レスポンス
    """
    return {'token_type': 'bearer', TYPE_ACCESS_TOKEN: access_token}

ログインAPI実装

ログイン用のスキーマ

fastapi_sample/api/schemas/auth.py
# /usr/bin/env python
# -*- coding: utf-8 -*-
"""
authのスキーマ定義
"""

from fastapi.param_functions import Form

from core.config import get_env
from migrations.models import User

MAX_LENGTH_USERNAME = User.username.property.columns[0].type.length
MAX_LENGTH_PASSWORD = User.password.property.columns[0].type.length


class AuthRequestSchema:
    """ 認証に関するスキーマ
    """
    def __init__(
        self,
        username: str = Form(..., max_length=MAX_LENGTH_USERNAME),
        password: str = Form(..., max_length=MAX_LENGTH_PASSWORD)
    ):
        """ 初期処理

        Args:
            username (str):
                ・ユーザー名
                ・必須パラメータ

            password (str):
                ・パスワード
                ・必須パラメータ
        """
        self.username = username
        self.password = password

APIの処理

fastapi_sample/api/v1/auth.py
from typing import Dict

from fastapi import Request

from api.schemas.auth import AuthRequestSchema
from crud.crud_user import CRUDUser
from exceptions import ApiException, create_error
from exceptions.error_messages import ErrorMessage
from migrations.models import User
from utilities.hasher import check_password
from utilities.jwt_handler import (
    jwt_claims_handler,
    jwt_encode_handler,
    jwt_response_handler,
    TYPE_ACCESS_TOKEN,
)


class AuthAPI:
    """ 認証に関するAPI
    """
    @classmethod
    def login(
        cls,
        request: Request,
        schema: AuthRequestSchema
    ) -> Dict[str, str]:
        """ ログインAPI

        Args:
            request (Request): リクエスト情報
            schema (AuthRequestSchema): リクエストボディ

        Returns:
            Dict[str, str]: ユーザー認証結果

        Raises:
            ApiException: メールアドレス または パスワードが未入力の場合
        """
        credentials = {
            'username': schema.username,
            'password': schema.password,
        }

        # メールアドレスとパスワードが入力されている場合、
        # ユーザー認証を実施してアクセストークンとリフレッシュトークンを生成
        if all(credentials.values()):
            user = cls().__authenticate(request, **credentials)
            # アクセストークンのクレームセット取得
            access_token_claims = jwt_claims_handler(
                user,
                token_type=TYPE_ACCESS_TOKEN)

        # メールアドレス または パスワードが未入力の場合はエラー
        else:
            raise ApiException(create_error(ErrorMessage.INVALID_EMAIL_OR_PASSWORD))

        # アクセストークンを返す
        return jwt_response_handler(jwt_encode_handler(access_token_claims))

    def __authenticate(
        self,
        request: Request,
        username: str = None,
        password: str = None
    ) -> User:
        """ ユーザー認証

        Args:
            request (Request): リクエスト情報
            username (str): ユーザー名
            password (str): パスワード

        Returns:
            User: ユーザー情報

        Raises:
            ApiException:
                ・入力されたメールアドレスでユーザーを取得できなかった場合
                ・入力されたパスワードとユーザーのパスワードが一致しなかった場合
                ・入力されたメールアドレスで取得したユーザーが有効でない場合
        """
        user = CRUDUser(request.state.db_session).get_query().filter_by(**{'username': username}).all()

        # ユーザーを取得できなかった場合はエラー
        if not user:
            raise ApiException(create_error(ErrorMessage.FAILURE_LOGIN))

        # パスワードが一致しない もしくは ユーザーが有効でない場合はエラー
        if not check_password(password, user[0].password) or not user[0].is_active:
            raise ApiException(create_error(ErrorMessage.FAILURE_LOGIN))

        return user[0]

ルーター

fastapi_sample/api/endpoints/v1/auth.py
from fastapi import APIRouter, Depends, Request

from api.v1.auth import AuthAPI
from api.schemas.auth import AuthRequestSchema

router = APIRouter()


@router.post('/login/')
async def login(request: Request, schema: AuthRequestSchema = Depends()):
    """ ログイン
    """
    return AuthAPI.login(request, schema)
fastapi_sample/api/endpoints/v1/__init__.py
from fastapi import APIRouter
from api.endpoints.v1 import user, auth  # auth追加

..
..
.
# 追加
api_v1_router.include_router(
    auth.router,
    prefix='/auth',
    tags=['auth'])

ログインしてみる

※ 事前にユーザの登録が必要です。(ここまでのハンズオンでユーザ登録のAPIは実装できているので、そのAPIを実行するだけでよいはず)
image.png
アクセストークンを取得することができました、ログイン成功です👏

認証済みのユーザーのみAPIを実行できるようにする

公式ドキュメント通りに実装していこうと思います。
ただ、エラー時はカスタム例外をスローしたいので「OAuth2PasswordBearerクラス」をラップしたクラスを作り、そちらを利用します。

「OAuth2PasswordBearerクラス」をラップしたクラスを実装

まずは「OAuth2PasswordBearerクラス」をラップしたクラスを実装します。

$ touch utilities/authentication.py
fastapi_sample/utilities/authentication.py
from typing import Optional

from fastapi import security, Request, status
from starlette import authentication

from exceptions import ApiException, create_error
from exceptions.error_messages import ErrorMessage
from migrations.models import User


class OAuth2PasswordBearer(security.OAuth2PasswordBearer):
    """ OAuth2PasswordBearerのラッパー
    """
    async def __call__(self, request: Request) -> Optional[str]:
        """ 呼び出し可能インスタンス
        Args:
            request (Request): リクエスト情報

        Returns:
            Optional[str]: JsonWebToken

        Raises:
            ApiException: ヘッダーに認証情報(Authorization)が含まれていない場合
        """
        authorization: str = request.headers.get('Authorization')
        scheme, param = security.utils.get_authorization_scheme_param(authorization)
        if not authorization or scheme.lower() != 'bearer':
            if self.auto_error:
                raise ApiException(create_error(ErrorMessage.INVALID_TOKEN), status_code=status.HTTP_401_UNAUTHORIZED)
            else:
                return None
        return param


class AuthenticatedUser(authentication.SimpleUser):
    """ 認証済みユーザー
    """
    def __init__(self, user: User) -> None:
        self.id = user.id
        self.username = user.username


class UnauthenticatedUser(authentication.UnauthenticatedUser):
    """ 未認証ユーザー
    """
    pass

認証ミドルウェアを作成する

次は認証のミドルウェアを実装します。

fastapi_sample/middlewares/__init__.py
...
..
.
from fastapi.security.utils import get_authorization_scheme_param  # 追加
from jose import jwt  # 追加
from starlette.middleware import authentication, cors  # authentication追加
...
..
.
from crud.crud_user import CRUDUser  # 追加
...
..
.
from exceptions.error_messages import ErrorMessage  # 追加
from utilities.authentication import AuthenticatedUser, UnauthenticatedUser  # 追加
from utilities.jwt_handler import jwt_decord_handler  # 追加
...
..
.
class AuthenticationBackend(authentication.AuthenticationBackend):
    """ 認証ミドルウェアのバックエンド

    このミドルウェアを認証バックエンドとして使用することで、リクエストのユーザー情報に「request.user」でアクセス可能になる
    """
    async def authenticate(self, request: Request) -> None:
        """ 認証処理

        Args:
            request (Request): リクエスト情報
        """
        authorization: str = request.headers.get('Authorization')
        scheme, access_token = get_authorization_scheme_param(authorization)

        # リクエストヘッダに認証情報が無い場合は「未認証ユーザー」を返す
        if not authorization or scheme.lower() != 'bearer':
            return authentication.AuthCredentials(['unauthenticated']), UnauthenticatedUser()

        # JWTをデコードしてクレームセットを取得
        try:
            claims = jwt_decord_handler(access_token)

        # アクセストークン期限切れ
        except jwt.ExpiredSignatureError:
            raise ApiException(create_error(ErrorMessage.EXPIRED_TOKEN), status_code=status.HTTP_401_UNAUTHORIZED)

        # その他エラーの場合は「未認証ユーザー」を返す
        except Exception as e:
            print(e)
            return authentication.AuthCredentials(['unauthenticated']), UnauthenticatedUser()

        # クレームセットのユーザーIDでユーザーを取得
        user = CRUDUser(request.state.db_session).get_by_id(claims['user_id'])

        # 下記いずれかの場合はエラー
        # ・ユーザーを取得できなかった場合
        # ・ユーザーを取得できたが、非アクティブ
        if not user or not user.is_active:
            raise ApiException(create_error(ErrorMessage.INVALID_TOKEN))
        return authentication.AuthCredentials(['authenticated']), AuthenticatedUser(user)

エントリポイントにも追加する

fastapi_sample/main.py
.
..
...
from starlette.middleware.authentication import AuthenticationMiddleware  # 追加
...
..
.
from middlewares import (
    AuthenticationBackend,  # 追加
    DBSessionMiddleware,
    CORSMiddleware,
    HttpRequestMiddleware
)
...
..
.
# ミドルウェアの設定
app.add_middleware(AuthenticationMiddleware, backend=AuthenticationBackend())  # 追加(HttpRequestMiddlewareより前に追加)
app.add_middleware(HttpRequestMiddleware)
app.add_middleware(DBSessionMiddleware)
app.add_middleware(CORSMiddleware)
...
..
.

※ tests/conftest.pyのテスト用のエントリポイントにも認証ミドルウェアを追加すること

依存関係を作成する

$ mkdir dependencies
$ touch dependencies/__init__.py
fastapi_sample/dependencies/__init__.py
from fastapi import Depends, Request, status

from exceptions import ApiException, create_error
from exceptions.error_messages import ErrorMessage
from utilities.authentication import OAuth2PasswordBearer

OAUTH2_SCHEMA = OAuth2PasswordBearer(tokenUrl='/api/v1/auth/login')


async def login_required(
    request: Request,
    token: str = Depends(OAUTH2_SCHEMA)
) -> None:
    """ ユーザがログインしているかどうか

    Args:
        request (Request): リクエスト情報
        token (str): アクセストークン

    Raises:
        ApiException: ログインに失敗している場合
    """
    if not request.user.is_authenticated:
        raise ApiException((create_error(ErrorMessage.INVALID_TOKEN)), status_code=status.HTTP_401_UNAUTHORIZED)

実行ユーザーを制限したいAPIルーターに依存関係をつける

fastapi_sample/api/endpoints/v1/user.py
from typing import List

from fastapi import APIRouter, Depends, Request  # Depends追加
...
..
.
from dependencies import login_required  # 追加

router = APIRouter()


@router.get('/', response_model=List[UserInDB], dependencies=[Depends(login_required)])  # dependencies=...追加
async def gets(request: Request) -> List[User]:
    ...
    ..
    .


@router.post('/', response_model=UserInDB, dependencies=[Depends(login_required)])  # dependencies=...追加
async def create(request: Request, schema: CreateUser) -> User:
    ...
    ..
    .


@router.put('/{id}/', response_model=UserInDB, dependencies=[Depends(login_required)])  # dependencies=...追加
async def update(request: Request, id: int, schema: UpdateUser) -> User:
    ...
    ..
    .


@router.delete('/{id}/', dependencies=[Depends(login_required)])  # dependencies=...追加
async def delete(request: Request, id: int) -> None:
    ...
    ..
    .

APIドキュメントの対象のエンドポイントに「鍵マーク」がつくようになります。これで成功です。
image.png

ログインせずに実行するともちろん401エラーが返ってきます。
image.png

ログインして実行してみる。

APIドキュメントの「Authorize」ボタンか各エンドポイントの鍵マークを押して、ログイン用のモーダルを表示し、
usernameとpasswordを入力してログインしてください。
image.png
image.png

モーダルはこれに変われば成功です。

ログインユーザーだけにAPIの実行を許可できるようになりました。

image.png

終わり

FastAPIでCRUDを実装してみました。
フルスタックフレームワークのDjango(RestはDRF)ばかり使っていたせいで、初め「なんだこの使いづらいフレームワークは・・・」とか思っていましたが、今ではもうDjangoよりFastAPI派になってしまいました。
ガシガシ環境周りのコードを自分で実装できて楽しいですし、何より成長に繋がりました。
 
公式ドキュメントを読みきれていないので、他にもいい方法はあるかと思いますので、
コメントでこっそり教えてもらえると嬉しいです( ´ノω`)
 
私が所属しているFabeee株式会社はお仕事 と 一緒に働くお仲間を随時募集しております!!!!(宣伝)

話を聞いてみたい方はこちら
SES/受託開発のご依頼についてはこちら

41
32
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
32

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?