LoginSignup
3
2

More than 1 year has passed since last update.

[Python] boto3でAthenaを使ったコードをMockでテストする

Posted at

AthenaのQueryを投げるコード

run.py
import boto3
from time import sleep

DATABASE = 'test_db'
PAGE_SIZE = 1000


class AthenaResult:

    def __init__(self, client, query_execution_id):
        self.client = client
        self.query_execution_id = query_execution_id
        self.next_token = False

    def __iter__(self):
        return self

    def __next__(self):
        if self.next_token is None:
            raise StopIteration
        kwargs = {
            "QueryExecutionId": self.query_execution_id,
            "MaxResults": PAGE_SIZE
        }
        if self.next_token:
            kwargs["NextToken"] = self.next_token
        res = self.client.get_query_results(**kwargs)
        self.next_token = res.get('NextToken')
        return res


class QueryException(Exception):
    pass


def athena_execute(query, use_paginator=True):
    athena = boto3.client('athena')
    response = athena.start_query_execution(
        QueryString=query,
        ClientRequestToken='ToEnsureTheRequestToCreateTheQueryIsIdempotent',
        QueryExecutionContext={
            'Database': DATABASE
        },
        WorkGroup='dev'
    )
    query_execution_id = response["QueryExecutionId"]
    print(f"{query_execution_id=}")

    while True:
        response = athena.get_query_execution(QueryExecutionId=query_execution_id)
        state = response["QueryExecution"]["Status"]["State"]
        if state == "SUCCEEDED":
            break
        elif state in ("FAILED", "CANCELLED"):
            raise QueryException("Query Failed.")
        else:
            sleep(5)

    if use_paginator:
        paginator = athena.get_paginator('get_query_results')
        return paginator.paginate(
            QueryExecutionId=query_execution_id,
            PaginationConfig={
                'PageSize': PAGE_SIZE,
                'MaxItems': 100000
            }
        )
    else:
        return AthenaResult(athena, query_execution_id)


def process_row(row, header=None):
    return dict(zip(header, row))


def main():
    res_iter = athena_execute(
        "SELECT * FROM test_db.test_table LIMIT 10000;",
        False
    )
    tot_processed_rows = 0
    for page_idx, results_page in enumerate(res_iter):
        for item_idx, row in enumerate(results_page['ResultSet']['Rows']):
            if page_idx == 0 and item_idx == 0:
                header = [d['VarCharValue'] for d in row['Data']]
                continue
            process_row([d['VarCharValue'] for d in row['Data']], header)
            tot_processed_rows += 1
    print(f"{tot_processed_rows=}")



if __name__ == '__main__':
    main()
  1. athena_execute(<sql>) で実行し、Iteratorが帰るので、Rowごとに処理を全ての結果に対して施せる。結果の行が多いときには、メモリに全て乗らないので、行ごとに処理を process_row(row, header)のように書けば良い。
  2. athena_executeの2つ目の引数use_paginator
    1. Trueの場合、boto3の get_paginator を使ってPageIteratorを返す。
    2. Falseの場合は、カスタムのAthenaResult というイテレータを返す。

テスト

Pythonでテストを書く場合、Athenaに実際Queryを投げずにテストを書いてAthena以外のロジックをテストしたい。 unittestのpatchを利用してboto3のclientをモックすることで可能である。

上のカスタムのAthenaResultは、start_query_executionと同じboto3.clientでget_query_resultsを読んでいるので、一つモックすることで、テストが可能。
自分で、QueryのResultに来るものをgenerate_athena_resultで定義する必要がある。今回は、1回目は、 NextTokenを返し、 NextToken がある場合には、 NextToken を返さないようにIterator部分がテストできるようにする。

test_boto3.py
from unittest.mock import patch
from run import athena_execute, process_row

FIELDS = ['name', 'age']
DATA = [
    [['a', 1], ['b', 2], ['c', 3], ['d', 4], ['e', 5], ['f', 6], ['g', 7], ['h', 8], ['i', 9]],
    [['j', 10], ['k', 11], ['l', 12], ['m', 13]],
]


def generate_athena_result(**kwargs):
    if kwargs.get("NextToken"):
        return {
            "ResultSet": {
                "Rows": [
                    {
                        "Data": [{"VarCharValue": d for d in row}]
                    } for row in  DATA[1]
                ]
            }
        }
    else:
        return {
           "ResultSet": {
               "Rows": [
                    {
                       "Data": [{"VarCharValue": d for d in row}]
                    } for row in [FIELDS] + DATA[0] # First iteration has field name
                ]
            },
           "NextToken": "next_token"
        }


def mock_athena(boto3_client):
    athena = boto3_client('athena')
    athena.start_query_execution.return_value = {"QueryExecutionId": "123456789"}
    athena.get_query_execution.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}}
    athena.get_query_results.side_effect = generate_athena_result
    return athena


@patch("run.boto3.client")
def test_athena_execute(boto3_client):
    mock_athena(boto3_client)
    res_iter = athena_execute("test sql", False)
    max_item = 0
    for page_idx, results_page in enumerate(res_iter):
        for item_idx, row in enumerate(results_page['ResultSet']['Rows']):
            if page_idx == 0 and item_idx == 0:
                header = [d['VarCharValue'] for d in row['Data']]
                continue
            dic = process_row([d['VarCharValue'] for d in row['Data']], header)
            max_item = max(dic['age'], max_item)
    assert max_item == 13

実行:

pytest test_boto3.py -s
======================================= test session starts =======================================
platform darwin -- Python 3.8.0, pytest-6.2.1, py-1.10.0, pluggy-0.13.1
rootdir: /Users/masato-naka/repos/bebit/gram-store-clean-up
collected 1 item                                                                                  

test_boto3.py query_execution_id='123456789'
.

======================================== 1 passed in 0.19s ========================================

その他

  • get_paginatorのテストはやってない. Iterator自体をMockするとテストの意味無くなりそうだからやってない。
  • moto というMock AWS ServicesにAthenaがあるからそっち使ったほうがいいかも。

参考

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2