LoginSignup
4
1

python: 自作クラスをnumpy ndarrayに使用する

Last updated at Posted at 2020-04-24

https://numpy.org/doc/stable/user/basics.subclassing.html#module-numpy.doc.subclassing (公式)を参考にすれば良いでしょう。

以下は簡単な実装例です。

class Element:
    def __init__(self, r, c):
        self.r = r
        self.c = c
        
    def __repr__(self):
        return f'Element({self.r},{self.c})'

r行c列目の要素がElement(r, c)であるようなnumpy.ndarrayを作りたい、とします。


MyNdarrayクラスをnp.ndarrayを継承して以下のように作ります。

import numpy as np

class MyNdarray(np.ndarray):
    def __new__(cls, dimension):
        shape = (dimension, dimension)
        return super().__new__(cls, shape, dtype=Element)

    def __init__(self, dimension):
        for r in range(dimension):
            for c in range(dimension):
                self[r, c] = Element(r, c)

__new__ でshapeが(dimension * dimension) の numpy ndarray instanceを作成して, __init__ で初期操作(コンストラクタ)が行われます.

a = MyNdarray(3)
>> MyNdarray([[Element(0,0), Element(0,1), Element(0,2)],
           [Element(1,0), Element(1,1), Element(1,2)],
           [Element(2,0), Element(2,1), Element(2,2)]], dtype=object)
a[0, 0]
>> Element(0,0)

ndarrayのスライシング操作や転置機能をそのまま使うことができるのがいいですね.

a[:, 0:2]
>> MyNdarray([[Element(0,0), Element(0,1)],
           [Element(1,0), Element(1,1)],
           [Element(2,0), Element(2,1)]], dtype=object)
a.T
>> MyNdarray([[Element(0,0), Element(1,0), Element(2,0)],
           [Element(0,1), Element(1,1), Element(2,1)],
           [Element(0,2), Element(1,2), Element(2,2)]], dtype=object)


また, MyNdarray クラスに, 属性を追加したいときは__array_finalize__ 関数を使います.

import numpy as np

class MyNdarray(np.ndarray):
    def __new__(cls, dimension):
        shape = (dimension, dimension)
        obj = super().__new__(cls, shape, dtype=Element)
        obj.dimension = dimension
        return obj
    
    def __init__(self, dimension):
        for r in range(dimension):
            for c in range(dimension):
                self[r, c] = Element(r, c)
                
    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.dimension = getattr(obj, 'dimension', None)
a = MyNdarray(3)
a.dimension
>>> 3
4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1