84
Help us understand the problem. What are the problem?

More than 1 year has passed since last update.

posted at

updated at

FastAPIを使ってCRUD APIを作成する

FastAPIは、pythonのWEB APIを作成するための非同期のframeworkで容易に高性能なAPIを作成できるということです。

Web Framework Benchmarks

Tutorialなどを参考に試してみます。

0.前提条件

  • Windows10(WSL2)にUbuntu 20.04がインストールされているものとします。
  • Python, pipはインストール済であるものとします。

  • 環境は以下のとおり

category value
CPU core i5 10210U 160GHz
MEMORY 8GB
OS Ubuntu 20.04.1 LTS(WSL2, Windows 10 home 1909)
Python 3.8.2
database PostgreSQL 12.4

1. CRUD APIの作成

(1). 作成するアプリのファイル構成

usersテーブルのCRUDを作成します。ファイルは6つほど作成しますが、それぞれ少ないコード量となっています。

fastapi-crud-example
│  db.py
│  main.py
│  
├─users
│  │  models.py
│  │  router.py
│  └─ schemas.py
│          
└─utils
      dbutils.py

(2). source

db.py

sqlalchemyは、modelの定義とクエリの生成で使い、
databaseへのアクセスはDatabasesを使います。

db.py
import databases
import sqlalchemy

DATABASE = 'postgresql'
USER = 'testuser'
PASSWORD = 'secret'
HOST = 'localhost'
PORT = '5432'
DB_NAME = 'testdb'

DATABASE_URL = '{}://{}:{}@{}:{}/{}'.format(DATABASE, USER, PASSWORD, HOST, PORT, DB_NAME)

# databases
database = databases.Database(DATABASE_URL, min_size=5, max_size=20)

ECHO_LOG = False

engine = sqlalchemy.create_engine(DATABASE_URL, echo=ECHO_LOG)

metadata = sqlalchemy.MetaData()

users/models.py

models.py
import sqlalchemy
from db import metadata, engine

users = sqlalchemy.Table(
    "users",
    metadata,
    sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, index=True),
    sqlalchemy.Column("username", sqlalchemy.String, index=True),
    sqlalchemy.Column("email", sqlalchemy.String, index=True),
    sqlalchemy.Column("hashed_password", sqlalchemy.String),
    sqlalchemy.Column("is_active", sqlalchemy.Boolean(), default=True),
    sqlalchemy.Column("is_superuser", sqlalchemy.Boolean(), default=False)
)

metadata.create_all(bind=engine)

users/schemas.py

pydanticのmodelを使って、crud用もmodelを定義します。

schemas.py
from pydantic import BaseModel

# insert用のrequest model。id(自動採番)は入力不要のため定義しない。
class UserCreate(BaseModel):
    username: str
    email: str
    password: str
    is_active: bool
    is_superuser: bool

# update用のrequest model
class UserUpdate(BaseModel):
    id : int
    username: str
    email: str
    password: str
    is_active: bool
    is_superuser: bool

# select用のrequest model。selectでは、パスワード不要のため定義しない。
class UserSelect(BaseModel):
    username: str
    email: str
    is_active: bool
    is_superuser: bool

users/router.py

crudの主要部になります。select, insert, update, deleteのいずれも、10行足らずで実装しています。コーディング量が少ないのはとてもいいですね。

router.py
import hashlib

from fastapi import APIRouter, Depends
from typing import List
from starlette.requests import Request

from .models import users
from .schemas import UserCreate, UserUpdate, UserSelect

from databases import Database

from utils.dbutils import get_connection

router = APIRouter()

# 入力したパスワード(平文)をハッシュ化して返します。
def get_users_insert_dict(user):
    pwhash=hashlib.sha256(user.password.encode('utf-8')).hexdigest()
    values=user.dict()
    values.pop("password")
    values["hashed_password"]=pwhash
    return values

# usersを全件検索して「UserSelect」のリストをjsonにして返します。
@router.get("/users/", response_model=List[UserSelect])
async def users_findall(request: Request, database: Database = Depends(get_connection)):
    query = users.select()
    return await database.fetch_all(query)

# usersをidで検索して「UserSelect」をjsonにして返します。
@router.get("/users/find", response_model=UserSelect)
async def users_findone(id: int, database: Database = Depends(get_connection)):
    query = users.select().where(users.columns.id==id)
    return await database.fetch_one(query)

# usersを新規登録します。
@router.post("/users/create", response_model=UserSelect)
async def users_create(user: UserCreate, database: Database = Depends(get_connection)):
    # validatorは省略
    query = users.insert()
    values = get_users_insert_dict(user)
    ret = await database.execute(query, values)
    return {**user.dict()}

