問題
2×4×3の3次元の配列xがあります。
3次元目のインデックス1の値をキーに、2次元目を安定で降順に並べ替えたい。
import numpy as np
np.random.seed(0)
x = np.random.randint(0, 10, (2, 4, 3))
print(x)
"""
以下の3次元配列を、3次元目のindex 1をキーに2次元目を(安定で)降順にソートしたい。
[[[5 0 3]
[3 7 9]
[3 5 2]
[4 7 6]]
[[8 8 1]
[6 7 7]
[8 1 5]
[9 8 9]]]
↑この列の降順にソートしたい
欲しい結果
[[[[3 7 9]
[4 7 6]
[3 5 2]
[5 0 3]]
[[8 8 1]
[9 8 9]
[6 7 7]
[8 1 5]]]]
"""
argsort
numpy には argsort と言うものがあり、ndarrayそのものをソートするのではなく、ソートしたときのインデックスを返してくれる。
rindex = (-x[:, :, 1]).argsort()
print(rindex)
"""
[[1 3 2 0]
[0 3 1 2]]
ここまでは期待通り
"""
argsortで検索すると、2次元の配列に対する例が多く、3次元以上の例がなかなか見つけられない。
また、argsortのソートは昇順と決まっており、降順が欲しいときはインデックスに::-1
を指定しろとか書いてある。
しかし、ソートキーに同じ値が存在する場合、インデックスに::-1
を指定すると結果が安定ではなくなってしまうのだ。
print(x[:, ::-1, 1].argsort())
"""
[[3 1 0 2]
[1 2 0 3]]
結果が正しくない
"""
print(x[:, :, 1].argsort()[:,::-1])
"""
[[3 1 2 0]
[3 0 1 2]]
結果が安定でない
"""
さて、こうして得られたrindexを使ってxにインデクシングしてやれば欲しい結果が得られそうなのであるが、うまく取り出せない。
print(x[rindex])
"""
IndexError: index 3 is out of bounds for axis 0 with size 2
"""
print(x[:, rindex])
"""
[[[[3 7 9]
[4 7 6]
[3 5 2]
[5 0 3]]
[[5 0 3]
[4 7 6]
[3 7 9]
[3 5 2]]]
[[[6 7 7]
[9 8 9]
[8 1 5]
[8 8 1]]
[[8 8 1]
[9 8 9]
[6 7 7]
[8 1 5]]]]
"""
苦肉の策で、以下の方法で取り出すことができるが、何か違う気がする。
sx = np.array([x[i, rindex[i]] for i in range(x.shape[0])])
print(sx)
"""
[[[3 7 9]
[4 7 6]
[3 5 2]
[5 0 3]]
[[8 8 1]
[9 8 9]
[6 7 7]
[8 1 5]]]
"""
人に聞いてみた
社内のslackで質問してみたところ、いろいろなやり方が寄せられたので、一つ一つ内容を理解してみたい。
なお、案の名前は私が勝手につけたので、提案していただいた方にはネーミングの責任はない。
(0) 原案
sx = np.array([x[i, rindex[i]] for i in range(x.shape[0])])
(1) sortedを使う案
sx = np.array([ sorted(x[i], key=lambda z: z[1], reverse=True) for i in range(x.shape[0]) ])
python組み込みのsortedとlistを使う案。
(2) take_along_axisを使う案
sx = np.take_along_axis(x, np.argsort(-x[..., 1], axis=1)[..., None], axis=1)
- python - Sorting 4D numpy array but keeping one axis tied together - Stack Overflowから
-
...
って何?→x[:, :, 1]
とx[..., 1]
が同義らしい。 -
..., None
って何?→None
はnumpy.newaxis
と同義で、次元を増やすらしい。 - take_along_axisって何?→第1引数の配列から、第2引数で指定されたindexに従って取り出す?
- マニュアルには、以下の説明があります。(頭が理解を拒否したのでここで逃げます)
# This is equivalent to (but faster than) the following use of ndindex and s_,
# which sets each of ii and kk to a tuple of indices:
Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis] # Need not equal M
out = np.empty(Ni + (J,) + Nk)
for ii in ndindex(Ni):
for kk in ndindex(Nk):
a_1d = a [ii + s_[:,] + kk]
indices_1d = indices[ii + s_[:,] + kk]
out_1d = out [ii + s_[:,] + kk]
for j in range(J):
out_1d[j] = a_1d[indices_1d[j]]
(3) diag_indicesを使う案
sx = x[:, (-x[:, :, 1]).argsort()][np.diag_indices(x.shape[0], ndim=2)]
-
np.diag_indices って何?→対角線(diagonal)上の要素にアクセスするためのインデックスを生成する。ndimはデフォルト2なので、
, ndim=2
は省略可能。 - つまり、2次元のndarray aがあったとして、
a[(0, 1), (2, 3)]
と言うインデックス指定をされると、aの0行2列目と、1行3列目の2要素の配列を返す。
a = np.arange(16).reshape(4, 4)
"""
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
"""
print(a[(0, 1), (2, 3)])
"""
[2 7]
"""
(4) identityを使う案
sx = x[:, (-x[:, :, 1]).argsort()][np.identity(x.shape[0]) > 0]
- identity って何?→対角線に1が並んだ行列を生成する。
- アイディアとしては、diag_indicesと同じか。
(5) 正統派インデックス指定案
sx = x[np.arange(x.shape[0])[:, None], (-x).argsort(axis=1)[:, :, 1]]
- (3), (4)の案が一度大きな配列を作ってから対角要素を取り出すのに対して、この案はxの1次元目の要素を取り出すindexと、argsortの返り値の2次元目のindexを使ってxの要素を取り出している。
- たぶん私がもともとやりたかったのはこの案。
- 以下のように書き換えても良いか(これを案(5)'とします。)
sx = x[np.arange(x.shape[0])[:, None], (-x[..., 1]).argsort()]
速度比較
上の例はxが2×4×3の例だが、実際に使いたかったのは10×65536×3のデータだったので両方のケースでipythonの%timeitで時間を計ってみた。
環境は python 3.6.8 + numpy 1.17.2 であるが、実際の実行時間は言語やライブラリ、ランタイムのバージョンや実行環境、データの量や質によるので、あくまで参考として見ていただきたい。
案 | 2×4×3 | 10×65536×3 |
---|---|---|
(0) | 8.69 µs ± 443 ns | 103 ms ± 4.33 ms |
(1) | 9.4 µs ± 377 ns | 419 ms ± 2.15 ms |
(2) | 14.4 µs ± 525 ns | 28.5 ms ± 133 µs |
(3) | 6.6 µs ± 75.6 ns | 95.2 ms ± 1.3 ms |
(4) | 10.4 µs ± 990 ns | 96 ms ± 896 µs |
(5) | 5.06 µs ± 727 ns | 69 ms ± 424 µs |
(5)' | 4.34 µs ± 38.2 ns | 20.2 ms ± 688 µs |
今回のデータだと、案(5)'が一番速いと言う結果になりました。
※ 案(1)だけ、10×65536×3のときのソート結果が合わない形になりました。2×4×3だと安定のソートができていたので、ちょっと理由がわかりませんでした。