この記事を3行で
- AWS X-Rayをpytestで使うと便利
- 関数の通過や例外の発生をassertでテストできる
- X-Rayの可視化にFlameGraphを使えば、各関数の実行時間が分かりやすい
この記事を書く理由
AWS X-Rayが便利なので、AWS環境へのデプロイの前でも使える使い方を紹介したい。
完成後の挙動
この記事で作成する単体テストを、Pytestで実行すると、
単体テストが吐き出したX-Rayのデータをもとに、下のようなグラフがローカルのPC上に作成されます。
FlameGraphと呼ばれているグラフです。炎のように下から上に伸びていくことが特徴です。
グラフの縦の方向は関数の呼び出しを表しています。
たとえばこのグラフなら、下から上に読んで、lambda_handler関数がnetwork_process関数を呼び出して、そこからgoogle.co.jpへのリクエストをかけていることが分かります。
グラフの横の方向は、時間を指しています。たとえばこのグラフなら、実行時間全体の半分程度がnetwork_processであること、network_processのうちの2/3程度がgoogleへのリクエストだということが分かります。
実行速度の上で、どこが問題になっているのか、どこを改善すれば早くなるのかが分かりやすく、フローの関係や呼ばれた順序も明確になります。
実行環境
- python 3.12
今回使ったソースコードはこちらにあります
X-Rayを実装する
まずはX-Rayを使って、簡単なAWS Lambdaのソースを書きます。
サンプルソースとして、以下のような処理の流れを書きます。
①API GatewayからPOSTのリクエストを受ける
②値を検証する。検証の結果が不正なら400を返す
③urlをたたいて、外部のサイトを取ってくる
④書き込み処理をする
⑤結果を返す
X-Rayの実装はシンプルで、defの頭に@xray_recorder.capture
をつけること、関数全体の先頭でpatch_all
を実行すること、この2つだけで実装できます。
from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core import patch_all
from pydantic import BaseModel
import time
import requests
# X-Rayの基本設定をする
xray_recorder.configure(service="Application")
patch_all() # ライブラリの監視をする
# 入力チェック用のクラス
class InputClass(BaseModel):
value: str
url: str
@xray_recorder.capture("def validation") # X-Rayの設定
def validation(event):
"""入力チェック、値の検証"""
time.sleep(0.1)
return InputClass.model_validate_json(event["body"])
@xray_recorder.capture("def network_process") # X-Rayの設定
def network_process(input: InputClass):
"""通信処理"""
res = requests.get(input.url)
time.sleep(0.4)
return res.text
@xray_recorder.capture("def write_process") # X-Rayの設定
def write_process(result):
"""取得したファイルの書き込み処理"""
time.sleep(0.6)
@xray_recorder.capture("def lambda_handler") # X-Rayの設定
def lambda_handler(event, context):
"""Lambdaのエントリポイント"""
try:
# 入力を検証する
input = validation(event)
except Exception:
# 入力チェックエラーなら400を返す
return {
"statusCode": 400,
"body": "Invalid input",
}
# ネットワークの処理
result = network_process(input)
# ファイルへの書き込み処理
write_process(result)
# 結果を返す
return {
"statusCode": 200,
"body": "Hello World",
}
requirements.txtでは、aws-xray-sdkをインストールしておきます。
サンプル処理が使うライブラリ(requests: HTTP通信のライブラリ、pydantic: バリデータ)もインストールします。
aws-xray-sdk==2.13.0
requests==2.31.0
pydantic==2.7.1
実装したら、SAMのテンプレートにTracingとTracingEnabledを足して、X-Rayを有効化します。
AWSTemplateFormatVersion: "2010-09-09"
Transform: AWS::Serverless-2016-10-31
Description: >
flame chart
Sample SAM Template for flame chart
# More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst
Globals:
Function:
Timeout: 9
+ Tracing: Active
Api:
+ TracingEnabled: True
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function # More info about Function Resource: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#awsserverlessfunction
Properties:
CodeUri: project/
Handler: app.lambda_handler
Runtime: python3.12
Architectures:
- x86_64
Events:
HelloWorld:
Type: Api # More info about API Event Source: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#api
Properties:
Path: /hello
Method: post
以上の手順を踏んで実装すると、トレースマップが記録されるようになります。
SAMでデプロイしてから処理を実行して、マネジメントコンソールからトレースマップを確認すると、以下のような図が見えます。
Lambdaの実行元、Lambdaが呼び出した接続先(googleのサイト)の流れが図に現れています。
また、それぞれの処理にどれだけの時間がかかっているのかも分かります。
Pytestに実装を書き加える
このX-Rayのデータを、マネジメントコンソールなしで可視化できるようにします。
X-Rayをpytestから実行する方法はシンプルです。
from aws_xray_sdk.core import xray_recorder
from project.app import lambda_handler # テスト対象の関数をインポート
def test_app():
# 単体試験用のセグメントを開始する
# f.__name__は実行中の関数名
# sampling=1.0を入れないと、ダミーセグメントで単体試験が失敗するので必要
seg = xray_recorder.begin_segment(f.__name__, sampling=1.0)
# 対象の関数を実行する
result = lambda_handler(
{"body": json.dumps({"url": "https://google.co.jp", "value": "abc"})}, {}
)
# 単体試験用のセグメントを終了する
# context.end_segmentを使うと、X-Rayデータの送信処理が実行されない
xray_recorder.context.end_segment()
# X-Rayのデータはseg変数に入っている
# または、xray_recorder.current_segment()を実行するとseg変数と同じものが取れる
xray_data = json.loads(seg.serialize())
テストの対象関数を、begin_segment
とend_segment
で挟むことでX-Rayを使って単体試験ができます。X-Rayの結果はserializeをかけるとJSON形式で取得できます。
結果をそのままassertに使ってもいいのですが、グラフにすると便利です。
JavaScriptにFlameGraphを作るライブラリがあります。
X-RayとJavaScriptのライブラリはキー名と値のフォーマットが違うので、そこだけ整形します。
def to_d3_dataframe_format(segment: dict):
""" D3のデータフレーム形式に変換する """
# サブセグメント名、実行時間を取得する
name = segment.get("name", "-")
value = int((segment.get("end_time", 0) - segment.get("start_time", 0)) * 1000)
# サブセグメント名、実行時間を記録する
result = {
"name": name,
"value": value,
"children": [
to_d3_dataframe_format(seg)
for seg in segment.get("subsegments", [])
]
}
# 結果を返す
return result
これでX-Rayのローカルでの可視化ができます。
関数化する
テストケース全てで個別にsegmentの作成を実装するのは大変なので、関数化します。
pytestのtestsディレクトリに、次のようなファイルを作ります。
from aws_xray_sdk.core import xray_recorder
from functools import wraps
import json
from string import Template
from pathlib import Path
# 出力用のディレクトリを指定する
viewer_directory = Path(__file__).parent.parent / "viewer"
with open(viewer_directory / "template.html", encoding="utf-8") as fp:
TEMPLATE = Template(fp.read())
def observer(f):
""" X-Rayのテストを、監視可能な状態で実行する """
@wraps(f)
def target_function(*args, **kwargs):
# 単体試験用のセグメントを開始する
seg = xray_recorder.begin_segment(f.__name__, sampling=1.0)
# 対象の関数を実行する
response = f(*args, **kwargs)
# 単体試験用のセグメントを終了する
xray_recorder.context.end_segment()
# FlameGraph用のデータを取得する
flatten_list = []
d3_flamegraph_data = to_d3_dataframe_format(json.loads(seg.serialize()), flatten_list, indent=0)
# HTML形式のファイルにFlameGraphのデータを書き込む
with open(viewer_directory / f"{f.__name__}.html", "w", encoding="utf-8") as fp:
fp.write(TEMPLATE.substitute(
VIEW_DATA=json.dumps(d3_flamegraph_data), LIST_DATA=json.dumps(flatten_list)
))
return response
return target_function
def to_d3_dataframe_format(segment: dict, flatten_list, indent: int):
""" D3のデータフレーム形式に変換する """
# サブセグメント名、実行時間を取得する
name = segment.get("name", "-")
value = int((segment.get("end_time", 0) - segment.get("start_time", 0)) * 1000)
# 表データ用のリストに追加する
flatten_list.append(f"{"".join([" " for _ in range(indent)])}{name} ({value} ms)")
# サブセグメント名、実行時間を記録する
result = {
"name": name,
"value": value,
"children": [
to_d3_dataframe_format(seg, flatten_list, indent=indent + 1)
for seg in segment.get("subsegments", [])
]
}
# 結果を返す
return result
また、viewerのディレクトリに、下のようなtemplate.html
を作成しておきます。
<!DOCTYPE html>
<html lang="ja-jp">
<head>
<meta charset="utf-8" />
<link
rel="stylesheet"
type="text/css"
href="https://cdn.jsdelivr.net/npm/d3-flame-graph@4.1.3/dist/d3-flamegraph.css"
/>
</head>
<body>
<div id="chart"></div>
<div id="list" style="margin-top: 1rem"></div>
<script type="text/javascript" src="https://d3js.org/d3.v7.js"></script>
<script
type="text/javascript"
src="https://cdn.jsdelivr.net/npm/d3-flame-graph@4.1.3/dist/d3-flamegraph.min.js"
></script>
<script type="text/javascript">
const ViewData = $VIEW_DATA;
const ListData = $LIST_DATA;
const chart = flamegraph().width(960);
chart.setLabelHandler(function (d) {
return [d.data.name, "(" + d.data.value + "ms)"].join(" ");
});
d3.select("#chart").datum(ViewData).call(chart);
d3.select("#list")
.selectAll("div")
.data(ListData)
.enter()
.append("pre")
.text((d) => d);
</script>
</body>
</html>
この関数を作ると、単体テストに@observer
を付けるだけでFlameGraphの出力までされるようになります。
from project.app import lambda_handler
from .tools import observer
import json
+ @observer
def test_lambda_handler():
"""
単体試験: 正常系(正しい想定の入力を渡す)
期待結果: 200が返ること、全ての処理を通ること
"""
result = lambda_handler(
{"body": json.dumps({"url": "https://google.co.jp", "value": "abc"})}, {}
)
X-Rayのデータをassertする
X-Rayのデータには、通過した関数の名前だけでなく、処理の途中で投げられた例外の内容も記録されています。current_segmentにassertをかけることで、関数の通過や例外の内容の試験を簡単に書くことができます。
@observer
def test_invalid_parameter():
"""
単体試験: 準正常系(必須パラメータのない入力を渡す)
期待結果: 400が返ること、Network以降の処理が通らないこと
"""
result = lambda_handler({"body": json.dumps({"url": "https://google.co.jp"})}, {})
# 実行結果のX-Rayのセグメント情報をJSONで取得する
current_segment = xray_recorder.current_segment().serialize()
# 実行済みの処理を検証する
assert "def lambda_handler" in current_segment
assert "def validation" in current_segment
# 必須チェックで例外があったことを検証する
assert "exceptions" in current_segment # 例外があったことを確認する
assert "1 validation error for InputClass" in current_segment
assert "value" in current_segment # 年齢の項目が誤っている
assert "Field required" in current_segment # 必須チェックエラーである
# ネットワーク通信以降の処理が実行されていないことを検証する
assert not ("def network_process" in current_segment)
assert not ("def write_process" in current_segment)
# レスポンスを検証する
assert result["statusCode"] == 400
実行された関数が想定通りか、例外の内容が想定通りか、といった試験が書きやすく、分かりやすい単体試験になります。
まとめ
pytestでX-Rayを使うことで、パフォーマンスの記録や問題の特定が簡単になります。実装時点のパフォーマンスをgitで管理する、実装を変えながらパフォーマンスを比較する、異常系のパフォーマンスを計測する、といったこともできるようになります。
FlameGraphは便利です。
ぜひX-Rayと組み合わせて使っていただければと思います。
あらためて、今回使ったソースコードはこちらにあります