FastAPIは、pythonのWEB APIを作成するための非同期のframeworkで容易に高性能なAPIを作成できるということです。
Tutorialなどを参考に試してみます。
0.前提条件
-
Windows10(WSL2)にUbuntu 20.04がインストールされているものとします。
-
Python, pipはインストール済であるものとします。
-
環境は以下のとおり
category | value |
---|---|
CPU | core i5 10210U 160GHz |
MEMORY | 8GB |
OS | Ubuntu 20.04.1 LTS(WSL2, Windows 10 home 1909) |
Python | 3.8.2 |
database | PostgreSQL 12.4 |
1. CRUD APIの作成
(1). 作成するアプリのファイル構成
usersテーブルのCRUDを作成します。ファイルは6つほど作成しますが、それぞれ少ないコード量となっています。
fastapi-crud-example
│ db.py
│ main.py
│
├─users
│ │ models.py
│ │ router.py
│ └─ schemas.py
│
└─utils
dbutils.py
(2). source
db.py
sqlalchemyは、modelの定義とクエリの生成で使い、
databaseへのアクセスはDatabasesを使います。
import databases
import sqlalchemy
DATABASE = 'postgresql'
USER = 'testuser'
PASSWORD = 'secret'
HOST = 'localhost'
PORT = '5432'
DB_NAME = 'testdb'
DATABASE_URL = '{}://{}:{}@{}:{}/{}'.format(DATABASE, USER, PASSWORD, HOST, PORT, DB_NAME)
# databases
database = databases.Database(DATABASE_URL, min_size=5, max_size=20)
ECHO_LOG = False
engine = sqlalchemy.create_engine(DATABASE_URL, echo=ECHO_LOG)
metadata = sqlalchemy.MetaData()
users/models.py
import sqlalchemy
from db import metadata, engine
users = sqlalchemy.Table(
"users",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, index=True),
sqlalchemy.Column("username", sqlalchemy.String, index=True),
sqlalchemy.Column("email", sqlalchemy.String, index=True),
sqlalchemy.Column("hashed_password", sqlalchemy.String),
sqlalchemy.Column("is_active", sqlalchemy.Boolean(), default=True),
sqlalchemy.Column("is_superuser", sqlalchemy.Boolean(), default=False)
)
metadata.create_all(bind=engine)
users/schemas.py
pydanticのmodelを使って、crud用もmodelを定義します。
from pydantic import BaseModel
# insert用のrequest model。id(自動採番)は入力不要のため定義しない。
class UserCreate(BaseModel):
username: str
email: str
password: str
is_active: bool
is_superuser: bool
# update用のrequest model
class UserUpdate(BaseModel):
id : int
username: str
email: str
password: str
is_active: bool
is_superuser: bool
# select用のrequest model。selectでは、パスワード不要のため定義しない。
class UserSelect(BaseModel):
username: str
email: str
is_active: bool
is_superuser: bool
users/router.py
crudの主要部になります。select, insert, update, deleteのいずれも、10行足らずで実装しています。コーディング量が少ないのはとてもいいですね。
import hashlib
from fastapi import APIRouter, Depends
from typing import List
from starlette.requests import Request
from .models import users
from .schemas import UserCreate, UserUpdate, UserSelect
from databases import Database
from utils.dbutils import get_connection
router = APIRouter()
# 入力したパスワード(平文)をハッシュ化して返します。
def get_users_insert_dict(user):
pwhash=hashlib.sha256(user.password.encode('utf-8')).hexdigest()
values=user.dict()
values.pop("password")
values["hashed_password"]=pwhash
return values
# usersを全件検索して「UserSelect」のリストをjsonにして返します。
@router.get("/users/", response_model=List[UserSelect])
async def users_findall(request: Request, database: Database = Depends(get_connection)):
query = users.select()
return await database.fetch_all(query)
# usersをidで検索して「UserSelect」をjsonにして返します。
@router.get("/users/find", response_model=UserSelect)
async def users_findone(id: int, database: Database = Depends(get_connection)):
query = users.select().where(users.columns.id==id)
return await database.fetch_one(query)
# usersを新規登録します。
@router.post("/users/create", response_model=UserSelect)
async def users_create(user: UserCreate, database: Database = Depends(get_connection)):
# validatorは省略
query = users.insert()
values = get_users_insert_dict(user)
ret = await database.execute(query, values)
return {**user.dict()}
# usersを更新します。
@router.post("/users/update", response_model=UserSelect)
async def users_update(user: UserUpdate, database: Database = Depends(get_connection)):
# validatorは省略
query = users.update().where(users.columns.id==user.id)
values=get_users_insert_dict(user)
ret = await database.execute(query, values)
return {**user.dict()}
# usersを削除します。
@router.post("/users/delete")
async def users_delete(user: UserUpdate, database: Database = Depends(get_connection)):
query = users.delete().where(users.columns.id==user.id)
ret = await database.execute(query)
return {"result": "delete success"}
utils/dbutils.py
from starlette.requests import Request
# middlewareでrequestに格納したconnection(Databaseオブジェクト)を返します。
def get_connection(request: Request):
return request.state.connection
main.py
from fastapi import FastAPI
from db import database
from users.router import router as userrouter
from starlette.requests import Request
app = FastAPI()
# 起動時にDatabaseに接続する。
@app.on_event("startup")
async def startup():
await database.connect()
# 終了時にDatabaseを切断する。
@app.on_event("shutdown")
async def shutdown():
await database.disconnect()
# users routerを登録する。
app.include_router(userrouter)
# middleware state.connectionにdatabaseオブジェクトをセットする。
@app.middleware("http")
async def db_session_middleware(request: Request, call_next):
request.state.connection = database
response = await call_next(request)
return response
(3). 起動と確認
powershellから「uvicorn main:app --reload」を入力してenter
uvicorn main:app --reload
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO: Started reloader process [14108]
INFO: Started server process [21068]
INFO: Waiting for application startup.
INFO: Connected to database postgresql://testuser:********@localhost:5432/fastapidb
(4). Swaggerで確認
Swagger UIでapiにをテストできます。便利ですね。
chromeでhttp://127.0.0.1:8000/docsにアクセスしてみます。
(5). 最後に 感想
FastAPIはapiに特化したfreameworkという印象を受けましたが、jinja2などを使ったtemplate engineなども使えるし、oauth2などの認証機能も備わっています。FastAPIいい感じです。
2020.10.18 追加
2. さらに高速にしたい
(1) gunicornを使う
運用では、gunicornを使用することで、workerやthreadを調整して動かすことができます。
※Unix系の環境用なのでWindows環境では使えません。
pip install gunicorn
gunicorn --workers=4 --threads=2 -k uvicorn.workers.UvicornWorker main:app --log-level warning
(2) json responseではorjsonを使う
orjsonを使うことで、より高速にjsonを処理できます。
pip install orjson
from fastapi.responses import ORJSONResponse
@router.get("/hello")
def resp_orjon():
return ORJSONResponse({"Hello": "World"})
(3) データベースに「asyncpg」を利用する
ormにこだわらなければ、asyncpgを使うことでより高速になります。
※ asyncpgはPostgreSQLのframeworkです。
from fastapi import FastAPI
from fastapi.responses import ORJSONResponse
import asyncpg
import settings
db_pool = None
app = FastAPI()
@app.on_event("startup")
async def startup_event():
global db_pool
db_pool = await asyncpg.create_pool(
user=settings.db_user,
password=settings.db_password,
database=settings.db_database_name,
host=settings.db_host,
port=settings.db_port,
min_size=settings.db_min_size,
max_size=settings.db_max_size
)
@app.on_event("shutdown")
async def shutdown_event():
db_pool.terminate()
PERSON_QUERY="""
select id, name, age from public.person
"""
@app.get('/persons')
async def personlist():
async with db_pool.acquire() as connection:
async with connection.transaction():
data = [dict(record) async for record in connection.cursor(PERSON_QUERY)]
return ORJSONResponse(data)
3. asyncpg Tips
asyncpgには、テーブルやクエリをCSVに出力する機能が使えます。
(1) copy_from_table
copy_from_tableを使えば、テーブルを直接csvファイルに出力できます。
async with db_pool.acquire() as connection:
async with connection.transaction():
await connection.copy_from_table("person", output='person.csv', format='csv')
return FileResponse('person.csv', filename='ユーザ.csv')
(2) copy_from_query
copy_from_queryも同様、テーブルを直接csvファイルに出力できます。
async with db_pool.acquire() as connection:
async with connection.transaction():
await connection.copy_from_query("select * from person where id=$1", 1 , output='query.csv', format='csv')
return FileResponse('person.csv', filename='ユーザ.csv')
4. APIKeyCookieを使った認証
APIKeyCookieを使った認証です。
①フォルダ/ファイル構成
├── authentication.py
├── main.py
└── settings.py
②authentication.py
from fastapi import HTTPException, Depends
from fastapi.security import APIKeyCookie
from starlette import status
from jose import jwt
from passlib.context import CryptContext
from settings import secret_key
cookie_security = APIKeyCookie(name="session")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
users = {}
"""
本番ではDBから取得して設定するなど、環境に応じて編集してください。
"""
async def get_db_user(username):
if username in users:
return users[username]
return None
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
async def authenticate(username, password):
errors=[]
token=None
user = await get_db_user(username)
if user:
db_password = user["password"]
if verify_password(password, db_password):
token = jwt.encode({"username": username, 'role': user["role"]}, secret_key)
else:
errors.append('username or password invalid')
else:
errors.append('username or password invalid')
return errors, token
def get_current_user(session: str = Depends(cookie_security)):
try:
payload = jwt.decode(session, secret_key)
print(payload)
return payload["username"]
except Exception:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid authentication"
)
def get_password_hash(password):
return pwd_context.hash(password)
"""
for test
"""
def set_fake_users():
# username=admin, password=secret
# username=guest, password=secret
password = get_password_hash('secret')
users["admin"] = {"password": password, 'role': 'admin'}
password = get_password_hash('secret')
users["guest"] = {"password": password, 'role': 'guest'}
print(users)
set_fake_users()
③main.py
from fastapi import FastAPI, Form, HTTPException, Depends
from fastapi.responses import HTMLResponse, RedirectResponse
from starlette import status
from authentication import authenticate, get_current_user
app = FastAPI()
@app.get("/")
async def root_page():
return HTMLResponse(
"""
<h1>Please Login</h1>
<a href='/login'>Login</a>
"""
)
@app.get("/login")
async def login_page():
return HTMLResponse(
"""
<h1>Login</h1>
<form action="/login" method="post">
Username: <input type="text" name="username" required>
<br>
Password: <input type="password" name="password" required>
<br>
<input type="submit" value="Login">
</form>
"""
)
@app.post("/login")
async def login(username: str = Form(...), password: str = Form(...)):
err, token = await authenticate(username, password)
if not token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Invalid user or password"
)
response = HTMLResponse("""
<h1>Home</h1>
<p>welcome to mysite</p>
<a href='/private'>private</a>
""")
response.set_cookie("session", token)
return response
@app.get("/private")
async def read_private(username: str = Depends(get_current_user)):
return HTMLResponse(f"""
<h1>welcome to secret page</h1>
<p>your name: {username}</p>
<form action='/logout' method='get'>
<button type='submit'>logout</button>
</form>
""")
@app.get("/logout")
async def read_private():
response = HTMLResponse("""
<h1>Logout</h1>
<a href='/login'>Login</a>
""")
response.set_cookie("session", '')
return response
2021.01.24 追加
#5. CSRF Protection
starlette-wtfを使ってCSRF Protectionを実装する例です。
①install
pip install starlette-wtf
②Directories and Files
├── main.py
└── templates
└── index.html
③main.py
secret_key, csrf_secretには、安全なtoken文字列を設定してください。
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette_wtf import StarletteForm, CSRFProtectMiddleware, csrf_protect
from wtforms import StringField
from wtforms.validators import DataRequired
class MyForm(StarletteForm):
name = StringField('name', validators=[DataRequired()])
templates = Jinja2Templates(directory="templates")
app = FastAPI(middleware=[
Middleware(SessionMiddleware, secret_key='secret_key_token'),
Middleware(CSRFProtectMiddleware, csrf_secret='csrf_secret_token')
])
@app.route('/', methods=['GET', 'POST'])
@csrf_protect
async def index(request:Request):
form = await MyForm.from_formdata(request)
print(type(form.name), form.name)
if await form.validate_on_submit():
return HTMLResponse(f"<h1>{form.name.data}</h1>")
return templates.TemplateResponse("index.html", {"request": request, "form": form})
確認
実行して、表示された画面で「Submit」ボタンを押すと、入力した値が表示される。
protectionの挙動については、以下のように、curlを実行すると 403 forbiddenとなり、以下のエラーを取得できる。
# curl -X POST -H "Content-Type: application/json" -d '{"name":"aaaaa"}' localhost:8000/
{"detail":"The CSRF token is missing."}