LoginSignup
1
0

More than 5 years have passed since last update.

someFunc(arr > n)という表現

Last updated at Posted at 2018-01-16

ほとんどnumpyを使ったことがなく、np.whereに出会って一瞬だけ怯んだ時のメモ。

JavaやJavaScript等の言語をやっている人からすると、表題の関数に渡される引数はtrue(PythonではTrue)か、false(PythonではFalse)になるとお思いではないでしょうか?

ところがどっこい、それ以外が引数に渡される言語があるのです。その一つがPythonです。

Pythonには、特殊メソッドなる素晴らしい機能がありまして、演算子に対する振る舞いを、クラスごとに自前で定義することが可能です。もちろん比較演算子もそうです。

class YJSNPI(object):
    def __init__(self,val):
        self._v = val

    def __eq__(self,other):
        if (self._v == other) in (114514,True):
            return 114514
        else:
            return 1919810

if __name__ == '__main__':
    a = YJSNPI(5)
    b = 5
    print(a == b) # => 114514

    c = YJSNPI(5)
    print(a == c) # => 114514

    d = YJSNPI(10)
    print(a == d) # => 1919810

これと似た感じで、TrueFalse以外のものを返させることができます。

以下ろくにドキュメントも読まず適当に予測して似せてみたwhere関数の実装です。numpyの振る舞いを見る限りもっと複雑なんでしょうけど、とりあえず比較演算子で対象要素、または対象要素のインデックスを取り出すことを目標としました。

class my_list(list):

    def __lt__(self, s):  # self.item < s
        return my_list(v < s for v in self)

    def __le__(self, s):  # self.item <= s
        return my_list(v <= s for v in self)

    def __eq__(self, s):  # self.item == s
        return my_list(v == s for v in self)

    def __ne__(self, s):  # self.item != s
        return my_list(v != s for v in self)

    def __gt__(self, s):  # self.item > s
        return my_list(v > s for v in self)

    def __ge__(self, s):  # self.item >= s
        return my_list(v >= s for v in self)

    def __getitem__(self, index):
        if type(index) in (list, my_list):
            return my_list(v for v, b in zip(self, index) if b)
        else:
            return super().__getitem__(index)

    @staticmethod
    def where(inds, *args):
        if len(args) >= 2:
            return my_list(args[not b] for b in inds)
        elif len(args) == 1:
            raise ValueError('either both or neither of x and y should be given')
        else:
            return my_list(i for i, b in enumerate(inds) if b)

if __name__ == '__main__':
    arr = my_list([8, 1, 7, 2, 6, 3, 5, 4, 5, 3, 7, 2, 8, 1])
    print(my_list.where(arr >= 5, 1, 0))
    # => [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
    print(my_list.where(arr != 5))
    # => [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13]
    # print(my_list.where(arr<5,2)) # => ValueError
    print(arr[arr >= 5]) # => [8, 7, 6, 5, 5, 7, 8]

@shiracamus様からいただいたコメントを元に修正させて頂きました。ありがとうございます。m(_ _)m

1
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0