4
5

More than 1 year has passed since last update.

Pythonにオーバーロードを実装する。

Last updated at Posted at 2023-01-20

はじめに

Pythonにはないオーバーロード。
typing@overloadも所詮は型チェッカー。引数の型に応じて分岐してくれる機能はない。
ならば作ってみよう。

追記
1/20 一応こんなものはあるけど、第1引数しか見てない?
functools.singledispatch

1/22 アノテーションありなし混合に対応。これによりクラスメソッドにも使用可能。

オーバーロードの実装

def overload(f):
    from inspect import signature, _empty

    if overload.__dict__.get("funcs") is None:
        overload.__dict__.update({"funcs": {}})
    annotations = tuple([(i, j.annotation) for i, j in signature(f).parameters.items()])
    if overload.funcs.get(f.__qualname__):
        overload.funcs[f.__qualname__].update({annotations: f})
    else:
        overload.funcs.update({f.__qualname__: {annotations: f}})

    def fork(*args, **kwargs):
        check = {}
        for anno, h in overload.funcs[f.__qualname__].items():
            d = dict(anno)
            if len(anno) < len(args) + len(kwargs):
                continue
            argcheck = {
                name: typ == _empty or isinstance(i, typ)
                for (name, typ), i in zip(anno, args)
            }
            kwargcheck = {
                name: d[name] == _empty or isinstance(i, d[name])
                for name, i in kwargs.items()
            }
            if all((argcheck | kwargcheck).values()):
                check[h] = len(d.keys() - (argcheck | kwargcheck).keys())
        if not check:
            raise NameError(
                f"name '{f.__qualname__}({', '.join([type(i).__name__ for i in args]+[':'.join([i, type(j).__name__]) for i, j in kwargs.items()])})' is not defined"
            )
        g = min(check, key=check.get)
        return g(*args, **kwargs)

    return fork

使い方

関数をデコレートするだけ。アノテーションは型名で。

@overload
def f(a:list, b:str):
    return a, b

@overload
def f(a:int, b:int):
    return a + b

@overload
def f(a:list, b:tuple):
    return {b: a}
f([1], "2") # -> ([1], '2')

f(1, 2) # -> 3

f([1], (2,)) # -> {(2,): [1]}

f("1", "2") # -> 
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In [3], line 1
----> 1 f("1", "2")

Cell In [1], line 30, in overload.<locals>.fork(*args, **kwargs)
     28         check[h] = len(d.keys() - (argcheck | kwargcheck).keys())
     29 if not check:
---> 30     raise NameError(
     31         f"name '{f.__name__}({', '.join([type(i).__name__ for i in args]+[':'.join([i, type(j).__name__]) for i, j in kwargs.items()])})' is not defined"
     32     )
     33 g = min(check, key=check.get)
     34 return g(*args, **kwargs)

NameError: name 'f(str, str)' is not defined

クラスメソッド

class MyClass:
    @overload
    def __init__(self, a:str, b:int):
        self.ab = (a, b)
        
    @overload
    def __init__(self, a:int|float, b:int): # 3.10以降ではor演算子によるUnionTypeに対応
        self.ab = a + b
MyClass("a", 1).ab # -> ("a", 1)
MyClass(1, 2).ab # -> 3
MyClass(1., 2).ab # -> 3.0

仕組み

overload関数のメンバ変数として関数の辞書を用意し、名前とアノテーションで管理しています。
引数の方をチェックして、矛盾がないものの中で最も適合している関数を返します。

おわりに

細かな検証・調整はしていないので、使い方により想定通りに動かない問題が残っているかもしれません。
発見の際は報告いただけると幸いです。

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5