2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

numpyで累積和の閾値を満たす最小 のインデックスを求める

Posted at

よく使うけどよく忘れるので,メモ

先に結論

argmax()つかうと上手くいく.

やりたいこと

タイトルがわかりづらいかもしれないので,詳しくやりたいことを説明.
numpy.arrayに対して,累積和(cumsum)を計算して,その累積和に対する閾値を超える最小のインデックスとかが欲しくなったりする.
例えば,

In [143]: score
Out[143]: 
array([[  1.05262936e-01,   1.05262936e-01,   1.05262936e-01,
          1.05262936e-01,   1.05262936e-01,   1.57894404e-01,
          1.05262936e-01,   1.05262936e-01,   2.10525873e-06,
          1.05262936e-01],
       [  1.66664722e-01,   1.66664722e-06,   1.66664722e-01,
          2.49997083e-01,   1.66664722e-06,   2.49997083e-01,
          1.66664722e-01,   3.33329444e-06,   3.33329444e-06,
          1.66664722e-06],
       [  9.09090909e-02,   9.09090909e-02,   9.09090909e-02,
          9.09090909e-02,   9.09090909e-02,   1.21212121e-01,
          1.21212121e-01,   9.09090909e-02,   9.09090909e-02,
          1.21212121e-01]])

In [144]: score.shape
Out[144]: (3, 10)

みたいなarrayがあったとして,これの累積和

In [145]: score.cumsum(1)
Out[145]: 
array([[ 0.10526294,  0.21052587,  0.31578881,  0.42105175,  0.52631468,
         0.68420909,  0.78947202,  0.89473496,  0.89473706,  1.        ],
       [ 0.16666472,  0.16666639,  0.33333111,  0.58332819,  0.58332986,
         0.83332694,  0.99999167,  0.999995  ,  0.99999833,  1.        ],
       [ 0.09090909,  0.18181818,  0.27272727,  0.36363636,  0.45454545,
         0.57575758,  0.6969697 ,  0.78787879,  0.87878788,  1.        ]])

がそれぞれ閾値

In [149]: threshold = np.random.random((3, 1))

In [150]: threshold
Out[150]: 
array([[ 0.62732896],
       [ 0.46494853],
       [ 0.54341381]])

よりも大きい要素のうち,一番小さいインデックスが欲しい.
この例だと,score[0, 5], score[1, 3], score[2, 5]になるので,[5, 3, 5]の結果が貰えればいい.

そこで,初めの結論に戻るのだが,numpyargmaxを使うと,いい塩梅に欲しい出力がもらえた.

In [155]: score.cumsum(1) > threshold
Out[155]: 
array([[False, False, False, False, False,  True,  True,  True,  True,
         True],
       [False, False, False,  True,  True,  True,  True,  True,  True,
         True],
       [False, False, False, False, False,  True,  True,  True,  True,
         True]], dtype=bool)

In [156]: np.argmax(score.cumsum(1) > threshold, 1)
Out[156]: array([5, 3, 5], dtype=int64)

argmaxは最大値を取るインデックスのうち,最も小さいインデックスを返すみたい.

だけど...

逆に閾値以下の要素のうち最大のインデックスをもとめたいとき,は無理だった.
[4, 3, 4]が欲しいのだが,

In [176]: score.cumsum(1) < threshold
Out[176]: 
array([[ True,  True,  True,  True,  True, False, False, False, False,
        False],
       [ True,  True,  True, False, False, False, False, False, False,
        False],
       [ True,  True,  True,  True,  True, False, False, False, False,
        False]], dtype=bool)

In [178]: np.argmin(score.cumsum(1) < threshold, 1)
Out[178]: array([5, 3, 5], dtype=int64)

In [179]: np.argmax(score.cumsum(1) < threshold, 1)
Out[179]: array([0, 0, 0], dtype=int64)

こうなってしまう.

np.whereとかで格好良く求めるのが正義?
誰かクールなtipsがあったら教えて下さい.

2
2
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?