Servant でアクセス制御のWeb DSLを作ってみる

More than 1 year has passed since last update.

この記事は CAMPHOR- Advent Calendar 2015 14日目の記事です

3期OBの@lotz84_です!
CAMPHOR-ではHaskell人材の育成担当をしています :stuck_out_tongue_closed_eyes:

Haskellって?

知らない人のために簡単な解説をしておくと、Haskellは純粋関数型のプログラミング言語で、強力な型と柔軟に扱える関数が特徴です。昔から存在する言語であり実用的にも使われていますがPHPやJavaなどの他のプログラミング言語の影に隠れてそれほど知名度が高いわけでは無いと思います。しかし最近 Immutable や Stateless などをキーワードに関数型のプログラミング言語が再び注目されるようになってきています。

実際のHaskellのコードを見てみましょう。

quicksort :: Ord a => [a] -> [a]  
quicksort [] = []  
quicksort (x:xs) = quicksort [a | a <- xs, a <= x]
                     ++ [x]
                       ++ quicksort [a | a <- xs, a > x]

これはクイックソートを実装したものです1。1行目は型の宣言になっていて、quicksortは順序Ordが定まっている型aのリスト[a]を引数にして[a]を返り値に持つ関数です。2~5行目は関数の実装になっていて、2行目は空リスト[]が来た時は空リスト[]を返します。3~5行目が1番大事な部分で、まず引数のリストから先頭の要素を取り出し(x:xs)、残りのリストから先頭の要素よりも小さい要素を取り出したリストを再びソートしquicksort [a | a <- xs, a <= x]、その後に先頭の要素を挿入して++ [x]、残りのリストから先頭の要素よりも大きい要素を取り出したリストを再びソートしたものをくっつける++ quicksort [a | a <- xs, a > x]というクイックソートの定義をそのままの実装になっています。

やりたいことが短く綺麗に実装できるのがHaskellのいいところです。さらに強力な型がプログラムの安全性を保証してくれます。まさにいいことずくめの言語です。Haskellを使ってる企業の例としてはFacebookPrezi、日本だとGree, NTTデータ, Tsuru Capitalなどが有名です。

もしHaskellが気になってきたら以下の文献から読んでみてください。

Mac で Homebrew を使っているなら今すぐ

$ brew install haskell-stack
$ stack setup
$ alias ghc="stack ghc"       # コンパイラ
$ alias ghci="stack ghci"     # REPL
$ alias runghc="stack runghc" # インタプリタ

で環境を整えましょう!

アクセス制御のServant DSL

さて本題です。先日Haskell Advent Calendar 2015【型レベルWeb DSL】 Servantの紹介という記事を書きました。Servantは型レベルでWebアプリの振る舞いを記述するDSLでWebアプリを作るのをとても簡単にしてくれます。この記事を書いてる時、同じくCAMPHOR-の@shotarokと話しててアクセス制限のあるAPIを型で表現できないだろうかという話題になりました。例えばこのユーザーは読み込み権限があるからアクセス出来るけど別のユーザーは権限がないのでアクセス出来ないといったようなことです。Servant DSLは自分で作るのも簡単そうだったので早速挑戦してみました。レッツ 型レベルプログラミング!

作るAPIはこんな感じです。

type API = "averages" :> Get '[JSON] [Score]
         :<|> AuthToken '[Readable] :> "records" :> Get '[JSON] [Record]
         :<|> AuthToken '[Readable, Writable] :> "record" :> ReqBody '[JSON] Record :> Post '[JSON] Record

GET /averages には誰でもアクセスできますが、GET /recordsPOST /recordはそれぞれ適切な権限を持ったユーザーじゃないとアクセスできません。作るAPIはテストの成績を表示するようなものを想定しています。平均点は誰でも見れるけど個人情報のアクセスは制限がある感じです。まずはデータ構造を定義していきましょう。

data Subject = English
             | Math
             | Physics
             deriving (Show, Generic, Eq, Ord)

instance FromJSON Subject
instance ToJSON Subject

type Score = (Subject, Float)
type Name = String
type Record = (Name, [Score])