# usersを更新します。
@router.post("/users/update", response_model=UserSelect)
async def users_update(user: UserUpdate, database: Database = Depends(get_connection)):
    # validatorは省略
    query = users.update().where(users.columns.id==user.id)
    values=get_users_insert_dict(user)
    ret = await database.execute(query, values)
    return {**user.dict()}

# usersを削除します。
@router.post("/users/delete")
async def users_delete(user: UserUpdate, database: Database = Depends(get_connection)):
    query = users.delete().where(users.columns.id==user.id)
    ret = await database.execute(query)
    return {"result": "delete success"}

utils/dbutils.py

dbutils.py
from starlette.requests import Request

# middlewareでrequestに格納したconnection(Databaseオブジェクト)を返します。
def get_connection(request: Request):
    return request.state.connection

main.py

main.py
from fastapi import FastAPI
from db import database
from users.router import router as userrouter
from starlette.requests import Request

app = FastAPI()

# 起動時にDatabaseに接続する。
@app.on_event("startup")
async def startup():
    await database.connect()

# 終了時にDatabaseを切断する。
@app.on_event("shutdown")
async def shutdown():
    await database.disconnect()

# users routerを登録する。
app.include_router(userrouter)

# middleware state.connectionにdatabaseオブジェクトをセットする。
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
    request.state.connection = database
    response = await call_next(request)
    return response

(3). 起動と確認

powershellから「uvicorn main:app --reload」を入力してenter

uvicorn main:app --reload
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [14108]
INFO: Started server process [21068]
INFO: Waiting for application startup.
INFO: Connected to database postgresql://testuser:********@localhost:5432/fastapidb

(4). Swaggerで確認

Swagger UIでapiにをテストできます。便利ですね。

chromeでhttp://127.0.0.1:8000/docsにアクセスしてみます。

image.png

(5). 最後に 感想

FastAPIはapiに特化したfreameworkという印象を受けましたが、jinja2などを使ったtemplate engineなども使えるし、oauth2などの認証機能も備わっています。FastAPIいい感じです。

2020.10.18 追加

2. さらに高速にしたい

(1) gunicornを使う

運用では、gunicornを使用することで、workerやthreadを調整して動かすことができます。
※Unix系の環境用なのでWindows環境では使えません。

console
pip install gunicorn
console
gunicorn --workers=4 --threads=2  -k uvicorn.workers.UvicornWorker main:app --log-level warning

(2) json responseではorjsonを使う

orjsonを使うことで、より高速にjsonを処理できます。

pip install orjson
hello.py
from fastapi.responses import ORJSONResponse
@router.get("/hello")
def resp_orjon():
    return ORJSONResponse({"Hello": "World"})

(3) データベースに「asyncpg」を利用する

ormにこだわらなければ、asyncpgを使うことでより高速になります。
※ asyncpgはPostgreSQLのframeworkです。

example.py
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
import asyncpg
import settings

db_pool = None

app = FastAPI()

@app.on_event("startup")
async def startup_event():
    global db_pool
    db_pool = await asyncpg.create_pool(
        user=settings.db_user,
        password=settings.db_password,
        database=settings.db_database_name,
        host=settings.db_host,
        port=settings.db_port,
        min_size=settings.db_min_size,
        max_size=settings.db_max_size
    )

@app.on_event("shutdown")
async def shutdown_event():
    db_pool.terminate()

PERSON_QUERY="""
select id, name, age from public.person
"""

@app.get('/persons')
async def personlist():
    async with db_pool.acquire() as connection:
        async with connection.transaction():
            data = [dict(record) async for record in connection.cursor(PERSON_QUERY)]
            return ORJSONResponse(data)

3. asyncpg Tips

asyncpgには、テーブルやクエリをCSVに出力する機能が使えます。

(1) copy_from_table

copy_from_tableを使えば、テーブルを直接csvファイルに出力できます。

db.py
    async with db_pool.acquire() as connection:
        async with connection.transaction():
            await connection.copy_from_table("person", output='person.csv', format='csv')
            return FileResponse('person.csv', filename='ユーザ.csv')

(2) copy_from_query

copy_from_queryも同様、テーブルを直接csvファイルに出力できます。

db.py
    async with db_pool.acquire() as connection:
        async with connection.transaction():
            await connection.copy_from_query("select * from person where id=$1", 1 , output='query.csv', format='csv')
            return FileResponse('person.csv', filename='ユーザ.csv')

4. APIKeyCookieを使った認証

APIKeyCookieを使った認証です。

①フォルダ/ファイル構成

