AWSを使用した自作モジュールを作ることが結構あるのですが、開発の初期段階から直接awsを実行すると時間もかかるしコストもかかります。
なのでできる限りローカルで開発可能なunittestを導入したいと思います。
ここでは自作モジュールを作る時の対象としてなりやすい、aws関連のモジュールのテスト方法を書いていきたいと思います。
機能フロー
- インスタンスがrunning状態であれば stoppedの状態にする
- stoppedの状態になった時点でインスタンスタイプの変更をする
- 起動をし直しrunning状態まで待つ( state: stoppedなら起動はしない)
機能のポイント
インスタンスの起動停止部分
インスタンスの状態を引数として、起動または停止後にステータスが変更されるまで待つ、というロジックにします。
def ensure_state(self,state,wait_time=10,wait_timeout=600):
passed_time = 0
if state != self.get_instance_state():
if state == 'stopped':
self.changed = True
retry(lambda: self.client.stop_instances(InstanceIds=[self.instance_id]))
elif state == 'running':
self.changed = True
retry(lambda: self.client.start_instances(InstanceIds=[self.instance_id]))
while passed_time <=wait_timeout:
if state == self.get_instance_state():
return True
time.sleep(wait_time)
return False
def get_instance_state(self):
res = retry(lambda: self.client.describe_instance_status(InstanceIds=[self.instance_id], IncludeAllInstances=True))
return res["InstanceStatuses"][0]["InstanceState"]["Name"]
インスタンタイプを変更する部分
ここはシンプルにインスタンタイプを変更するだけになります。
def change_instance_type(self):
self.changed = True
res = self.client.modify_instance_attribute(InstanceId=self.instance_id,
InstanceType={'Value':self.new_instance_type})
def check_instance_type(self):
res = retry(lambda: self.client.describe_instances(InstanceIds=[self.instance_id]))
return res["Reservations"][0]["Instances"][0]['InstanceType'] == self.new_instance_type
retry処理
上記のawsへの処理に対して retry処理を入れています。
ここでのretry処理は exponential backoff
と呼ばれるもので、APIコールが失敗した時に指数的にリトライ時間を延ばしていく、というものになります。
AWSの公式ドキュメントにもリトライ処理を進めるものがあります。
https://docs.aws.amazon.com/ja_jp/general/latest/gr/api-retries.html
よく見るエラーとしては RequestLimitExceeded
などですね。こういうエラーに対してのエラーハンドリング処理になります。
def retry(f, retries=5):
'''
This is wrapper for exponential backoff
'''
exception = None
for i in range(retries):
try:
return f()
except Exception as e:
exception = e
time.sleep(2 ** i)
raise exception
テストの前準備
ansibleのmoduleをテストする時はいくつかのpatchを事前に差し込む必要があります。
以下のansibleドキュメントにもありますが、
- moudleの引数を渡す部分
- module自体の終了時に意図的に例外を発生させる
が少なくとも必要になります。
詳しい仕様は公式ドキュメントに書いてあります。
https://docs.ansible.com/ansible/latest/dev_guide/testing_units_modules.html
moudleの引数を渡す部分
ansibleのmoduleの引数は ansible.module_utils.basic
で制御をしているため、以下のように引数を渡すための関数を作成します。
def set_module_args(args):
"""prepare arguments so that they will be picked up during module creation"""
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
module自体の終了時に意図的に例外を発生させる
ansibleのmoduleでは終了時に exit_json
または fail_json
を実行しています。これらの関数はsys.exit(0)
で処理を終了させてしまうため、意図的に例外を発生させることでテストを可能にします。
ポイントは
- unittestの
setUp
でexit_json
とfail_json
に対してpatchを当てて、返り値をAnsibleExitJson
かAnsibleFailJson
にする
です。
class AnsibleExitJson(Exception):
"""Exception class to be raised by module.exit_json and caught by the test case"""
pass
class AnsibleFailJson(Exception):
"""Exception class to be raised by module.fail_json and caught by the test case"""
pass
def exit_json(*args, **kwargs):
"""function to patch over exit_json; package return data into an exception"""
if 'changed' not in kwargs:
kwargs['changed'] = False
raise AnsibleExitJson(kwargs)
def fail_json(*args, **kwargs):
"""function to patch over fail_json; package return data into an exception"""
kwargs['failed'] = True
raise AnsibleFailJson(kwargs)
class TestMyModule(unittest.TestCase):
def setUp(self):
self.mock_module_helper = patch.multiple(basic.AnsibleModule,
exit_json=exit_json,
fail_json=fail_json)
self.mock_module_helper.start()
self.addCleanup(self.mock_module_helper.stop)
テストケース
ここまで来てやっとテストケースについて考えてみます。
- 最終的に指定をしたstateになっているか
- インスタンスタイプが変わっているか
この二つをテストしてみましょう。
motoを使ってmockをする
ec2やaws関連のリソースに対してはmotoというlibraryが便利です。
以下のように @mock_ec2
というアノテーションをつけるだけで boto3.client
へのAPIアクセスが全てモックされます。
import boto3
from unittest import mock
from moto import mock_ec2
def add_servers(instance_types,server_count=3):
ami_id = 'ami-1234abcd'
client = boto3.client('ec2', region_name='ap-northeast-1')
result = []
result.append(client.run_instances(
ImageId=ami_id,
InstanceType=instance_types,
MaxCount=server_count,
MinCount=server_count
))
return result[0]['Instances'][0]['InstanceId']
@mock_ec2
def test_change_instance_types(self):
実際のテストロジック
以下のようなテストロジックになります
- moduleの実行
-
describe_instances
でインスタンス情報を再度取得 - 期待した値に変更されているかを確認
class TestMyModule(unittest.TestCase):
def setUp(self):
self.mock_module_helper = patch.multiple(basic.AnsibleModule,
exit_json=exit_json,
fail_json=fail_json)
self.mock_module_helper.start()
self.addCleanup(self.mock_module_helper.stop)
@mock_ec2
def test_change_instance_types(self):
#expected values
current_instance_type = 'c5.large'
expected_status = 'running'
expected_instance_type = 'c5.xlarge'
region = 'ap-northeast-1'
#Make ec2 mock
instance_id = add_servers(current_instance_type)
#Setup args for ansible
set_module_args({
'aws_access_key': 'AWSKEYIOSFODNN7EXAMPLE',
'aws_secret_key': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
'instance_id': instance_id,
'state': expected_status,
'new_instance_type': expected_instance_type,
'region': region
})
#Call actual ansible module
with self.assertRaises(AnsibleExitJson):
test_instance_type_change.main()
#Confirm 1. check instance type 2. status is running
client = boto3.client('ec2', region_name='ap-northeast-1')
for ins in client.describe_instances(InstanceIds=[instance_id])['Reservations'][0]['Instances']:
self.assertEqual(ins['State']['Name'], expected_status)
self.assertEqual(ins['InstanceType'], expected_instance_type)
実行してみる
問題なくtestが実行されていることが分かります。
root@3560cc651526:/tmp/library# python3 test_module.py
----------------------------------------------------------------------
Ran 1 test in 0.541s
OK
完成したmoudleとテスト
自作モジュール部分
import time
try:
import boto3
import botocore
HAS_BOTO3_API = True
except ImportError:
HAS_BOTO3_API = False
from ansible.module_utils.ec2 import aws_common_argument_spec
class ChangeInstanceType(object):
def __init__(self, module):
self.changed = False
self.module = module
self.aws_access_key = module.params.get('aws_access_key')
self.aws_secret_key = module.params.get('aws_secret_key')
self.region = module.params.get('region')
self.instance_id = module.params.get('instance_id')
self.new_instance_type = module.params.get('new_instance_type')
self.state =module.params.get('state')
if not HAS_BOTO3_API:
self.module.fail_json(changed=False, msg="Python package boto3 is required")
try:
self.client = boto3.client(
'ec2',
aws_access_key_id=self.aws_access_key,
aws_secret_access_key=self.aws_secret_key,
region_name=self.region
)
except botocore.exceptions.ClientError as e:
self.module.fail_json(changed=self.changed, msg="Cannot initialize connection to ec2: {}".format(e))
def ensure_state(self,state,wait_time=10,wait_timeout=600):
passed_time = 0
if state != self.get_instance_state():
if state == 'stopped':
self.changed = True
retry(lambda: self.client.stop_instances(InstanceIds=[self.instance_id]))
elif state == 'running':
self.changed = True
retry(lambda: self.client.start_instances(InstanceIds=[self.instance_id]))
while passed_time <=wait_timeout:
if state == self.get_instance_state():
return True
time.sleep(wait_time)
return False
def get_instance_state(self):
res = retry(lambda: self.client.describe_instance_status(InstanceIds=[self.instance_id], IncludeAllInstances=True))
return res["InstanceStatuses"][0]["InstanceState"]["Name"]
def change_instance_type(self):
self.changed = True
res = self.client.modify_instance_attribute(InstanceId=self.instance_id,
InstanceType={'Value':self.new_instance_type})
def check_instance_type(self):
res = retry(lambda: self.client.describe_instances(InstanceIds=[self.instance_id]))
return res["Reservations"][0]["Instances"][0]['InstanceType'] == self.new_instance_type
def main(self):
if self.check_instance_type():
self.changed = False
return False
if not self.ensure_state('stopped'):
self.module.fail_json(changed=self.changed, msg="Module is failed when ensure the state")
self.change_instance_type()
if not self.ensure_state(self.state):
self.module.fail_json(changed=self.changed, msg="Module is failed when ensure the state")
def retry(f, retries=5):
'''
This is wrapper for exponential backoff
'''
exception = None
for i in range(retries):
try:
return f()
except Exception as e:
exception = e
time.sleep(2 ** i)
raise exception
def main():
argument_spec = aws_common_argument_spec()
argument_spec.update(dict(
aws_access_key=dict(required=True, type='str'),
aws_secret_key=dict(required=True, type='str', no_log=True),
region=dict(choices=['us-east-1', 'us-west-2', 'us-west-1', 'eu-west-1', 'eu-central-1', 'ap-southeast-1', 'ap-northeast-1', 'ap-southeast-2', 'ap-northeast-2', 'ap-south-1', 'sa-east-1']),
instance_id=dict(required=True, type='str'),
state=dict(type="str", choices=["running", "stopped"], default="running"),
new_instance_type=dict(required=True, type='str')
))
module = AnsibleModule(argument_spec=argument_spec)
instance_change_type = ChangeInstanceType(module)
instance_change_type.main()
module.exit_json(changed=instance_change_type.changed)
from ansible.module_utils.basic import *
if __name__ == '__main__':
main()
テスト部分
import json
import sys
import boto3
import botocore
import time
import threading
from pathlib import Path
from unittest import mock
from moto import mock_ec2
from ansible.compat.tests import unittest
from ansible.compat.tests.mock import patch
from ansible.module_utils import basic
from ansible.module_utils._text import to_bytes
import test_instance_type_change
def add_servers(instance_types,server_count=3):
ami_id = 'ami-1234abcd'
client = boto3.client('ec2', region_name='ap-northeast-1')
result = []
result.append(client.run_instances(
ImageId=ami_id,
InstanceType=instance_types,
MaxCount=server_count,
MinCount=server_count
))
return result[0]['Instances'][0]['InstanceId']
def stop_servers(instance_ids):
client = boto3.client('ec2', region_name='ap-northeast-1')
client.stop_instances(InstanceIds=instance_ids)
def set_module_args(args):
"""prepare arguments so that they will be picked up during module creation"""
args = json.dumps({'ANSIBLE_MODULE_ARGS': args})
basic._ANSIBLE_ARGS = to_bytes(args)
class AnsibleExitJson(Exception):
"""Exception class to be raised by module.exit_json and caught by the test case"""
pass
class AnsibleFailJson(Exception):
"""Exception class to be raised by module.fail_json and caught by the test case"""
pass
def exit_json(*args, **kwargs):
"""function to patch over exit_json; package return data into an exception"""
if 'changed' not in kwargs:
kwargs['changed'] = False
raise AnsibleExitJson(kwargs)
def fail_json(*args, **kwargs):
"""function to patch over fail_json; package return data into an exception"""
kwargs['failed'] = True
raise AnsibleFailJson(kwargs)
class TestMyModule(unittest.TestCase):
def setUp(self):
self.mock_module_helper = patch.multiple(basic.AnsibleModule,
exit_json=exit_json,
fail_json=fail_json)
self.mock_module_helper.start()
self.addCleanup(self.mock_module_helper.stop)
@mock_ec2
def test_change_instance_types(self):
#expected values
current_instance_type = 'c5.large'
expected_status = 'running'
expected_instance_type = 'c5.xlarge'
region = 'ap-northeast-1'
#Make ec2 mock
instance_id = add_servers(current_instance_type)
#Setup args for ansible
set_module_args({
'aws_access_key': 'AWSKEYIOSFODNN7EXAMPLE',
'aws_secret_key': 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
'instance_id': instance_id,
'state': expected_status,
'new_instance_type': expected_instance_type,
'region': region
})
#Call actual ansible module
with self.assertRaises(AnsibleExitJson):
test_instance_type_change.main()
#Confirm 1. check instance type 2. status is running
client = boto3.client('ec2', region_name='ap-northeast-1')
for ins in client.describe_instances(InstanceIds=[instance_id])['Reservations'][0]['Instances']:
self.assertEqual(ins['State']['Name'], expected_status)
self.assertEqual(ins['InstanceType'], expected_instance_type)
if __name__ == "__main__":
unittest.main()
おわりに
ansibleのテストは標準のunittestで実行できないものが結構あるので厄介ですね。
ansibleの公式のrepositoryでも各moduleに対してのテストがコミットされているので、それを参考に自作モジュールを作っていくのも良さそうです。