はじめに
私が運用している GraphQL API では、API クライアント単位で
一時間ごとの呼び出し回数に制限を設けています。いわゆるサービスクォータです。
最近、(回数制限があるが故に?)「1 回のクエリでなるべく多くの情報を引き出そうと巨大な(高コストな)クエリを投げてくる」行為が問題になりつつあります。
ここは利用者のモラルに任せていましたが、考えが甘かったようです。
そこで、API クライアントごとに制限を設ける方法を模索したので、その方法を書き記しておきます。
実装サンプルは GitHub に公開しています。
実装した環境
- java: 17
- spring-boot: 3.0.5
- spring-graphql: 1.1.3
実現イメージ
- API クライアントは一意に識別できる (今回は API Key で識別)
- API クライアントごとに「どれだけ高コストなクエリを発行できるか」という情報をサーバ側で管理している
- 例えば、有料プランの契約者は複雑度 100 のクエリまで許容されるのに対し、フリープランはその半分の複雑度のクエリしか実行できないとか、そういうイメージ
- コスト上限に抵触した場合、GraphQL クエリの実行は拒否され、エラー応答する
1. API クライアントを識別する
API クライアントが識別できないことには始まらないので、まずはそこから実装します。
今回は WebFlux を使っているので、reactor の Context に情報を格納することでスレッド間で情報を引き渡せるようにします。
WebMVC を使っている場合は ThreadLocal + ThreadLocalAccessor を用意すればよいでしょう。
1-1. API クライアントの情報を定義する
API クライアントは、自身が実行可能なクエリの深さ (maxDepth)、複雑度 (maxComplexity) を持った record クラスとします。
public record Client(String apiKey, int maxDepth, int maxComplexity) {
}
1-2. WebFilter で HTTP ヘッダを解析し API クライアントを特定する
WebFilter で HTTP ヘッダを解析し、 API クライアントを特定できたら Context に設定します。
@Component
@Order(0)
@RequiredArgsConstructor
public class ClientDetectionWebFilter implements WebFilter {
private final ClientRepository clientRepository;
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
final String apiKey = exchange.getRequest().getHeaders().getFirst("X-Api-Key");
return Mono.justOrEmpty(apiKey)
.filter(StringUtils::hasLength)
.flatMap(clientRepository::findByApiKey)
.flatMap(
client ->
chain
.filter(exchange)
.contextWrite(context -> context.putNonNull(Client.class, client)))
.switchIfEmpty(writeErrorResponse(exchange));
}
...
}
これで Context#get を使って API クライアント情報を参照できるようになりました。
2. GraphQL クエリの複雑度に応じてアクセス拒否する
2-1. Instrumentation を実装する
もともと graphql-java には
- クエリの深さ(ネストレベル)に応じてリクエストを拒否することができる graphql.analysis.MaxQueryDepthInstrumentation
- クエリの複雑度に応じてリクエストを拒否できる graphql.analysis.MaxQueryComplexityInstrumentation
というものが用意されています。
ただ、これらは制限値をコンストラクタで与える形で固定されているので、動的に判定できるような実装を用意します。
public class QueryComplexityMonitoringInstrumentation extends SimpleInstrumentation {
public interface OnQueryComplexityCalculatedListener {
void onQueryComplexityCalculated(QueryComplexityCalculatedEvent event);
}
private final FieldComplexityCalculator fieldComplexityCalculator;
private final OnQueryComplexityCalculatedListener listener;
@Override
public InstrumentationContext<ExecutionResult> beginExecuteOperation(
InstrumentationExecuteOperationParameters parameters) {
QueryTraverser queryTraverser = newQueryTraverser(parameters.getExecutionContext());
final int depth =
queryTraverser.reducePreOrder(
(env, acc) -> Math.max(getPathLength(env.getParentEnvironment()), acc), 0);
Map<QueryVisitorFieldEnvironment, Integer> valuesByParent = new LinkedHashMap<>();
queryTraverser.visitPostOrder(
new QueryVisitorStub() {
@Override
public void visitField(QueryVisitorFieldEnvironment env) {
int childsComplexity = valuesByParent.getOrDefault(env, 0);
int value = calculateComplexity(env, childsComplexity);
valuesByParent.compute(
env.getParentEnvironment(),
(key, oldValue) -> Optional.ofNullable(oldValue).orElse(0) + value);
}
});
final OperationDefinition.Operation operationType =
parameters.getExecutionContext().getOperationDefinition().getOperation();
final int totalComplexity = valuesByParent.getOrDefault(null, 0);
GraphQLContext context = parameters.getExecutionContext().getGraphQLContext();
// この中で複雑度の判定をする
listener.onQueryComplexityCalculated(
QueryComplexityCalculatedEvent.builder()
.operationType(operationType)
.depth(depth)
.complexity(totalComplexity)
.context(context)
.build());
return SimpleInstrumentationContext.noOp();
}
...
}
2-2. FieldComplexityCalculator を実装する
次に、クエリの複雑度の計算です。
FieldComplexityCalculator という interface は用意されているものの具象クラスが無いので、自分で実装することにします。
以下の仕様に決めました。
- アクセスするフィールド 1 つあたり
+1
- List フィールドは
{子要素の合計複雑度} * 10
- ただし、Relay の Connection 型で取得件数 (N) が指定されている場合、
{子要素の合計複雑度} * N
計算例
query {
users(first: 5) { # (2 + 1 + 121) + 1 = 125
pageInfo { # 1 + 1 = 2
hasNextPage # +1
}
totalCount # +1
edges { # 24 * 5 + 1 = 121
node { # (1 + 1 + 21) + 1 = 24
id # +1
name # +1
favorites { # (1 + 1) * 10 + 1 = 21
name # +1
description # +1
}
}
}
}
}
複雑度は下層から計算していきます。
この例では複雑度 = 125 ということになります。
上記の仕様をもとに、実装はこうなりました。
public class DefaultFieldComplexityCalculator implements FieldComplexityCalculator {
private final int defaultListWeight;
public DefaultFieldComplexityCalculator() {
this(10);
}
public DefaultFieldComplexityCalculator(int defaultListWeight) {
this.defaultListWeight = defaultListWeight;
}
@Override
public int calculate(FieldComplexityEnvironment environment, int childComplexity) {
final int weight = calculateWeight(environment);
return childComplexity * weight + 1;
}
protected int calculateWeight(FieldComplexityEnvironment environment) {
GraphQLOutputType type = environment.getFieldDefinition().getType();
if (type instanceof GraphQLNonNull) {
type = (GraphQLOutputType) ((GraphQLNonNull) type).getWrappedType();
}
if (!(type instanceof GraphQLList)) {
return 1;
}
if ("edges".equals(environment.getField().getName())) {
FieldComplexityEnvironment parentEnvironment = environment.getParentEnvironment();
GraphQLOutputType parentType = parentEnvironment.getFieldDefinition().getType();
if (isImplementsInterfaceNamed(parentType, "Connection")) {
Map<String, Object> parentArgs = parentEnvironment.getArguments();
return IntStream.of(
// for Cursor based Connection
Objects.requireNonNullElse((Integer) parentArgs.get("first"), -1),
Objects.requireNonNullElse((Integer) parentArgs.get("after"), -1),
// for Offset based Connection
Objects.requireNonNullElse((Integer) parentArgs.get("limit"), -1))
.filter(i -> i >= 0)
.max()
.orElse(defaultListWeight);
}
}
return defaultListWeight;
}
...
}
計算ルールは工夫の余地がありそうです。
ググると、@cost(weight: 2)
のような directive を用意して、計算コストが特に高いフィールドには特別に重み付けをするような例も見受けられました。
今回はそこまで細かいケアをするつもりは無いので、これでヨシとします。
3. アプリに設定する
Configuration を設定します。
複雑度に応じた挙動はここで実装しています。
WebFilter で設定した API クライアントの情報を参照して、アクセス可否のチェックを行います。
制限に抵触した場合はアクセス拒否したいので graphql.execution.AbortExecutionException をスローしています。
@Configuration(proxyBeanMethods = false)
public class Config {
@Bean
public Instrumentation queryComplexityMonitoringInstrumentation() {
return new QueryComplexityMonitoringInstrumentation(
new DefaultFieldComplexityCalculator(), queryComplexityCalculatedListener());
}
private OnQueryComplexityCalculatedListener queryComplexityCalculatedListener() {
return event -> {
final Client client = event.context().get(Client.class);
if (event.depth() > client.maxDepth()) {
throw new AbortExecutionException(
"Maximum query depth exceeded " + event.depth() + " > " + client.maxDepth());
}
if (event.complexity() > client.maxComplexity()) {
throw new AbortExecutionException(
"Maximum query complexity exceeded "
+ event.complexity()
+ " > "
+ client.maxComplexity());
}
};
}
}
動作確認
bootRun して API を動かしてみます。
成功例
API Key: key1
は制限が緩めなのでアクセスが許可されます。
$ curl http://localhost:8080/graphql \
-H 'Content-Type: application/json' \
-H 'X-Api-Key: key1' \
-d '{"query": "query { users(first: 5) { pageInfo { hasNextPage } totalCount edges { node { id name favorites { name description } } } } }"}' | jq .
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 633 100 497 100 136 37217 10184 --:--:-- --:--:-- --:--:-- 70333
{
"data": {
"users": {
"pageInfo": {
"hasNextPage": false
},
"totalCount": 3,
"edges": [
{
"node": {
"id": "u1",
"name": "taro",
"favorites": [
{
"name": "u1-fav1",
"description": "fav1 desc"
},
{
"name": "u1-fav2",
"description": "fav2 desc"
}
]
}
},
{
"node": {
"id": "u2",
"name": "jiro",
"favorites": [
{
"name": "u2-fav1",
"description": "fav1 desc"
},
{
"name": "u2-fav2",
"description": "fav2 desc"
}
]
}
},
{
"node": {
"id": "u3",
"name": "saburo",
"favorites": [
{
"name": "u3-fav1",
"description": "fav1 desc"
},
{
"name": "u3-fav2",
"description": "fav2 desc"
}
]
}
}
]
}
}
}
アクセス拒否例
API Key: key2
は制限がきつく、アクセスが拒否されました。
$ curl http://localhost:8080/graphql \
-H 'Content-Type: application/json' \
-H 'X-Api-Key: key2' \
-d '{"query": "query { users(first: 5) { pageInfo { hasNextPage } totalCount edges { node { id name favorites { name description } } } } }"}' | jq .
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 257 100 121 100 136 9746 10954 --:--:-- --:--:-- --:--:-- 42833
{
"errors": [
{
"message": "Maximum query complexity exceeded 125 > 100",
"extensions": {
"classification": "ExecutionAborted"
}
}
]
}
これで、期待通りの動作を実装できました。
まとめ
クエリの複雑度に応じてアクセスを制限する仕組みを実装しました。
- API クライアントの識別は WebFilter で行い、結果を reactor の Context で引き回す。
- クエリの複雑度は Instrumentation で計算し、必要なら AbortExecutionException をスローすることでアクセスを拒否する。