LoginSignup
6
8

More than 5 years have passed since last update.

【Python3.7】dataclassがHashableを継承できない話

Posted at

問題提起

Python 3.7 で新たに登場したdataclassデコレータは、クラスに様々なメソッドを自動的に付加してくれる。
例えば、次のようにすることで、__init__()__hash__()__eq__()を適切に実装したnamedtupleのようなものを作ってくれる。

dataclass_example.py
from dataclasses import dataclass

@dataclass(eq=True, frozen=True)
class Point3D:
    """Immutable struct for 3-D points"""
    x: float
    y: float
    z: float = 0.0  # default value

if __name__ == '__main__':
    p = Point3D(x=3.14, y=2.72)
    print(p)  # -> Point3D(x=3.14, y=2.72, z=0.0)
    q = Point3D(4.6, 9.3, 0.0)
    print(p == q)  # -> False
    another_p = Point3D(3.14, 2.72)
    print(p == another_p)  # -> True
    d = dict(p='Point3D can be', q='a key of a dictionary')

このインスタンス、pqは、当然hashableである。

make_sure_to_be_hashable.py
from dataclass_example import Point3D
from typing import Hashable

p = Point3D(x=3.14, y=2.72)
print(hash(p))  # -> -7688538011378258102
print(isinstance(p, Hashable))  # -> True

ところが、このクラスはHashableを継承できない。
以下のコードはエラーになるのである。

actually_its_not_hashable.py
from dataclasses import dataclass
from typing import Hashable

@dataclass(eq=True, frozen=True)
class HashablePoint3D(Hashable):  # Explicitly inherit from Hashable (actually it's a protocol)
    x: float
    y: float
    z: float = 0.0

if __name__ == '__main__':
    p = HashablePoint3D(x=3.14, y=2.72)
    # -> TypeError: Can't instantiate abstract class HashablePoint3D with abstract methods __hash__

Point3Dはインスタンスが__hash__()メソッドを持つにもかかわらず、Hashableを明示的に継承することができない。
Hashableに限らず、dataclassによって付加されるメソッドを要求するほとんどの Abstract Base Class が、継承できずエラーになるはずだ。

理由

デコレータはただのラッパーで、__new__()を上書きするだけなのだが、その前にABCMetaの監査が入ってしまうため、__hash__()が実装される前にHashableかどうかの判定が行われてしまい、エラーが出てしまう。
dataclassはクラスの生成時ではなく、値の生成時に作用しているということだ。

以下では、このことをもう少し詳しく説明する。

復習: デコレータの中身

デコレータとは、メソッドやクラスに作用するラッパーである。

minimal_decorator_example.py
from typing import Callable, Any

def sample_decorator(target_function: Callable[[int], str]) -> Callable[[int], str]:
    # decorated_function = ( lambda x: decoration(target_function(x)) )
    def decorated_function(x: int) -> str:
        original_return_value = target_function(x)
        return f'Decorated! Original return value is: { original_return_value }'

    return decorated_function

@sample_decorator
def sample(x: int) -> str:
    return f'x is {x}'

if __name__ == '__main__':
    print(sample(1))  # -> Decorated! Original return value is: x is 1

sample_decoratorは、対象であるsampleメソッドを受け取って、それをラップした新たなメソッドを返す。
@sample_decoratorを付けることで、元のsampleメソッドにそのラッパーを作用させたメソッドをsampleとして改めて定義することができる。

復習: Abstract Base Class

Abstract Base Classとは、型チェックのないPythonにおいて、その型が特定のメソッドをサポートしていることを保証する構文である。

minimal_abc_example.py
from abc import ABC, abstractmethod

class Button:
    def push(self) -> str:
        raise NotImplementedError()

class UnsafeButton(Button):
    """Ouch you unfortunately forget to implement push()"""
    pass

class ABCButton(ABC):  # Inheritance from ABC is the syntax sugar of metaclass=abc.ABCMeta
    @abstractmethod  # Declare that this method needs implementation
    def push(self) -> str:
        raise NotImplementedError()

class SafeButton(ABCButton)
    """Again you forget to implement push()"""
    pass

def do_stuff_and_push_the_button(button):
    do_your_complicated_stuff(button)
    button.push()

if __name__ == '__main__':
    unsafe_button = UnsafeButton()  # -> (No error)
    do_stuff_and_push_the_button(unsafe_button)  # -> NotImplementedError
    safe_button = SafeButton()  # -> TypeError: Can't instantiate abstract class SafeButton with abstract methods push