Subjectは教科を表す型で リクエスト・レスポンスで扱うためにFromJSON, ToJSONのインスタンスにしています。Score, Name, Recordはそれぞれリストやタプルで作られているので自動的にこれらのインスタンスになっています。

次に権限を表すデータ構造を定義します。

{-# LANGUAGE DataKinds #-}

data Permission = Readable
                | Writable
                deriving (Show, Eq)

dataを使って型(Permission)と値(Readable,Writable)を宣言すると同時に、DataKinds拡張を使ってカインド(Permission)と型(Readable, Writable)も定義しています。これで'[Readable, Writable]のような型レベルリストが使えるようになります。次にこの型レベルリストを値レベルで利用するための仕掛けを用意します。

class AllPermission (list :: [Permission]) where
  allPermission :: Proxy list -> [Permission]

instance AllPermission '[] where
  allPermission _ = []

instance (AllPermission ps) => AllPermission (Readable ': ps) where
  allPermission _ = Readable : allPermission (Proxy :: Proxy ps)

instance (AllPermission ps) => AllPermission (Writable ': ps) where
  allPermission _ = Writable : allPermission (Proxy :: Proxy ps)

これを使うと

allPermission (Proxy :: Proxy '[Readable, Writable]) == [Readable, Writable]

のように型レベルの情報を値レベルで取り出すことができます。これらを使ってServant DSLを作っていきましょう!

data AuthToken (permissions :: [Permission])

instance (AllPermission permissions, HasServer sublayout) => HasServer ((AuthToken permissions) :> sublayout) where
  type ServerT (AuthToken permissions :> sublayout) m = (TokenList -> EitherT ServantErr IO ()) -> ServerT sublayout m

  route _ subserver request respond = do
    let checkAuth tokenList =
          case lookup "Auth-Token" (requestHeaders request) >>= flip lookup tokenList of
            Just permissions -> if all . map (`elem` permissions) $ allPermission (Proxy :: Proxy permissions)
                                  then pure ()
                                  else left err403
            Nothing -> left err403
    route (Proxy :: Proxy sublayout) (subserver checkAuth) request respond

この様にHasServerのインスタンスにすることによってserve関数で扱えるようになります。serve関数の型は

serve :: HasServer layout => Proxy layout -> Server layout -> Application

のようになっていてServerHasServerの定義で

type ServerT layout m :: *
type Server layout = ServerT layout (EitherT ServantErr IO) Source

というような型族として宣言されています。結果としてAuthTokenをAPIの定義に使った場合(TokenList -> EitherT ServantErr IO ()) -> ServerT sublayout mが対応する処理としてユーザーが定義するものになります。第一引数の関数が権限をチェックする関数でrouteの定義の中でcheckAuthとして宣言されているものです。これはリクエストのヘッダーからAuth-Tokenの値を取り出して事前に与えられたアクセス権限のリストからそのトークンが持っているアクセス権限を探して妥当なリクエストかどうかチェックしています。今回はアクセス権限のリストを直接渡していますが実際はDBへのコネクション等を渡してこの中でDBに問い合わせるところまで実装するほうがいいでしょう。

以上のコードをまとめて動くものにすると以下のようになります

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Main where

import GHC.Generics
import Control.Monad.IO.Class
import Control.Monad.Trans.Either
import Control.Concurrent.STM
import Data.Aeson
import Data.ByteString (ByteString)
import Data.Function
import Data.List
import Network.Wai (requestHeaders)
import qualified Network.Wai.Handler.Warp as Warp
import Servant

data Subject = English
             | Math
             | Physics
             deriving (Show, Generic, Eq, Ord)

instance FromJSON Subject
instance ToJSON Subject

type Score = (Subject, Float)
type Name = String
type Record = (Name, [Score])

data Permission = Readable
                | Writable
                deriving (Show, Eq)

class AllPermission (list :: [Permission]) where
  allPermission :: Proxy list -> [Permission]

instance AllPermission '[] where
  allPermission _ = []

instance (AllPermission ps) => AllPermission (Readable ': ps) where
  allPermission _ = Readable : allPermission (Proxy :: Proxy ps)

instance (AllPermission ps) => AllPermission (Writable ': ps) where
  allPermission _ = Writable : allPermission (Proxy :: Proxy ps)

data AuthToken (permissions :: [Permission])

instance (AllPermission permissions, HasServer sublayout) => HasServer ((AuthToken permissions) :> sublayout) where
  type ServerT (AuthToken permissions :> sublayout) m = (TokenList -> EitherT ServantErr IO ()) -> ServerT sublayout m

  route _ subserver request respond = do
    let checkAuth tokenList =
          case lookup "Auth-Token" (requestHeaders request) >>= flip lookup tokenList of
            Just permissions -> if all id . map (`elem` permissions) $ allPermission (Proxy :: Proxy permissions)
                                  then pure ()
                                  else left err403
            Nothing -> left err403
    route (Proxy :: Proxy sublayout) (subserver checkAuth) request respond

type API = "averages" :> Get '[JSON] [Score]
         :<|> AuthToken '[Readable] :> "records" :> Get '[JSON] [Record]
         :<|> AuthToken '[Readable, Writable] :> "record" :> ReqBody '[JSON] Record :> Post '[JSON] Record

api :: Proxy API
api = Proxy

type TokenList = [(ByteString, [Permission])]

tokenList :: TokenList
tokenList = [ ("d8e8fca2dc0f896fd7cb4cb0031ba249", [])
            , ("cbec921310a6d6d6c4b4e4493f199f28", [Readable])
            , ("c89b95af796787bede99f5f857a7825f", [Readable, Writable])
            ]

server :: TVar [Record] -> Server API
server recordsTVar = getAverages :<|> getRecords :<|> postRecord
  where
    getAverages = do
        records <- liftIO . atomically . readTVar $ recordsTVar
        let f (xs@(x:_)) = (fst x, (sum . map snd $ xs) / (fromIntegral $ length xs))
        pure . map f . groupBy ((==) `on` fst) . sort . concat . map snd $ records
    getRecords checkAuth = do
        checkAuth tokenList
        records <- liftIO . atomically . readTVar $ recordsTVar
        pure records
    postRecord checkAuth record = do
        checkAuth tokenList
        liftIO . atomically . modifyTVar recordsTVar $ (record:)
        pure record

records :: [Record]
records = [ ("rin", [(English, 45), (Math, 88), (Physics, 87)])
          , ("len", [(English, 54), (Math, 45), (Physics, 11)])
          , ("ran", [(English, 86), (Math, 36), (Physics, 43)])
          ]

main :: IO ()
main = do
  recordsTVar <- atomically $ newTVar records
  putStrLn "Listening on port 8080"
  Warp.run 8080 $ serve api (server recordsTVar)

起動していろいろ実験してみてください

$ curl http://localhost:8080/averages
[["English",61.666668],["Math",56.333332],["Physics",47]]

$ curl -H "Auth-Token: aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" http://localhost:8080/records
$ #403

$ curl -H "Auth-Token: cbec921310a6d6d6c4b4e4493f199f28" http://localhost:8080/records
[["rin",[["English",45],["Math",88],["Physics",87]]],["len",[["English",54],["Math",45],["Physics",11]]],["ran",[["English",86],["Math",36],["Physics",43]]]]

以上、アクセス制御を行うServant DSLを実装してみました。要は型レベルリストを値レベルリストに変換してるだけですが別のServant DSLを実装する参考になればと思います。実装したDSLで微妙だなと思うのはチェックの実装を強制できてないところでしょうか。実装する人はcheckAuthを使わずに捨てればコンパイラに怒られること無くビルドできてしまいます。しかし実行時にDBのコネクション情報などを渡そうと思うとこの実装になってしまったのでもしもっといい実装等あれば教えて下さい :sweat: あと今回作ったDSLはANDの条件は表現できますがORの条件は作れないのでもし拡張して更にSemiringのインスタンスにしたよ!みたいな記事を書いてくださったらご一報お願いしますー


  1. 有名なクイックソートの例なのですが同時に実は効率は良くないというのも有名です(参考: Haskellの神話