概要
GoでWebアプリケーションを作るにあたり、gRPCとgrpc-gatewayを利用して作っています。
ここで何か全APIに共通の処理を書きたい場合、grpc-gatewayにミドルウェアを作成し、そこで処理をしてしまうことが多いです。今回はそのミドルウェアのテストを書くやり方をまとめます。
なおミドルウェアはgrpc-gatewayやgRPCに依存しているものではなく、net/httpを使っているミドルウェアであれば同様にテストが書けるはずです。
業務のロジックが含まれていて割愛している所も多く、また同様の事をしている例も多々あるかと思いますが、
実際に使われているものに近いミドルウェアとそのテストとして、何かしら参考になれば幸いです。
テストするミドルウェア
以下はアプリバージョンを渡してもらい、最低アプリバージョン以下だとエラーを返すというミドルウェアです。
実際のアプリケーションでは強制アップデートをかけるために利用しています。
package gateway
import (
"fmt"
"net/http"
"strconv"
"github.com/andfactory/xxx-webapp/domain/model"
"github.com/andfactory/xxx-webapp/domain/errors/code"
"github.com/andfactory/xxx-webapp/domain/errors"
"github.com/andfactory/xxx-webapp/library/env"
)
const (
slackTitleAppVersionInvalid = "appVersion-invalid"
headerKeyAppVersion = "App-Version"
)
var minimumAppVersionIos int
var minimumAppVersionAndroid int
func init() {
minimumAppVersionIos = env.GetMinimumAppVersionIos()
minimumAppVersionAndroid = env.GetMinimumAppVersionAndroid()
}
// getAppVersionHeader クライアントのアプリバージョンチェックを実施するミドルウェアを取得する
func getAppVersionHeader(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//不要なログ出力を避けるため、healthCheckとドキュメントルートではこのチェックをおこなわない
if r.RequestURI == "/health_check" || r.RequestURI == "/" {
h.ServeHTTP(w, r)
return
}
deviceTypeStr := r.Header.Get(headerKeyDeviceType)
deviceType, err := model.ConvertStringToDeviceType(deviceTypeStr)
if err != nil {
err := errors.WrapApplicationError(err, code.InvalidDevice, fmt.Sprintf("invalid device type: '%v'", deviceTypeStr))
setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
return
}
appVersionStr := r.Header.Get(headerKeyAppVersion)
appVersion, err := strconv.Atoi(appVersionStr)
if err != nil {
err := errors.WrapApplicationError(err, code.InvalidAppVersion, fmt.Sprintf("invalid application version: '%v'", appVersionStr))
setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
return
}
var minimumAppVersion int
switch deviceType {
case model.DeviceTypeIOS:
minimumAppVersion = minimumAppVersionIos
case model.DeviceTypeAndroid:
minimumAppVersion = minimumAppVersionAndroid
}
if appVersion < minimumAppVersion {
err := errors.NewApplicationError(code.NeedUpdateApplication, fmt.Sprintf("%s Application version too low. got %d want %d", deviceType, appVersion, minimumAppVersion))
setUnencryptedErrorResponse(w, slackTitleAppVersionInvalid, http.StatusBadRequest, err)
return
}
h.ServeHTTP(w, r)
})
}
テストコード
上記のミドルウェアに対しては、以下のようにテストを書くことができます。
package gateway_test
import (
"bytes"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/andfactory/xxx-webapp/adapter/grpc/presenter"
"github.com/andfactory/xxx-webapp/domain/errors/code"
"github.com/andfactory/xxx-webapp/infra/grpc/gateway"
)
//TestAppVersionSkip 特定のpassで処理をスキップする部分のテスト
func TestAppVersionSkip(t *testing.T) {
ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
defer ts.Close()
tests := []struct {
name string
pass string
isError bool
expectedCode code.ErrorCode
}{
{
name: "ルート",
pass: "/",
isError: false,
},
{
name: "ヘルスチェック",
pass: "/health_check",
isError: false,
},
{
name: "通常",
pass: "/test",
isError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u bytes.Buffer
u.WriteString(string(ts.URL))
u.WriteString(tt.pass)
req, _ := http.NewRequest("GET", u.String(), nil)
req.Header.Set(gateway.GetHeaderKeyDeviceType(), "invalidDeviceType")
req.Header.Set(gateway.GetHeaderKeyAppVersion(), "0")
res, err := gateway.Client.Do(req)
if err != nil {
t.Fatalf("request faiulure %v", err)
}
if res != nil {
defer res.Body.Close()
}
if tt.isError {
var d presenter.ErrorResponse
if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
t.Fatalf("request faiulure %v", err)
}
if d.Body.ErrorCode != code.InvalidDevice {
t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
}
} else {
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("request faiulure %v", err)
}
if string(b) != "OK" {
t.Fatalf("return want to be OK but returned %v", string(b))
}
}
})
}
}
//TestAppVersion appVersionでチェックする処理全般のテスト
func TestAppVersion(t *testing.T) {
ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
defer ts.Close()
var u bytes.Buffer
u.WriteString(string(ts.URL))
u.WriteString("/test")
gateway.SetMinimumAppVersionIos(50)
gateway.SetMinimumAppVersionAndroid(150)
tests := []struct {
name string
deviceType string
appVersion string
isError bool
expectedCode code.ErrorCode
}{
{
name: "不正なデバイス",
deviceType: "",
appVersion: "50",
isError: true,
expectedCode: code.InvalidDevice,
},
{
name: "不正なデバイス",
deviceType: "iOS",
appVersion: "50",
isError: true,
expectedCode: code.InvalidDevice,
},
{
name: "不正なデバイス",
deviceType: "3",
appVersion: "50",
isError: true,
expectedCode: code.InvalidDevice,
},
{
name: "iOS不正なバージョン",
deviceType: "1",
appVersion: "",
isError: true,
expectedCode: code.InvalidAppVersion,
},
{
name: "iOS不正なバージョン",
deviceType: "1",
appVersion: "1.1.1",
isError: true,
expectedCode: code.InvalidAppVersion,
},
{
name: "iOS強制アップデート",
deviceType: "1",
appVersion: "49",
isError: true,
expectedCode: code.NeedUpdateApplication,
},
{
name: "iOSミニマム",
deviceType: "1",
appVersion: "50",
isError: false,
},
{
name: "iOSミニマムより大きい",
deviceType: "1",
appVersion: "51",
isError: false,
},
{
name: "android不正なバージョン",
deviceType: "2",
appVersion: "",
isError: true,
expectedCode: code.InvalidAppVersion,
},
{
name: "android不正なバージョン",
deviceType: "2",
appVersion: "1.1.1",
isError: true,
expectedCode: code.InvalidAppVersion,
},
{
name: "android強制アップデート",
deviceType: "2",
appVersion: "149",
isError: true,
expectedCode: code.NeedUpdateApplication,
},
{
name: "androidミニマム",
deviceType: "2",
appVersion: "150",
isError: false,
},
{
name: "androidミニマムより大きい",
deviceType: "2",
appVersion: "151",
isError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", u.String(), nil)
req.Header.Set(gateway.GetHeaderKeyDeviceType(), tt.deviceType)
req.Header.Set(gateway.GetHeaderKeyAppVersion(), tt.appVersion)
res, err := gateway.Client.Do(req)
if err != nil {
t.Fatalf("request faiulure %v", err)
}
if res != nil {
defer res.Body.Close()
}
if tt.isError {
var d presenter.ErrorResponse
if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
t.Fatalf("request faiulure %v", err)
}
if d.Body.ErrorCode != tt.expectedCode {
t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
}
} else {
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("request faiulure %v", err.Error())
}
if string(b) != "OK" {
t.Fatalf("return want to be OK but returned %v", string(b))
}
}
})
}
}
func GetTestHandler() http.HandlerFunc {
fn := func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte("OK"))
return
}
return http.HandlerFunc(fn)
}
privateな情報にテストからアクセスできるようにexport_test.goを作成します。
package gateway
import (
"net/http"
)
var Client = new(http.Client)
var GetAppVersionHeader = getAppVersionHeader
func SetApplicationAppVersionIos(i int) {
applicationAppVersionIos = i
}
func SetApplicationAppVersionAndroid(i int) {
applicationAppVersionAndroid = i
}
func GetHeaderKeyDeviceType() string {
return headerKeyDeviceType
}
func GetHeaderKeyAppVersion() string {
return headerKeyAppVersion
}
解説
ミドルウェアのテストをするには、テストしたいミドルウェアのみを実行するサーバを作れば実現できます。
以下のようにエラーがなかった時用のハンドラを用意し、
func GetTestHandler() http.HandlerFunc {
fn := func(rw http.ResponseWriter, req *http.Request) {
rw.Write([]byte("OK"))
return
}
return http.HandlerFunc(fn)
}
テストしたミドルウェアを通してサーバを立ててあげます。
ts := httptest.NewServer(gateway.GetAppVersionHeader(GetTestHandler()))
defer ts.Close()
urlの設定は以下のようにすれば実現できます
var u bytes.Buffer
u.WriteString(string(ts.URL))
u.WriteString(tt.pass)
req, _ := http.NewRequest("GET", u.String(), nil)
gRPCとgrpc-getewayを使うときは共通のパラメータを送るときはhttpHeaderに設定し、gRPC飲めたデータとして処理しています。headerへの設定は以下のようにします。
req.Header.Set(gateway.GetHeaderKeyDeviceType(), tt.deviceType)
req.Header.Set(gateway.GetHeaderKeyAppVersion(), tt.appVersion)
これで、APIにアクセスします。なおクライアントはexport_test.goで作成して使いまわしています。appVersion_test.goで作っても良いのですが、他のミドルウェアのテストでも活用したいのでこのようになってます。
res, err := gateway.Client.Do(req)
if err != nil {
t.Fatalf("request faiulure %v", err)
}
if res != nil {
defer res.Body.Close()
}
あとはレスポンスの内容をチェックしてあげればOKです。
エラーの場合は特定の型のレスポンスを返すようにしてあるので、それをパースしてコードが意図したものになっていればOK。エラーでない場合はOKが返ってくれば正常です。
if tt.isError {
var d presenter.ErrorResponse
if err := json.NewDecoder(res.Body).Decode(&d); err != nil {
t.Fatalf("request faiulure %v", err)
}
if d.Body.ErrorCode != tt.expectedCode {
t.Fatalf("return want to be %v but returned %v", tt.expectedCode, d.Body.ErrorCode)
}
} else {
b, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("request faiulure %v", err.Error())
}
if string(b) != "OK" {
t.Fatalf("return want to be OK but returned %v", string(b))
}
}
参考
export_test.goを作って非公開の変数や関数を扱うやり方は以下で詳しく解説されてます。
非公開(unexported)な機能を使ったテスト
以下の記事でも同様の事が書かれています。
Unit Testing Golang HTTP Middleware