インスタンス生成時に怒ってくれるので、そのクラスのインスタンスが作れている段階で全てのabstractmethodが実装されていることを保証できる。

復習: Pythonの__new__()__init__()

Pythonにおけるコンストラクタは2種類ある。__new__(cls)clsのインスタンスを生成し、__init__(self, ...)は生成したクラスのインスタンスselfに対して初期化処理を行う。

minimal_constructor_example.py
class Test:
    def __new__(cls):
        instance = super().__new__(cls)
        print('new')
        return instance

    def __init__(self):
        super().__init__()
        print('init')

if __name__ == '__main__':
    t = Test()  # -> new, init

大事なのは、必ず__new__()を呼び出した後に__init__()が呼ばれること。
__new__()より先に、生成されるインスタンスに対して処理を加えることはできない。

本題

対象となるクラスを再掲する。

actually_its_not_hashable.py
from dataclasses import dataclass
from typing import Hashable

@dataclass(eq=True, frozen=True)
class HashablePoint3D(Hashable):
    ...

この@dataclassはデコレータであるから、クラスを受け取ってクラスを返すメソッドである。
その中身は、たぶん「受け取ったクラスの__new__()を、以下のように書き換える」という内容なのだろう。

the_newer_new_method.py
def __new__(new_cls):
    instance = old_cls.__new___()
    instance.__hash__ = ...
    instance.__eq__ = ...
    instance.__awesome_methods__ = ...
    return instance

すなわち、__new__()メソッドでインスタンスを生成する際に、種々のメソッドを付加するというものである。

ところが、このデコレータの作用対象であるold_clsは、Hashableを継承している。

Hashable_may_be_like_this.py
from abc import ABC, abstractmethod

class Hashable(ABC):
    @abstractmethod
    def __hash__(self) -> int:
        raise NotImplementedError()

したがって、

instance = old_cls.__new___()

の中身は、次のようになっている。

old_class_new_may_be_like_this.py
def __new__(old_cls):
    instance = object.__new__(old_cls)
    ABC_abstract_implementaion_check(instance)  # -> Here we have NOT implemented the __hash__() yet!

この段階ではまだ__hash__()は実装されていないため、abc.ABCによってTypeErrorが送出されてしまう。

すなわち、
1. 対象となるクラスの__new__()、すなわちデコレータの__new__()が呼び出される。
2. デコレータの__new__()内でsuper().__new__()、すなわちABC__new__()が呼び出される。
3. ABC__new__()内でsuper().__new__()、すなわちobject__new__()が呼び出される。
4. object__new__()によってインスタンスが生成される。
5. ABC__new__()による型検査が入り、この時点で全てのabstractmethodを実装していないとTypeErrorが送出される。
6. デコレータの__new__()によって種々のメソッドが付与される。

という順番で処理が行われ、5.の段階でエラーになってしまう。

これが事件の真相である。

対策

対策と言っても、Hashableを明示的に継承するシンプルな方法は、たぶん存在しない。
__new__の代わりにdataclassに直接メソッドを付与するようなメタクラス操作をするデコレータを自分で書くとか、dataclassでラップされたクラスをHashableとは別に継承するとか、とにかくダサい方法しかないと思う。

だが、実際問題、Hashableを明示的に継承する必要はない。
我々にはProtocolという強力な武器があるからである。

protocol.py
from dataclasses import dataclass
from typing import Hashable
from abc import abstractmethod
from typing_extensions import Protocol

@dataclass(eq=True, frozen=True)
class Point3D:
    x: float
    y: float
    z: float = 0.0

class MyHashable(Protocol):
    @abstractmethod
    def __hash__(self) -> int:
        raise NotImplementedError()

def get_hashable_and_return_none(x: Hashable) -> None: pass
def get_myhashable_and_return_none(x: MyHashable) -> None: pass

if __name__ == '__main__':
    p = Point3D(4.6, 9.3)
    get_hashable_and_return_none(p)  # Mypy (static type checker) finds no error
    get_myhashable_and_return_none(p)  # Also no error

Mypyは賢いので、デコレータが付与するメソッドもきちんと理解してくれる。
明示的に継承できないのは若干読みづらいかもしれないが、静的型チェック時にはProtocolを利用するようにすることで一応解決する。

(どうもmypyのHashableの確認はかなり怪しい挙動をするようなのだが、そこには目を瞑るほかないようだ。)

この記事が皆さんのPython生活の一助となれば幸いである。

6
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
8