はじめに
numpyには、axisを指定する関数が多数あります。
2次元以下ならば、縦軸、横軸と覚えれば問題ありませんが、それ以上の次元になると分かりづらくなります。
多次元の場合のaxisの考え方が自分なりに分かったので記事にしてみます。
この記事ではsort()の挙動を例にとって考えます。
環境
Python 3.13.9 (tags/v3.13.9:8183fa5, Oct 14 2025, 14:09:13) [MSC v.1944 64 bit (AMD64)]
numpy 2.2.4
numpy.sort
numpy.sort(a, axis=-1, kind=None, order=None, *, stable=None)
-
a: array_like
ソートしたい配列 -
axis: int or None, optional
どの軸(axis)でソートしたいか指定する。
デフォルト値は、-1。
Noneの場合は、配列を平坦化してからソートする。
-1の場合は、最後の軸でソートする。 -
kind: {‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’}, optional
ソートアルゴリズム。
デフォルト値は、"quicksort" -
order: str or list of str, optional
配列のdtypeがstructured data typeの場合、どの項目を優先してソートするか指定する。
昇順・降順の指定ではない。 -
stable: bool, optional
安定ソートにするかどうか。
デフォルト値は、None
None,Falseの場合は、安定かどうかは保証されない。
Trueにすると、kind="stable"に設定される。
ndarray.sort(axis=-1, kind=None, order=None)
-
axis: int, optional
どの軸(axis)でソートしたいか指定する。
デフォルト値は、-1。
-1の場合は、最後の軸でソートする。 -
kind: {‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’}, optional
ソートアルゴリズム。
デフォルト値は、"quicksort" -
order: str or list of str, optional
配列のdtypeがstructured data typeの場合、どの項目を優先してソートするか指定する。
昇順・降順の指定ではない。
解説
コード
自分なりに分かりやすく表示するために補助関数を定義しておきます。
import math
import re
import numpy as np
def string1dArray(a, maxL, c=" "):
result = "["
for x in a:
result += f"{x:{maxL}}{c}"
return result[:-1] + "]"
def stringNdArray(a, maxL, c=" "):
if a.ndim == 1:
return string1dArray(a, maxL, c)
result = "["
for x in a:
result += f"{stringNdArray(x, maxL, c)}{c}"
return result[:-1] + "]"
def insertNewline(string, ndim, depth, c):
pat = rf"([^\[\]\s]{']' * ndim}{c})"
repl = rf"\1\n{' ' * (depth)}"
return re.sub(pat, repl, string)
def stringArray(a, axis=0, c=" "):
maxL = len(str(max(a.flatten())))
string = stringNdArray(a, maxL, c)
s = a.ndim - 1
e = max(s - axis, -1)
for i in range(s, e, -1):
string = insertNewline(string, i, a.ndim - i, c)
return string
shape = [5, 4, 3, 2]
a = np.arange(math.prod(shape))
rng = np.random.default_rng(0)
rng.shuffle(a)
a = a.reshape(shape)
print(f"ndim = {a.ndim}, shape = {a.shape}")
print(stringArray(a, 1))
print("***")
それを使って配列を表示すると以下のようになります。
この配列をソートしつつ軸について解説していきたいと思います。
ndim = 4, shape = (5, 4, 3, 2)
[[[[101 67] [ 36 114] [ 1 23]] [[ 13 16] [107 27] [ 5 50]] [[ 84 44] [ 98 100] [111 117]] [[ 99 87] [110 97] [ 20 34]]]
[[[ 88 10] [ 39 105] [113 72]] [[ 52 11] [ 83 8] [ 9 37]] [[ 22 19] [ 66 81] [ 55 4]] [[ 25 92] [109 53] [ 74 75]]]
[[[ 71 3] [ 93 82] [ 42 15]] [[ 30 2] [ 35 94] [ 80 43]] [[ 17 90] [ 28 64] [ 18 6]] [[ 65 68] [ 24 86] [ 57 21]]]
[[[ 91 85] [104 47] [ 0 60]] [[ 70 62] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 89 46]]]
[[[ 58 14] [ 73 38] [106 31]] [[108 48] [ 77 76] [ 7 103]] [[ 63 119] [ 69 78] [ 59 54]] [[ 29 41] [ 56 33] [ 79 95]]]]
axis = None
axis = Noneの場合は、配列を平坦化してからソートします。
def testNone(a):
r = np.sort(a, None)
print(stringArray(r))
testNone()
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]
axis = 0
def test0(a):
print(stringArray(a, 1))
print("---")
r = np.sort(a, 0)
print(stringArray(r, 1))
print("---")
test0(a)
[[[[101 67] [ 36 114] [ 1 23]] [[ 13 16] [107 27] [ 5 50]] [[ 84 44] [ 98 100] [111 117]] [[ 99 87] [110 97] [ 20 34]]]
[[[ 88 10] [ 39 105] [113 72]] [[ 52 11] [ 83 8] [ 9 37]] [[ 22 19] [ 66 81] [ 55 4]] [[ 25 92] [109 53] [ 74 75]]]
[[[ 71 3] [ 93 82] [ 42 15]] [[ 30 2] [ 35 94] [ 80 43]] [[ 17 90] [ 28 64] [ 18 6]] [[ 65 68] [ 24 86] [ 57 21]]]
[[[ 91 85] [104 47] [ 0 60]] [[ 70 62] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 89 46]]]
[[[ 58 14] [ 73 38] [106 31]] [[108 48] [ 77 76] [ 7 103]] [[ 63 119] [ 69 78] [ 59 54]] [[ 29 41] [ 56 33] [ 79 95]]]]
---
[[[[ 58 3] [ 36 38] [ 0 15]] [[ 13 2] [ 35 8] [ 5 37]] [[ 17 19] [ 28 64] [ 18 4]] [[ 12 41] [ 24 33] [ 20 21]]]
[[[ 71 10] [ 39 47] [ 1 23]] [[ 30 11] [ 45 27] [ 7 43]] [[ 22 44] [ 66 78] [ 40 6]] [[ 25 68] [ 32 53] [ 57 34]]]
[[[ 88 14] [ 73 82] [ 42 31]] [[ 52 16] [ 77 76] [ 9 50]] [[ 49 90] [ 69 81] [ 55 54]] [[ 29 87] [ 56 86] [ 74 46]]]
[[[ 91 67] [ 93 105] [106 60]] [[ 70 48] [ 83 94] [ 26 51]] [[ 63 96] [ 98 100] [ 59 61]] [[ 65 92] [109 97] [ 79 75]]]
[[[101 85] [104 114] [113 72]] [[108 62] [107 102] [ 80 103]] [[ 84 119] [112 115] [111 117]] [[ 99 118] [110 116] [ 89 95]]]]
---
軸0の要素は、以下の5つになります。
[[[101 67] [ 36 114] [ 1 23]] [[ 13 16] [107 27] [ 5 50]] [[ 84 44] [ 98 100] [111 117]] [[ 99 87] [110 97] [ 20 34]]][[[ 88 10] [ 39 105] [113 72]] [[ 52 11] [ 83 8] [ 9 37]] [[ 22 19] [ 66 81] [ 55 4]] [[ 25 92] [109 53] [ 74 75]]][[[ 71 3] [ 93 82] [ 42 15]] [[ 30 2] [ 35 94] [ 80 43]] [[ 17 90] [ 28 64] [ 18 6]] [[ 65 68] [ 24 86] [ 57 21]]][[[ 91 85] [104 47] [ 0 60]] [[ 70 62] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 89 46]]][[[ 58 14] [ 73 38] [106 31]] [[108 48] [ 77 76] [ 7 103]] [[ 63 119] [ 69 78] [ 59 54]] [[ 29 41] [ 56 33] [ 79 95]]]
この要素同士の同じ位置の値を比較してソートしていきます。
軸0に対応するshapeが5なので、比較する組み合わせは、5個で1組になります。
具体的には、(101, 88, 71, 91, 58), (36, 39, 93, 104, 73), ...となります。
これらをソートして結果が、(58, 71, 88, 91, 101), (36, 39, 73, 93, 104), ...となります。
axis = 1
def test1(a):
print(stringArray(a, 2))
print("---")
r = np.sort(a, 1)
print(stringArray(r, 2))
print("---")
test1(a)
[[[[101 67] [ 36 114] [ 1 23]]
[[ 13 16] [107 27] [ 5 50]]
[[ 84 44] [ 98 100] [111 117]]
[[ 99 87] [110 97] [ 20 34]]]
[[[ 88 10] [ 39 105] [113 72]]
[[ 52 11] [ 83 8] [ 9 37]]
[[ 22 19] [ 66 81] [ 55 4]]
[[ 25 92] [109 53] [ 74 75]]]
[[[ 71 3] [ 93 82] [ 42 15]]
[[ 30 2] [ 35 94] [ 80 43]]
[[ 17 90] [ 28 64] [ 18 6]]
[[ 65 68] [ 24 86] [ 57 21]]]
[[[ 91 85] [104 47] [ 0 60]]
[[ 70 62] [ 45 102] [ 26 51]]
[[ 49 96] [112 115] [ 40 61]]
[[ 12 118] [ 32 116] [ 89 46]]]
[[[ 58 14] [ 73 38] [106 31]]
[[108 48] [ 77 76] [ 7 103]]
[[ 63 119] [ 69 78] [ 59 54]]
[[ 29 41] [ 56 33] [ 79 95]]]]
---
[[[[ 13 16] [ 36 27] [ 1 23]]
[[ 84 44] [ 98 97] [ 5 34]]
[[ 99 67] [107 100] [ 20 50]]
[[101 87] [110 114] [111 117]]]
[[[ 22 10] [ 39 8] [ 9 4]]
[[ 25 11] [ 66 53] [ 55 37]]
[[ 52 19] [ 83 81] [ 74 72]]
[[ 88 92] [109 105] [113 75]]]
[[[ 17 2] [ 24 64] [ 18 6]]
[[ 30 3] [ 28 82] [ 42 15]]
[[ 65 68] [ 35 86] [ 57 21]]
[[ 71 90] [ 93 94] [ 80 43]]]
[[[ 12 62] [ 32 47] [ 0 46]]
[[ 49 85] [ 45 102] [ 26 51]]
[[ 70 96] [104 115] [ 40 60]]
[[ 91 118] [112 116] [ 89 61]]]
[[[ 29 14] [ 56 33] [ 7 31]]
[[ 58 41] [ 69 38] [ 59 54]]
[[ 63 48] [ 73 76] [ 79 95]]
[[108 119] [ 77 78] [106 103]]]]
---
軸1の要素は、20個(軸0のshape * 軸1のshape = 5 * 4 = 20)になります。
具体的には、以下のようになります。
[[101 67] [ 36 114] [ 1 23]][[ 13 16] [107 27] [ 5 50]][[ 84 44] [ 98 100] [111 117]]-
[[ 99 87] [110 97] [ 20 34]]
-- [[ 88 10] [ 39 105] [113 72]][[ 52 11] [ 83 8] [ 9 37]][[ 22 19] [ 66 81] [ 55 4]]-
[[ 25 92] [109 53] [ 74 75]]
...
軸1に対応するshapeが4なので、比較する組み合わせは、4個で1組になります。
具体的には、(101, 13, 84, 99), (67, 16, 44, 87), ...となります。
これらをソートして結果が、(13, 84, 99, 101), (16, 44, 67, 87), ...となります。
axis = 2
def test2(a):
print(stringArray(a, 3))
print("---")
r = np.sort(a, 2)
print(stringArray(r, 3))
print("---")
test2(a)
[[[[101 67]
[ 36 114]
[ 1 23]]
[[ 13 16]
[107 27]
[ 5 50]]
[[ 84 44]
[ 98 100]
[111 117]]
[[ 99 87]
[110 97]
[ 20 34]]]
[[[ 88 10]
[ 39 105]
[113 72]]
[[ 52 11]
[ 83 8]
[ 9 37]]
[[ 22 19]
[ 66 81]
[ 55 4]]
[[ 25 92]
[109 53]
[ 74 75]]]
[[[ 71 3]
[ 93 82]
[ 42 15]]
[[ 30 2]
[ 35 94]
[ 80 43]]
[[ 17 90]
[ 28 64]
[ 18 6]]
[[ 65 68]
[ 24 86]
[ 57 21]]]
[[[ 91 85]
[104 47]
[ 0 60]]
[[ 70 62]
[ 45 102]
[ 26 51]]
[[ 49 96]
[112 115]
[ 40 61]]
[[ 12 118]
[ 32 116]
[ 89 46]]]
[[[ 58 14]
[ 73 38]
[106 31]]
[[108 48]
[ 77 76]
[ 7 103]]
[[ 63 119]
[ 69 78]
[ 59 54]]
[[ 29 41]
[ 56 33]
[ 79 95]]]]
---
[[[[ 1 23]
[ 36 67]
[101 114]]
[[ 5 16]
[ 13 27]
[107 50]]
[[ 84 44]
[ 98 100]
[111 117]]
[[ 20 34]
[ 99 87]
[110 97]]]
[[[ 39 10]
[ 88 72]
[113 105]]
[[ 9 8]
[ 52 11]
[ 83 37]]
[[ 22 4]
[ 55 19]
[ 66 81]]
[[ 25 53]
[ 74 75]
[109 92]]]
[[[ 42 3]
[ 71 15]
[ 93 82]]
[[ 30 2]
[ 35 43]
[ 80 94]]
[[ 17 6]
[ 18 64]
[ 28 90]]
[[ 24 21]
[ 57 68]
[ 65 86]]]
[[[ 0 47]
[ 91 60]
[104 85]]
[[ 26 51]
[ 45 62]
[ 70 102]]
[[ 40 61]
[ 49 96]
[112 115]]
[[ 12 46]
[ 32 116]
[ 89 118]]]
[[[ 58 14]
[ 73 31]
[106 38]]
[[ 7 48]
[ 77 76]
[108 103]]
[[ 59 54]
[ 63 78]
[ 69 119]]
[[ 29 33]
[ 56 41]
[ 79 95]]]]
---
軸2の要素は、60個(軸0のshape * 軸1のshape * 軸2のshape = 5 * 4 * 3 = 60)になります。
具体的には、以下のようになります。
[101 67][ 36 114]-
[ 1 23]
-- [ 13 16][107 27]-
[ 5 50]
...
軸2に対応するshapeが3なので、比較する組み合わせは、3個で1組になります。
具体的には、(101, 36, 1), (67, 114, 23), ...となります。
これらをソートして結果が、(1, 36, 101), (23, 67, 114), ...となります。
axis = 3 および axis = -1
デフォルト値でソートするとaxis = -1となります。
この場合は、axis = 3と同じ結果になります。
def test3(a):
print(stringArray(a, 4))
print("---")
r = np.sort(a, 3)
print(stringArray(r, 4))
print("---")
r = np.sort(a)
print(stringArray(r))
print("---")
r = np.sort(a, -1)
print(stringArray(r))
print("---")
test3(a)
[[[[101
67]
[ 36
114]
[ 1
23]]
[[ 13
16]
[107
27]
[ 5
50]]
[[ 84
44]
[ 98
100]
[111
117]]
[[ 99
87]
[110
97]
[ 20
34]]]
[[[ 88
10]
[ 39
105]
[113
72]]
[[ 52
11]
[ 83
8]
[ 9
37]]
[[ 22
19]
[ 66
81]
[ 55
4]]
[[ 25
92]
[109
53]
[ 74
75]]]
[[[ 71
3]
[ 93
82]
[ 42
15]]
[[ 30
2]
[ 35
94]
[ 80
43]]
[[ 17
90]
[ 28
64]
[ 18
6]]
[[ 65
68]
[ 24
86]
[ 57
21]]]
[[[ 91
85]
[104
47]
[ 0
60]]
[[ 70
62]
[ 45
102]
[ 26
51]]
[[ 49
96]
[112
115]
[ 40
61]]
[[ 12
118]
[ 32
116]
[ 89
46]]]
[[[ 58
14]
[ 73
38]
[106
31]]
[[108
48]
[ 77
76]
[ 7
103]]
[[ 63
119]
[ 69
78]
[ 59
54]]
[[ 29
41]
[ 56
33]
[ 79
95]]]]
---
[[[[ 67
101]
[ 36
114]
[ 1
23]]
[[ 13
16]
[ 27
107]
[ 5
50]]
[[ 44
84]
[ 98
100]
[111
117]]
[[ 87
99]
[ 97
110]
[ 20
34]]]
[[[ 10
88]
[ 39
105]
[ 72
113]]
[[ 11
52]
[ 8
83]
[ 9
37]]
[[ 19
22]
[ 66
81]
[ 4
55]]
[[ 25
92]
[ 53
109]
[ 74
75]]]
[[[ 3
71]
[ 82
93]
[ 15
42]]
[[ 2
30]
[ 35
94]
[ 43
80]]
[[ 17
90]
[ 28
64]
[ 6
18]]
[[ 65
68]
[ 24
86]
[ 21
57]]]
[[[ 85
91]
[ 47
104]
[ 0
60]]
[[ 62
70]
[ 45
102]
[ 26
51]]
[[ 49
96]
[112
115]
[ 40
61]]
[[ 12
118]
[ 32
116]
[ 46
89]]]
[[[ 14
58]
[ 38
73]
[ 31
106]]
[[ 48
108]
[ 76
77]
[ 7
103]]
[[ 63
119]
[ 69
78]
[ 54
59]]
[[ 29
41]
[ 33
56]
[ 79
95]]]]
---
[[[[ 67 101] [ 36 114] [ 1 23]] [[ 13 16] [ 27 107] [ 5 50]] [[ 44 84] [ 98 100] [111 117]] [[ 87 99] [ 97 110] [ 20 34]]] [[[ 10 88] [ 39 105] [ 72 113]] [[ 11 52] [ 8 83] [ 9 37]] [[ 19 22] [ 66 81] [ 4 55]] [[ 25 92] [ 53 109] [ 74 75]]] [[[ 3 71] [ 82 93] [ 15 42]] [[ 2 30] [ 35 94] [ 43 80]] [[ 17 90] [ 28 64] [ 6 18]] [[ 65 68] [ 24 86] [ 21 57]]] [[[ 85 91] [ 47 104] [ 0 60]] [[ 62 70] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 46 89]]] [[[ 14 58] [ 38 73] [ 31 106]] [[ 48 108] [ 76 77] [ 7 103]] [[ 63 119] [ 69 78] [ 54 59]] [[ 29 41] [ 33 56] [ 79 95]]]]
---
[[[[ 67 101] [ 36 114] [ 1 23]] [[ 13 16] [ 27 107] [ 5 50]] [[ 44 84] [ 98 100] [111 117]] [[ 87 99] [ 97 110] [ 20 34]]] [[[ 10 88] [ 39 105] [ 72 113]] [[ 11 52] [ 8 83] [ 9 37]] [[ 19 22] [ 66 81] [ 4 55]] [[ 25 92] [ 53 109] [ 74 75]]] [[[ 3 71] [ 82 93] [ 15 42]] [[ 2 30] [ 35 94] [ 43 80]] [[ 17 90] [ 28 64] [ 6 18]] [[ 65 68] [ 24 86] [ 21 57]]] [[[ 85 91] [ 47 104] [ 0 60]] [[ 62 70] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 46 89]]] [[[ 14 58] [ 38 73] [ 31 106]] [[ 48 108] [ 76 77] [ 7 103]] [[ 63 119] [ 69 78] [ 54 59]] [[ 29 41] [ 33 56] [ 79 95]]]]
---
軸3の要素は、120個(軸0のshape * 軸1のshape * 軸2のshape * 軸3のshape = 5 * 4 * 3 * 2 = 120)になります。
具体的には、以下のようになります。
101-
67
-- 36-
114
...
軸3に対応するshapeが2なので、比較する組み合わせは、2個で1組になります。
これは、一番最深の配列の要素となります。
つまり、最深の配列の要素をそれぞれソートする事になります。
具体的には、(101, 67), (36, 114), ...となります。
これらをソートして結果が、(67, 101), (36, 114), ...となります。