ほとんどnumpy
を使ったことがなく、np.where
に出会って一瞬だけ怯んだ時のメモ。
JavaやJavaScript等の言語をやっている人からすると、表題の関数に渡される引数はtrue
(PythonではTrue
)か、false
(PythonではFalse
)になるとお思いではないでしょうか?
ところがどっこい、それ以外が引数に渡される言語があるのです。その一つがPythonです。
Pythonには、特殊メソッドなる素晴らしい機能がありまして、演算子に対する振る舞いを、クラスごとに自前で定義することが可能です。もちろん比較演算子もそうです。
class YJSNPI(object):
def __init__(self,val):
self._v = val
def __eq__(self,other):
if (self._v == other) in (114514,True):
return 114514
else:
return 1919810
if __name__ == '__main__':
a = YJSNPI(5)
b = 5
print(a == b) # => 114514
c = YJSNPI(5)
print(a == c) # => 114514
d = YJSNPI(10)
print(a == d) # => 1919810
これと似た感じで、True
やFalse
以外のものを返させることができます。
以下ろくにドキュメントも読まず適当に予測して似せてみたwhere
関数の実装です。numpy
の振る舞いを見る限りもっと複雑なんでしょうけど、とりあえず比較演算子で対象要素、または対象要素のインデックスを取り出すことを目標としました。
class my_list(list):
def __lt__(self, s): # self.item < s
return my_list(v < s for v in self)
def __le__(self, s): # self.item <= s
return my_list(v <= s for v in self)
def __eq__(self, s): # self.item == s
return my_list(v == s for v in self)
def __ne__(self, s): # self.item != s
return my_list(v != s for v in self)
def __gt__(self, s): # self.item > s
return my_list(v > s for v in self)
def __ge__(self, s): # self.item >= s
return my_list(v >= s for v in self)
def __getitem__(self, index):
if type(index) in (list, my_list):
return my_list(v for v, b in zip(self, index) if b)
else:
return super().__getitem__(index)
@staticmethod
def where(inds, *args):
if len(args) >= 2:
return my_list(args[not b] for b in inds)
elif len(args) == 1:
raise ValueError('either both or neither of x and y should be given')
else:
return my_list(i for i, b in enumerate(inds) if b)
if __name__ == '__main__':
arr = my_list([8, 1, 7, 2, 6, 3, 5, 4, 5, 3, 7, 2, 8, 1])
print(my_list.where(arr >= 5, 1, 0))
# => [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
print(my_list.where(arr != 5))
# => [0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, 13]
# print(my_list.where(arr<5,2)) # => ValueError
print(arr[arr >= 5]) # => [8, 7, 6, 5, 5, 7, 8]
@shiracamus様からいただいたコメントを元に修正させて頂きました。ありがとうございます。m(_ _)m