はじめに
Pythonでは、メソッドやモジュール内の関数を上書きするのは比較的簡単です。Qiitaにもいくつか記事があります。
これらの手法は、以下のように属性アクセスをいじっているだけです。
オブジェクトとしての関数を弄っているわけではありません。
class SomeClass:
def original_method(self):
print('call original_method')
def new_method(self):
print('call new_method')
some_instance1 = SomeClass()
some_instance2 = SomeClass()
# some_instance1 のメソッドを書き換える
some_instance1.original_method = type(some_instance1.original_method)(new_method, some_instance1)
some_instance1.original_method() # new_method() が呼ばれる
some_instance2.original_method() # original_method() が呼ばれる
# すべてのインスタンスのメソッドを書き換える
SomeClass.original_method = new_method
some_instance1.original_method() # new_method() が呼ばれる
some_instance2.original_method() # new_method() が呼ばれる
import some_module
def new_func():
print('call new_func')
## モジュール内の関数を書き換える
some_module.original_func = new_func
some_module.original_func() # new_func() が呼ばれる
大抵はこれらの手法でなんとかなります。
しかし、属性アクセスを弄っても上書きできない場合もあります。
# 属性を上書きする前に取り出されていると、
original_method = some_instance1.original_method
# 属性を上書きしても、
type(some_instance1).original_method = new_method
# 先に取り出されたほうには影響がない
original_method() # もとの original_method() が呼ばれる
import some_module
from some_module import original_func
# モジュール内の関数でも同じ
some_module.original_func = new_func
original_func() # original_func() が呼ばれる
やりたいこと(使い方)
先に属性を取り出されていようが、プログラム全域に渡って上書きしたい。
import some_module
from some_module import original_func # たとえ先に取り出されていても
def new_func():
print('call new_func')
overwrite_func(some_module.original_func, new_func) # あとから上書きして
original_func() # ここで new_func() が呼ばれてほしい
できたもの(試作)
def overwrite_func(orig, new):
from uuid import uuid4
kw = 'kw' + str(uuid4()).replace('-', '')
exec("def outer():\n " + '='.join(list(orig.__code__.co_freevars) + ['None'])
+ "\n def inner(*args, " + kw + "=new, **kwargs):\n "
+ ','.join(orig.__code__.co_freevars)
+ "\n return " + kw + "(*args, **kwargs)\n return inner",
locals(), globals())
inner = outer()
orig.__code__ = inner.__code__
orig.__defaults__ = inner.__defaults__
orig.__kwdefaults__ = inner.__kwdefaults__
最初は __code__
を上書きするだけだろうと思っていたのですが、__code__.co_freevars
の個数(関数内で定義された関数が内部で使っている外側の関数の変数の個数?)が一致していないと代入できないようだったので exec
でfreevarsの数を合わせています。
できたもの(完成版)
試作版だとシグネチャが失われてしまうので、できるだけ残したバージョンです。
ただし、__code__.co_freevars
の調整のために __overwrite_func
がキーワード引数に追加されます。
def overwrite_func(orig, new, signature=None):
import inspect
from types import FunctionType
from textwrap import dedent
assert isinstance(orig, FunctionType), (orig, type(orig))
assert isinstance(new, FunctionType), (new, type(new))
if signature is None:
signature = inspect.signature(orig)
params = [
(str(p).split(':')[0].split('=')[0], p)
for k, p in signature.parameters.items()
if k != '__overwrite_func'
]
default = {p.name: p.default for _, p in params}
anno = {p.name: p.annotation for _, p in params}
args_kwargs = [
k if k[0] == '*' or p.kind == p.POSITIONAL_ONLY else k + '=' + k
for k, p in params
]
signature_ = [
(k + (':anno["' + k + '"]' if p.annotation != p.empty else '')
+ ('=default["' + k + '"]' if p.default != p.empty else ''),
not (p.kind == p.VAR_KEYWORD or p.kind == p.KEYWORD_ONLY))
for k, p in params
]
signature__ = [s for s, positional in signature_ if positional]
signature__.append('__overwrite_func=new')
signature__.extend(s for s, positional in signature_ if not positional)
signature__ = '(' + ', '.join(signature__) + ')'
if signature.return_annotation is not inspect.Signature.empty:
anno['return'] = signature.return_annotation
signature__ += ' -> anno["return"]'
source = dedent("""
def outer():
""" + '='.join(list(orig.__code__.co_freevars) + ['None']) + """
def inner""" + signature__ + """:
""" + ', '.join(orig.__code__.co_freevars) + """
return __overwrite_func(""" + ', '.join(args_kwargs) + """)
return inner
""")
globals_ = {}
exec(source, dict(new=new, default=default, anno=anno), globals_)
inner = globals_['outer']()
globals_.clear()
orig.__code__ = inner.__code__
orig.__defaults__ = inner.__defaults__
orig.__kwdefaults__ = inner.__kwdefaults__
orig.__annotations__ = inner.__annotations__
注意
今回作った関数は万能ではありません。
__code__
を持っていない特殊な関数や、__call__
が実装された呼び出し可能オブジェクトなどには無力です。
相手に合わせて使ってください。
overwrite_func(print, new_func) # assert は無効にしている
# → AttributeError: 'builtin_function_or_method' object has no attribute '__code__'
上書きする側の関数への参照が __overwrite_func
に溜まっていくので、メモリリークにお気をつけて。
応用例
def copy_func(f):
"""https://stackoverflow.com/questions/13503079"""
import functools
import types
assert isinstance(f, types.FunctionType), (f, type(f))
g = types.FunctionType(
f.__code__,
f.__globals__,
name=f.__name__,
argdefs=f.__defaults__,
closure=f.__closure__,
)
g.__kwdefaults__ = f.__kwdefaults__
functools.update_wrapper(g, f)
return g
def add_hook(func, pre_call=None, post_call=None, except_=None, finally_=None):
import inspect
func_sig = inspect.signature(func)
func_copy = copy_func(func)
def hook(*args, **kwargs):
bound_args = func_sig.bind(*args, **kwargs)
if pre_call is not None:
pre_call(func_copy, bound_args)
try:
return_ = func_copy(*args, **kwargs)
except BaseException as e:
if except_ is not None:
except_(func_copy, bound_args, e)
raise
else:
if post_call is not None:
post_call(func_copy, bound_args, return_)
finally:
if finally_ is not None:
finally_(func_copy, bound_args)
return return_
overwrite_func(func, hook)
コールバック関数を後付できます。
def callback(f, args, result):
print(result)
add_hook(original_func, post_call=callback)
original_func() # original_func() が呼ばれる前に callback() が呼ばれる。
おわりに
できることはできましたが、やらなくて済むならやらないほうが良いです。
__code__
まわりの仕様を理解しきれていないので、テストケースはたぶん足りていません。
動かない事例があれば教えて下さい。