1
1

More than 3 years have passed since last update.

[Python]namedtupleをdataclassに変更(_replace, _fieldsの互換)

Posted at

今回はnamedtupleをdataclassに変更した際にnamedtuple._replacenamedtuple._fieldsの代替をどのように実装すれば良いかを自分が実際に移行した際のプログラムを用いて見ていきます.

import dataclasses
from dataclasses import dataclass
from collections import namedtuple

namedtupleの実装

DataSupplyUnitLegacyはある時刻の複数銘柄の株価データをひとまとまりにするためのクラスです.フィールドについてはあまり気にしないで下さい.
namedtupleを継承しているので,部分的に値を変えた新しいオブジェクトを返す_replaceメソッドと設定したアトリビュートの名前のタプルを返す_fieldsプロパティを持っています.

field_list = ["names",  # 銘柄名
              "key_currency_index",  # 基軸通貨のインデックス
              "datetime",  # データの日時
              "window",  # データのウィンドウ
              "open_array",  # [銘柄名, ウィンドウ(時間)]に対応する始値
              "close_array",  # [銘柄名, ウィンドウ(時間)]に対応する終値
              "high_array",  # [銘柄名, ウィンドウ(時間)]に対応する高値
              "low_array",  # [銘柄名, ウィンドウ(時間)]に対応する低値
              "volume_array"  # [銘柄名, ウィンドウ(時間)]に対応する取引量
             ]

DataSupplyUnitBaseLegacy = namedtuple("DataSupplyUnitBase", field_list)


class DataSupplyUnitLegacy(DataSupplyUnitBaseLegacy):
    """
    DataSupplierによって提供されるデータクラス
    """    
    def __str__(self):
        return_str = "DataSupplyUnit( \n"
        for field_str in self._fields:
            return_str += field_str + "="
            return_str += str(getattr(self, field_str)) + "\n"
        return_str += ")"
        return return_str

dataclassの実装

DataSupplyUnitLegacyをdataclassに変更したものがDataSupplyUnitです.
namedtuple._replaceに対応するのはdataclasses.replaceです.一方namedtuple._fieldsに対応するのはdataclasses.fieldsです.しかし,dataclasses.fieldsが返すのはdataclasses.Fieldのタプルなので名前を取得するには.nameでアクセスします.
よって,以下のように実装できます.

@dataclass
class DataSupplyUnit:
    """
    DataSupplierによって提供されるデータクラス
    """
    names: np.ndarray # 銘柄名
    key_currency_index: int  # 基軸通貨のインデックス
    datetime: datetime.datetime  # データの日時
    window: np.ndarray  # データのウィンドウ
    open_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する始値
    close_array: np.ndarray # [銘柄名, ウィンドウ(時間)]に対応する終値
    high_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する高値
    low_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する低値
    volume_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する取引量

    def _replace(self, **kwargs):
        """
        namedtupleとの互換性のため
        """
        return dataclasses.replace(self, **kwargs)


    def __str__(self):
        return_str = "DataSupplyUnit( \n"
        for field in dataclasses.fields(self):
            return_str += field.name + "="
            return_str += str(getattr(self, field.name)) + "\n"
        return_str += ")"
        return return_str

自分は_fieldsを外部から利用することは無いので実装していませんが,もし実装する場合は

    @property
    def _fields(self):
        return tuple([field.name for field in dataclasses.fields(self)])

とすればいいと思います.

おまけ

DataSupplyUnitをフィールドのnp.ndarrayを含めてコピーするメソッドをcopyとして,部分的にしか利用しない場合にコピーするメソッドをpartialとして以下のように実装しています.もっといい書き方があれば教えてほしいです,特にクラス名を利用してコンストラクトしているのが良くないと思います.

    def copy(self):
        """
        自身のコビーを返す.ndarrayのプロパティの場合はそのコビーを保持する.
        """
        arg_dict = {}
        for field in dataclasses.fields(self):
            field_value = getattr(self, field.name)
            if isinstance(field_value, np.ndarray):
                field_value = field_value.copy()

            arg_dict[field.name] = field_value

        return DataSupplyUnit(**arg_dict)

    def partial(self, *args):
        """
        str:
            フィールド名
        メモリ等の状況によって,自身の部分的なコビーを返す.
        引数に与えられなかったプロパティはNoneとなる.
        """
        arg_dict = {}
        for field in dataclasses.fields(self):
            if field.name in args:
                field_value = getattr(self, field.name)
                if isinstance(field_value, np.ndarray):
                    field_value = field_value.copy()
            else:
                field_value = None

            arg_dict[field.name] = field_value

        return DataSupplyUnit(**arg_dict)

まとめ

namedtuple_replace_fieldsを互換性のあるままdataclassに変更する場合は以下のように実装できます.

    def _replace(self, **kwargs):
        return dataclasses.replace(self, **kwargs)

    @property
    def _fields(self):
        return tuple([field.name for field in dataclasses.fields(self)])
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1