├── authentication.py
├── main.py
└── settings.py

②authentication.py

authentication.py
from fastapi import HTTPException, Depends
from fastapi.security import APIKeyCookie
from starlette import status
from jose import jwt
from passlib.context import CryptContext
from settings import secret_key

cookie_security = APIKeyCookie(name="session")

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

users = {}

"""
本番ではDBから取得して設定するなど、環境に応じて編集してください。
"""
async def get_db_user(username):
    if username in users:
        return users[username]
    return None

def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

async def authenticate(username, password):
    errors=[]
    token=None
    user = await get_db_user(username)
    if user:
        db_password = user["password"]
        if verify_password(password, db_password):
            token = jwt.encode({"username": username, 'role': user["role"]}, secret_key)
        else:
            errors.append('username or password invalid')
    else:
        errors.append('username or password invalid')
    return errors, token

def get_current_user(session: str = Depends(cookie_security)):
    try:
        payload = jwt.decode(session, secret_key)
        print(payload)
        return payload["username"]
    except Exception:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication"
        )

def get_password_hash(password):
    return pwd_context.hash(password)

"""
for test
"""
def set_fake_users():
    # username=admin, password=secret
    # username=guest, password=secret
    password = get_password_hash('secret')
    users["admin"] = {"password": password, 'role': 'admin'}
    password = get_password_hash('secret')
    users["guest"] = {"password": password, 'role': 'guest'}
    print(users)

set_fake_users()

③main.py

main.py
from fastapi import FastAPI, Form, HTTPException, Depends
from fastapi.responses import HTMLResponse, RedirectResponse
from starlette import status
from authentication import authenticate, get_current_user

app = FastAPI()

@app.get("/")
async def root_page():
    return HTMLResponse(
        """
        <h1>Please Login</h1>
        <a href='/login'>Login</a>
        """
    )

@app.get("/login")
async def login_page():
    return HTMLResponse(
        """
        <h1>Login</h1>
        <form action="/login" method="post">
        Username: <input type="text" name="username" required>
        <br>
        Password: <input type="password" name="password" required>
        <br>
        <input type="submit" value="Login">
        </form>
        """
    )

@app.post("/login")
async def login(username: str = Form(...), password: str = Form(...)):
    err, token = await authenticate(username, password)
    if not token:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="Invalid user or password"
        )
    response = HTMLResponse("""
        <h1>Home</h1>
        <p>welcome to mysite</p>
        <a href='/private'>private</a>
        """)
    response.set_cookie("session", token)
    return response

@app.get("/private")
async def read_private(username: str = Depends(get_current_user)):
    return HTMLResponse(f"""
        <h1>welcome to secret page</h1>
        <p>your name: {username}</p>
        <form action='/logout' method='get'>
            <button type='submit'>logout</button>
        </form>
        """)

@app.get("/logout")
async def read_private():
    response = HTMLResponse("""
        <h1>Logout</h1>
        <a href='/login'>Login</a>
    """)
    response.set_cookie("session", '')
    return response

2021.01.24 追加

5. CSRF Protection

starlette-wtfを使ってCSRF Protectionを実装する例です。

①install

bash
pip install starlette-wtf

②Directories and Files

Directories,Files
├── main.py
└── templates
    └── index.html

③main.py

secret_key, csrf_secretには、安全なtoken文字列を設定してください。

main.py
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette_wtf import StarletteForm, CSRFProtectMiddleware, csrf_protect
from wtforms import StringField
from wtforms.validators import DataRequired

class MyForm(StarletteForm):
    name = StringField('name', validators=[DataRequired()])

templates = Jinja2Templates(directory="templates")

app = FastAPI(middleware=[
    Middleware(SessionMiddleware, secret_key='secret_key_token'),
    Middleware(CSRFProtectMiddleware, csrf_secret='csrf_secret_token')
])

@app.route('/', methods=['GET', 'POST'])
@csrf_protect
async def index(request:Request):
    form = await MyForm.from_formdata(request)
    print(type(form.name), form.name)
    if await form.validate_on_submit():
        return HTMLResponse(f"<h1>{form.name.data}</h1>")

    return templates.TemplateResponse("index.html", {"request": request, "form": form})

確認

実行して、表示された画面で「Submit」ボタンを押すと、入力した値が表示される。
protectionの挙動については、以下のように、curlを実行すると 403 forbiddenとなり、以下のエラーを取得できる。

# curl -X POST -H "Content-Type: application/json" -d '{"name":"aaaaa"}' localhost:8000/
{"detail":"The CSRF token is missing."}
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Sign upLogin
84
Help us understand the problem. What are the problem?