LoginSignup
0
4

More than 3 years have passed since last update.

Pythonで関数実行時に引数の型アノテーションをチェックしてエラーにする

Posted at

はじめに

Pythonをお使いの方なら既知ですが、Pythonでは型による強制力はありません。
型アノテーション(typing)がPython3.5で追加されるまでは型の記述もできませんでした。
コメントや変数名で空気を読むしかできないわけです。
(個人的に辞書は最悪でした)

なので型アノテーションが実装されてからは本当に助かっています。
もう癖になっているので書かない方が違和感があります。
(同じ思いの方も多いかと思います笑)

ただ依然として強制力はなく、mypyやVSCode拡張のPylanceでチェックができるに留まっています。
第三者にモジュールとして提供する場合はIFの型チェックはできないわけで実装の中でケアしないといけないです。
こういったケア(チェック処理)をわざわざ1つ1つに対して書くのは手間で本質的ではないので,
簡潔に書きたいというのがモチベーションです。

アプローチ

Pythonにはデコレータという便利な機能があります。
デコレータを使うと関数の実行前に処理を実行できるので、
チェック処理をこの中で行ってあげるというアプローチになります。

なので各関数ではデコレータだけ書いてしまえばOKということになります。

実装

以下のようなデコレータ関数を定義します。
error引数には型アノテーションの不一致があった際のErrorを指定できます。
check_all_collectionはCollection型の引数をチェックする際に全件チェックするかどうかを指定できます。

"""
引数の型をチェックするデコレータ定義ファイル
"""
import functools
import inspect
from typing import Any, Union, Callable, _GenericAlias


def check_args_type(error: Exception = TypeError, check_all_collection: bool = False):
    """
    引数の型がアノテーションの型と一致しているかチェックを行うデコレータ関数
    Args:
        error: 不一致時のエラークラス
        check_all_collection: コレクション型の中身を全てチェックするか
    """
    def _decorator(func: Callable):
        @functools.wraps(func)
        def args_type_check_wrapper(*args, **kwargs):
            sig = inspect.signature(func)
            try:
                for arg_key, arg_val in sig.bind(*args, **kwargs).arguments.items():
                    # アノテーションがタイプでない/空の場合は判定しない
                    annotation = sig.parameters[arg_key].annotation
                    if not isinstance(annotation, type) and not isinstance(annotation, _GenericAlias):
                        continue
                    if annotation == inspect._empty:
                        continue

                    # 一致判定
                    # Generic系のタイプだった場合は派生形・一部が一致していればOK
                    is_match = __check_generic_alias(annotation, arg_val, check_all_collection)
                    if not is_match:
                        message = f"引数'{arg_key}'の型が正しくありません。annotaion:{annotation} request:{type(arg_val)}"
                        raise error(message)
            except TypeError as exc:
                raise error("引数の型か数が一致しません。") from exc
            return func(*args, **kwargs)
        return args_type_check_wrapper
    return _decorator

def __check_generic_alias(
    annotation: Union[_GenericAlias, type],
    request: Any,
    check_all_collection: bool = False
):
    """
    GenericAliasの型チェック
    Args:
        annotation: アノテーションタイプ
        request: リクエスト
        check_all_collection: コレクション型の中身を全てチェックするか
    """
    # Anyの場合は型チェックしない
    if annotation == Any:
        return True

    # 型チェック
    request_type = type(request)
    if isinstance(annotation, _GenericAlias):
        if annotation.__origin__ == request_type:    # for collection ...list, dict, set
            # -----------
            # list
            # -----------
            if annotation.__origin__ == list and request:
                _annotation = annotation.__args__[0]
                if check_all_collection:    # 全件チェックの場合は1つずつ確認
                    for _request in request:
                        is_match = __check_generic_alias(
                            _annotation, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    return True

                else:   # 全件チェックでない場合は先頭を取り出して確認
                    return __check_generic_alias(
                        _annotation, request[0], check_all_collection
                    )

            # -----------
            # dict
            # -----------
            if annotation.__origin__ == dict and request:
                _annotation_key = annotation.__args__[0]
                _annotation_value = annotation.__args__[1]
                if check_all_collection:    # 全件チェックの場合は1つずつ確認
                    for _request in request.keys():
                        is_match = __check_generic_alias(
                            _annotation_key, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    for _request in request.values():
                        is_match = __check_generic_alias(
                            _annotation_value, _request, check_all_collection
                        )
                        if not is_match:
                            return False
                    return True

                else:   # 全件チェックでない場合は先頭を取り出して確認
                    is_match_key = __check_generic_alias(
                        _annotation_key, list(request.keys())[0], check_all_collection
                    )
                    is_match_value = __check_generic_alias(
                        _annotation_value, list(request.values())[0], check_all_collection
                    )
                    is_match = is_match_key and is_match_value
                    return is_match

            # 中身が存在してない場合,originがあっていればOKとする
            if not request:
                return True

        else:
            # list/dictの場合はoriginが一致していないとエラーとする
            origin = annotation.__origin__
            if origin == list or origin == dict:
                return False
            # それ以外は再帰的にチェック
            else:
                for arg in annotation.__args__:
                    is_match = __check_generic_alias(arg, request)
                    if is_match:
                        return True
    else:
        # BoolはintのサブクラスなのでissubclassでTureとなる
        # 本来意味合いが違うのでNGとしたい
        if request_type == bool and annotation == int:
            return False
        return issubclass(request_type, annotation)
    return False

使用例はその1。

# 一番シンプルなパターン
@check_args_type()
def test(value: int, is_valid: bool) -> float:
    """
    (省略)
    """
    return 0.0

def main():
    # OK
    result = test(5, True)

    # NG -> TypeError
    result = test(0.0, False)

    # NG2 -> TypeError
    result = test(1, "True")

使用例その2。

# Collectionの中身を全てチェックするパターン
@check_args_type(check_all_collection=True)
def test2(value: List[int]) -> List[float]:
    """
    (省略)
    """
    return [0.0]

def main():
    # OK
    result = test2([0, 5, 10, 20])

    # NG -> TypeError
    result = test([0.0, 5.0, 10.0, 20.0])

    # NG2 -> TypeError
    result = test([0, 5, "test"])

Enumやgeneratorなど考慮不足の型があるかと思いますが、基本的な型であれば網羅できているかと思います。
(必要であれば追加していく形でお願いします)

まとめ

関数実行時に引数の型アノテーションをチェックしてエラーにする方法を紹介しました。
これで型による強制力が発揮できます。
厳密性が求められる場面(IFなどの境界)では使えるかなと思います。

PS)
契約プログラミングのように値まで確認できれば最高なので、拡張予定です。

0
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
4