Help us understand the problem. What is going on with this article?

[Python3] 関数のコードオブジェクトを書き換える

はじめに

Pythonでは、メソッドやモジュール内の関数を上書きするのは比較的簡単です。Qiitaにもいくつか記事があります。

これらの手法は、以下のように属性アクセスをいじっているだけです。
オブジェクトとしての関数を弄っているわけではありません。

class SomeClass:
    def original_method(self):
        print('call original_method')

def new_method(self):
    print('call new_method')

some_instance1 = SomeClass()
some_instance2 = SomeClass()

# some_instance1 のメソッドを書き換える
some_instance1.original_method = type(some_instance1.original_method)(new_method, some_instance1)  
some_instance1.original_method()  # new_method() が呼ばれる
some_instance2.original_method()  # original_method() が呼ばれる

# すべてのインスタンスのメソッドを書き換える
SomeClass.original_method = new_method
some_instance1.original_method()  # new_method() が呼ばれる
some_instance2.original_method()  # new_method() が呼ばれる
import some_module

def new_func():
    print('call new_func')

## モジュール内の関数を書き換える
some_module.original_func = new_func
some_module.original_func()  # new_func() が呼ばれる

大抵はこれらの手法でなんとかなります。
しかし、属性アクセスを弄っても上書きできない場合もあります。

# 属性を上書きする前に取り出されていると、
original_method = some_instance1.original_method

# 属性を上書きしても、
type(some_instance1).original_method = new_method

# 先に取り出されたほうには影響がない
original_method()  # もとの original_method() が呼ばれる
import some_module
from some_module import original_func

# モジュール内の関数でも同じ
some_module.original_func = new_func
original_func()  # original_func() が呼ばれる

やりたいこと(使い方)

先に属性を取り出されていようが、プログラム全域に渡って上書きしたい。

import some_module
from some_module import original_func  # たとえ先に取り出されていても

def new_func():
    print('call new_func')

overwrite_func(some_module.original_func, new_func)  # あとから上書きして
original_func()  # ここで new_func() が呼ばれてほしい

できたもの(試作)

def overwrite_func(orig, new):
    from uuid import uuid4
    kw = 'kw' + str(uuid4()).replace('-', '')
    exec("def outer():\n " + '='.join(list(orig.__code__.co_freevars) + ['None']) 
         + "\n def inner(*args, " + kw + "=new, **kwargs):\n  " 
         + ','.join(orig.__code__.co_freevars)
         + "\n  return " + kw + "(*args, **kwargs)\n return inner",
         locals(), globals())
    inner = outer()
    orig.__code__ = inner.__code__
    orig.__defaults__ = inner.__defaults__
    orig.__kwdefaults__ = inner.__kwdefaults__

最初は __code__ を上書きするだけだろうと思っていたのですが、__code__.co_freevars の個数(関数内で定義された関数が内部で使っている外側の関数の変数の個数?)が一致していないと代入できないようだったので exec でfreevarsの数を合わせています。

できたもの(完成版)

試作版だとシグネチャが失われてしまうので、できるだけ残したバージョンです。
ただし、__code__.co_freevars の調整のために __overwrite_func がキーワード引数に追加されます。

def overwrite_func(orig, new, signature=None):
    import inspect
    from types import FunctionType
    from textwrap import dedent
    assert isinstance(orig, FunctionType), (orig, type(orig))
    assert isinstance(new, FunctionType), (new, type(new))
    if signature is None:
        signature = inspect.signature(orig)
    params = [
        (str(p).split(':')[0].split('=')[0], p)
        for k, p in signature.parameters.items()
        if k != '__overwrite_func'
    ]
    default = {p.name: p.default for _, p in params}
    anno = {p.name: p.annotation for _, p in params}
    args_kwargs = [
        k if k[0] == '*' or p.kind == p.POSITIONAL_ONLY else k + '=' + k 
        for k, p in params
    ]
    signature_ = [
        (k + (':anno["' + k + '"]' if p.annotation != p.empty else '')
         + ('=default["' + k + '"]' if p.default != p.empty else ''),
         not (p.kind == p.VAR_KEYWORD or p.kind == p.KEYWORD_ONLY))
        for k, p in params
    ]
    signature__ = [s for s, positional in signature_ if positional]
    signature__.append('__overwrite_func=new')
    signature__.extend(s for s, positional in signature_ if not positional)
    signature__ = '(' + ', '.join(signature__) + ')'
    if signature.return_annotation is not inspect.Signature.empty:
        anno['return'] = signature.return_annotation
        signature__ += ' -> anno["return"]'
    source = dedent("""
    def outer():
        """ + '='.join(list(orig.__code__.co_freevars) + ['None']) + """
        def inner""" + signature__ + """:
            """ + ', '.join(orig.__code__.co_freevars) + """
            return __overwrite_func(""" + ', '.join(args_kwargs) + """)
        return inner
    """)
    globals_ = {}
    exec(source, dict(new=new, default=default, anno=anno), globals_)
    inner = globals_['outer']()
    globals_.clear()
    orig.__code__ = inner.__code__
    orig.__defaults__ = inner.__defaults__
    orig.__kwdefaults__ = inner.__kwdefaults__
    orig.__annotations__ = inner.__annotations__

注意

今回作った関数は万能ではありません。
__code__ を持っていない特殊な関数や、__call__ が実装された呼び出し可能オブジェクトなどには無力です。
相手に合わせて使ってください。

overwrite_func(print, new_func)  # assert は無効にしている
# → AttributeError: 'builtin_function_or_method' object has no attribute '__code__'

上書きする側の関数への参照が __overwrite_func に溜まっていくので、メモリリークにお気をつけて。

応用例

def copy_func(f):
    """https://stackoverflow.com/questions/13503079"""
    import functools
    import types
    assert isinstance(f, types.FunctionType), (f, type(f))
    g = types.FunctionType(
        f.__code__,
        f.__globals__, 
        name=f.__name__,
        argdefs=f.__defaults__,
        closure=f.__closure__,
    )
    g.__kwdefaults__ = f.__kwdefaults__
    functools.update_wrapper(g, f)
    return g

def add_hook(func, pre_call=None, post_call=None, except_=None, finally_=None):
    import inspect
    func_sig = inspect.signature(func)
    func_copy = copy_func(func)

    def hook(*args, **kwargs):
        bound_args = func_sig.bind(*args, **kwargs)
        if pre_call is not None:
            pre_call(func_copy, bound_args)
        try:
            return_ = func_copy(*args, **kwargs)
        except BaseException as e:
            if except_ is not None:
                except_(func_copy, bound_args, e)
            raise
        else:
            if post_call is not None:
                post_call(func_copy, bound_args, return_)
        finally:
            if finally_ is not None:
                finally_(func_copy, bound_args)
        return return_

    overwrite_func(func, hook)

コールバック関数を後付できます。

def callback(f, args, result):
    print(result)

add_hook(original_func, post_call=callback)
original_func()  # original_func() が呼ばれる前に callback() が呼ばれる。

おわりに

できることはできましたが、やらなくて済むならやらないほうが良いです。
__code__ まわりの仕様を理解しきれていないので、テストケースはたぶん足りていません。
動かない事例があれば教えて下さい。

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away