3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

numpyの3次元配列のソート

Last updated at Posted at 2020-03-09

問題

2×4×3の3次元の配列xがあります。
3次元目のインデックス1の値をキーに、2次元目を安定で降順に並べ替えたい。

import numpy as np

np.random.seed(0)
x = np.random.randint(0, 10, (2, 4, 3))
print(x)
"""
以下の3次元配列を、3次元目のindex 1をキーに2次元目を(安定で)降順にソートしたい。
[[[5 0 3]
  [3 7 9]
  [3 5 2]
  [4 7 6]]

 [[8 8 1]
  [6 7 7]
  [8 1 5]
  [9 8 9]]]
     ↑この列の降順にソートしたい

欲しい結果
[[[[3 7 9]
   [4 7 6]
   [3 5 2]
   [5 0 3]]

  [[8 8 1]
   [9 8 9]
   [6 7 7]
   [8 1 5]]]]
"""

argsort

numpy には argsort と言うものがあり、ndarrayそのものをソートするのではなく、ソートしたときのインデックスを返してくれる。

rindex = (-x[:, :, 1]).argsort()
print(rindex)
"""
[[1 3 2 0]
 [0 3 1 2]]
ここまでは期待通り
"""

argsortで検索すると、2次元の配列に対する例が多く、3次元以上の例がなかなか見つけられない。
また、argsortのソートは昇順と決まっており、降順が欲しいときはインデックスに::-1を指定しろとか書いてある。
しかし、ソートキーに同じ値が存在する場合、インデックスに::-1を指定すると結果が安定ではなくなってしまうのだ。

print(x[:, ::-1, 1].argsort())
"""
[[3 1 0 2]
 [1 2 0 3]]
結果が正しくない
"""
print(x[:, :, 1].argsort()[:,::-1])
"""
[[3 1 2 0]
 [3 0 1 2]]
結果が安定でない
"""

さて、こうして得られたrindexを使ってxにインデクシングしてやれば欲しい結果が得られそうなのであるが、うまく取り出せない。

print(x[rindex])
"""
IndexError: index 3 is out of bounds for axis 0 with size 2
"""
print(x[:, rindex])
"""
[[[[3 7 9]
   [4 7 6]
   [3 5 2]
   [5 0 3]]

  [[5 0 3]
   [4 7 6]
   [3 7 9]
   [3 5 2]]]


 [[[6 7 7]
   [9 8 9]
   [8 1 5]
   [8 8 1]]

  [[8 8 1]
   [9 8 9]
   [6 7 7]
   [8 1 5]]]]
"""

苦肉の策で、以下の方法で取り出すことができるが、何か違う気がする。

sx = np.array([x[i, rindex[i]] for i in range(x.shape[0])])
print(sx)
"""
[[[3 7 9]
  [4 7 6]
  [3 5 2]
  [5 0 3]]

 [[8 8 1]
  [9 8 9]
  [6 7 7]
  [8 1 5]]]
"""

人に聞いてみた

社内のslackで質問してみたところ、いろいろなやり方が寄せられたので、一つ一つ内容を理解してみたい。
なお、案の名前は私が勝手につけたので、提案していただいた方にはネーミングの責任はない。

(0) 原案

sx = np.array([x[i, rindex[i]] for i in range(x.shape[0])])

(1) sortedを使う案

sx = np.array([ sorted(x[i], key=lambda z: z[1], reverse=True) for i in range(x.shape[0]) ])

python組み込みのsortedとlistを使う案。

(2) take_along_axisを使う案

sx = np.take_along_axis(x, np.argsort(-x[..., 1], axis=1)[..., None], axis=1)
# This is equivalent to (but faster than) the following use of ndindex and s_,
# which sets each of ii and kk to a tuple of indices:
Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis]  # Need not equal M
out = np.empty(Ni + (J,) + Nk)

for ii in ndindex(Ni):
    for kk in ndindex(Nk):
        a_1d       = a      [ii + s_[:,] + kk]
        indices_1d = indices[ii + s_[:,] + kk]
        out_1d     = out    [ii + s_[:,] + kk]
        for j in range(J):
            out_1d[j] = a_1d[indices_1d[j]]

(3) diag_indicesを使う案

sx = x[:, (-x[:, :, 1]).argsort()][np.diag_indices(x.shape[0], ndim=2)]
  • np.diag_indices って何?→対角線(diagonal)上の要素にアクセスするためのインデックスを生成する。ndimはデフォルト2なので、, ndim=2は省略可能。
  • つまり、2次元のndarray aがあったとして、a[(0, 1), (2, 3)]と言うインデックス指定をされると、aの0行2列目と、1行3列目の2要素の配列を返す。
a = np.arange(16).reshape(4, 4)
"""
[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]
"""
print(a[(0, 1), (2, 3)])
"""
[2 7]
"""

(4) identityを使う案

sx = x[:, (-x[:, :, 1]).argsort()][np.identity(x.shape[0]) > 0]
  • identity って何?→対角線に1が並んだ行列を生成する。
  • アイディアとしては、diag_indicesと同じか。

(5) 正統派インデックス指定案

sx = x[np.arange(x.shape[0])[:, None], (-x).argsort(axis=1)[:, :, 1]]
  • (3), (4)の案が一度大きな配列を作ってから対角要素を取り出すのに対して、この案はxの1次元目の要素を取り出すindexと、argsortの返り値の2次元目のindexを使ってxの要素を取り出している。
  • たぶん私がもともとやりたかったのはこの案。
  • 以下のように書き換えても良いか(これを案(5)'とします。)
sx = x[np.arange(x.shape[0])[:, None], (-x[..., 1]).argsort()]

速度比較

上の例はxが2×4×3の例だが、実際に使いたかったのは10×65536×3のデータだったので両方のケースでipythonの%timeitで時間を計ってみた。
環境は python 3.6.8 + numpy 1.17.2 であるが、実際の実行時間は言語やライブラリ、ランタイムのバージョンや実行環境、データの量や質によるので、あくまで参考として見ていただきたい。

2×4×3 10×65536×3
(0) 8.69 µs ± 443 ns 103 ms ± 4.33 ms
(1) 9.4 µs ± 377 ns 419 ms ± 2.15 ms
(2) 14.4 µs ± 525 ns 28.5 ms ± 133 µs
(3) 6.6 µs ± 75.6 ns 95.2 ms ± 1.3 ms
(4) 10.4 µs ± 990 ns 96 ms ± 896 µs
(5) 5.06 µs ± 727 ns 69 ms ± 424 µs
(5)' 4.34 µs ± 38.2 ns 20.2 ms ± 688 µs

今回のデータだと、案(5)'が一番速いと言う結果になりました。

※ 案(1)だけ、10×65536×3のときのソート結果が合わない形になりました。2×4×3だと安定のソートができていたので、ちょっと理由がわかりませんでした。

3
0
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?