0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

[Python] NumpyにおけるAxis(軸)の考え方

Posted at

はじめに

numpyには、axisを指定する関数が多数あります。
2次元以下ならば、縦軸、横軸と覚えれば問題ありませんが、それ以上の次元になると分かりづらくなります。
多次元の場合のaxisの考え方が自分なりに分かったので記事にしてみます。
この記事ではsort()の挙動を例にとって考えます。

環境

py -VV
Python 3.13.9 (tags/v3.13.9:8183fa5, Oct 14 2025, 14:09:13) [MSC v.1944 64 bit (AMD64)]
py -m pip list
numpy                  2.2.4

numpy.sort

numpy.sort(a, axis=-1, kind=None, order=None, *, stable=None)
  • a: array_like
    ソートしたい配列

  • axis: int or None, optional
    どの軸(axis)でソートしたいか指定する。
    デフォルト値は、-1。
    Noneの場合は、配列を平坦化してからソートする。
    -1の場合は、最後の軸でソートする。

  • kind: {‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’}, optional
    ソートアルゴリズム。
    デフォルト値は、"quicksort"

  • order: str or list of str, optional
    配列のdtypestructured data typeの場合、どの項目を優先してソートするか指定する。
    昇順・降順の指定ではない。

  • stable: bool, optional
    安定ソートにするかどうか。
    デフォルト値は、None
    None, Falseの場合は、安定かどうかは保証されない。
    Trueにすると、kind="stable"に設定される。

ndarray.sort(axis=-1, kind=None, order=None)
  • axis: int, optional
    どの軸(axis)でソートしたいか指定する。
    デフォルト値は、-1。
    -1の場合は、最後の軸でソートする。

  • kind: {‘quicksort’, ‘mergesort’, ‘heapsort’, ‘stable’}, optional
    ソートアルゴリズム。
    デフォルト値は、"quicksort"

  • order: str or list of str, optional
    配列のdtypestructured data typeの場合、どの項目を優先してソートするか指定する。
    昇順・降順の指定ではない。

解説

コード

自分なりに分かりやすく表示するために補助関数を定義しておきます。

python
import math
import re
import numpy as np

def string1dArray(a, maxL, c=" "):
  result = "["
  for x in a:
    result += f"{x:{maxL}}{c}"
  return result[:-1] + "]"

def stringNdArray(a, maxL, c=" "):
  if a.ndim == 1:
    return string1dArray(a, maxL, c)
  result = "["
  for x in a:
    result += f"{stringNdArray(x, maxL, c)}{c}"
  return result[:-1] + "]"

def insertNewline(string, ndim, depth, c):
  pat = rf"([^\[\]\s]{']' * ndim}{c})"
  repl = rf"\1\n{' ' * (depth)}"
  return re.sub(pat, repl, string)

def stringArray(a, axis=0, c=" "):
  maxL = len(str(max(a.flatten())))
  string = stringNdArray(a, maxL, c)
  s = a.ndim - 1
  e = max(s - axis, -1)
  for i in range(s, e, -1):
    string = insertNewline(string, i, a.ndim - i, c)
  return string

shape = [5, 4, 3, 2]
a = np.arange(math.prod(shape))
rng = np.random.default_rng(0)
rng.shuffle(a)
a = a.reshape(shape)
print(f"ndim = {a.ndim}, shape = {a.shape}")
print(stringArray(a, 1))
print("***")

それを使って配列を表示すると以下のようになります。
この配列をソートしつつ軸について解説していきたいと思います。

配列
ndim = 4, shape = (5, 4, 3, 2)
[[[[101  67] [ 36 114] [  1  23]] [[ 13  16] [107  27] [  5  50]] [[ 84  44] [ 98 100] [111 117]] [[ 99  87] [110  97] [ 20  34]]]
 [[[ 88  10] [ 39 105] [113  72]] [[ 52  11] [ 83   8] [  9  37]] [[ 22  19] [ 66  81] [ 55   4]] [[ 25  92] [109  53] [ 74  75]]]
 [[[ 71   3] [ 93  82] [ 42  15]] [[ 30   2] [ 35  94] [ 80  43]] [[ 17  90] [ 28  64] [ 18   6]] [[ 65  68] [ 24  86] [ 57  21]]]
 [[[ 91  85] [104  47] [  0  60]] [[ 70  62] [ 45 102] [ 26  51]] [[ 49  96] [112 115] [ 40  61]] [[ 12 118] [ 32 116] [ 89  46]]]
 [[[ 58  14] [ 73  38] [106  31]] [[108  48] [ 77  76] [  7 103]] [[ 63 119] [ 69  78] [ 59  54]] [[ 29  41] [ 56  33] [ 79  95]]]]

axis = None

axis = Noneの場合は、配列を平坦化してからソートします。

testNone
def testNone(a):
  r = np.sort(a, None)
  print(stringArray(r))

testNone()
出力
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]

axis = 0

test0
def test0(a):
  print(stringArray(a, 1))
  print("---")
  r = np.sort(a, 0)
  print(stringArray(r, 1))
  print("---")

test0(a)
出力
[[[[101  67] [ 36 114] [  1  23]] [[ 13  16] [107  27] [  5  50]] [[ 84  44] [ 98 100] [111 117]] [[ 99  87] [110  97] [ 20  34]]]
 [[[ 88  10] [ 39 105] [113  72]] [[ 52  11] [ 83   8] [  9  37]] [[ 22  19] [ 66  81] [ 55   4]] [[ 25  92] [109  53] [ 74  75]]]
 [[[ 71   3] [ 93  82] [ 42  15]] [[ 30   2] [ 35  94] [ 80  43]] [[ 17  90] [ 28  64] [ 18   6]] [[ 65  68] [ 24  86] [ 57  21]]]
 [[[ 91  85] [104  47] [  0  60]] [[ 70  62] [ 45 102] [ 26  51]] [[ 49  96] [112 115] [ 40  61]] [[ 12 118] [ 32 116] [ 89  46]]]
 [[[ 58  14] [ 73  38] [106  31]] [[108  48] [ 77  76] [  7 103]] [[ 63 119] [ 69  78] [ 59  54]] [[ 29  41] [ 56  33] [ 79  95]]]]
---
[[[[ 58   3] [ 36  38] [  0  15]] [[ 13   2] [ 35   8] [  5  37]] [[ 17  19] [ 28  64] [ 18   4]] [[ 12  41] [ 24  33] [ 20  21]]]
 [[[ 71  10] [ 39  47] [  1  23]] [[ 30  11] [ 45  27] [  7  43]] [[ 22  44] [ 66  78] [ 40   6]] [[ 25  68] [ 32  53] [ 57  34]]]
 [[[ 88  14] [ 73  82] [ 42  31]] [[ 52  16] [ 77  76] [  9  50]] [[ 49  90] [ 69  81] [ 55  54]] [[ 29  87] [ 56  86] [ 74  46]]]
 [[[ 91  67] [ 93 105] [106  60]] [[ 70  48] [ 83  94] [ 26  51]] [[ 63  96] [ 98 100] [ 59  61]] [[ 65  92] [109  97] [ 79  75]]]
 [[[101  85] [104 114] [113  72]] [[108  62] [107 102] [ 80 103]] [[ 84 119] [112 115] [111 117]] [[ 99 118] [110 116] [ 89  95]]]]
---

軸0の要素は、以下の5つになります。

  1. [[[101 67] [ 36 114] [ 1 23]] [[ 13 16] [107 27] [ 5 50]] [[ 84 44] [ 98 100] [111 117]] [[ 99 87] [110 97] [ 20 34]]]
  2. [[[ 88 10] [ 39 105] [113 72]] [[ 52 11] [ 83 8] [ 9 37]] [[ 22 19] [ 66 81] [ 55 4]] [[ 25 92] [109 53] [ 74 75]]]
  3. [[[ 71 3] [ 93 82] [ 42 15]] [[ 30 2] [ 35 94] [ 80 43]] [[ 17 90] [ 28 64] [ 18 6]] [[ 65 68] [ 24 86] [ 57 21]]]
  4. [[[ 91 85] [104 47] [ 0 60]] [[ 70 62] [ 45 102] [ 26 51]] [[ 49 96] [112 115] [ 40 61]] [[ 12 118] [ 32 116] [ 89 46]]]
  5. [[[ 58 14] [ 73 38] [106 31]] [[108 48] [ 77 76] [ 7 103]] [[ 63 119] [ 69 78] [ 59 54]] [[ 29 41] [ 56 33] [ 79 95]]]

この要素同士の同じ位置の値を比較してソートしていきます。
軸0に対応するshapeが5なので、比較する組み合わせは、5個で1組になります。
具体的には、(101, 88, 71, 91, 58), (36, 39, 93, 104, 73), ...となります。
これらをソートして結果が、(58, 71, 88, 91, 101), (36, 39, 73, 93, 104), ...となります。

axis = 1

test1
def test1(a):
  print(stringArray(a, 2))
  print("---")
  r = np.sort(a, 1)
  print(stringArray(r, 2))
  print("---")

test1(a)
出力
[[[[101  67] [ 36 114] [  1  23]]
  [[ 13  16] [107  27] [  5  50]]
  [[ 84  44] [ 98 100] [111 117]]
  [[ 99  87] [110  97] [ 20  34]]]
 [[[ 88  10] [ 39 105] [113  72]]
  [[ 52  11] [ 83   8] [  9  37]]
  [[ 22  19] [ 66  81] [ 55   4]]
  [[ 25  92] [109  53] [ 74  75]]]
 [[[ 71   3] [ 93  82] [ 42  15]]
  [[ 30   2] [ 35  94] [ 80  43]]
  [[ 17  90] [ 28  64] [ 18   6]]
  [[ 65  68] [ 24  86] [ 57  21]]]
 [[[ 91  85] [104  47] [  0  60]]
  [[ 70  62] [ 45 102] [ 26  51]]
  [[ 49  96] [112 115] [ 40  61]]
  [[ 12 118] [ 32 116] [ 89  46]]]
 [[[ 58  14] [ 73  38] [106  31]]
  [[108  48] [ 77  76] [  7 103]]
  [[ 63 119] [ 69  78] [ 59  54]]
  [[ 29  41] [ 56  33] [ 79  95]]]]
---
[[[[ 13  16] [ 36  27] [  1  23]]
  [[ 84  44] [ 98  97] [  5  34]]
  [[ 99  67] [107 100] [ 20  50]]
  [[101  87] [110 114] [111 117]]]
 [[[ 22  10] [ 39   8] [  9   4]]
  [[ 25  11] [ 66  53] [ 55  37]]
  [[ 52  19] [ 83  81] [ 74  72]]
  [[ 88  92] [109 105] [113  75]]]
 [[[ 17   2] [ 24  64] [ 18   6]]
  [[ 30   3] [ 28  82] [ 42  15]]
  [[ 65  68] [ 35  86] [ 57  21]]
  [[ 71  90] [ 93  94] [ 80  43]]]
 [[[ 12  62] [ 32  47] [  0  46]]
  [[ 49  85] [ 45 102] [ 26  51]]
  [[ 70  96] [104 115] [ 40  60]]
  [[ 91 118] [112 116] [ 89  61]]]
 [[[ 29  14] [ 56  33] [  7  31]]
  [[ 58  41] [ 69  38] [ 59  54]]
  [[ 63  48] [ 73  76] [ 79  95]]
  [[108 119] [ 77  78] [106 103]]]]
---

軸1の要素は、20個(軸0のshape * 軸1のshape = 5 * 4 = 20)になります。
具体的には、以下のようになります。

  1. [[101 67] [ 36 114] [ 1 23]]
  2. [[ 13 16] [107 27] [ 5 50]]
  3. [[ 84 44] [ 98 100] [111 117]]
  4. [[ 99 87] [110 97] [ 20 34]]
    --
  5. [[ 88 10] [ 39 105] [113 72]]
  6. [[ 52 11] [ 83 8] [ 9 37]]
  7. [[ 22 19] [ 66 81] [ 55 4]]
  8. [[ 25 92] [109 53] [ 74 75]]
    ...

軸1に対応するshapeが4なので、比較する組み合わせは、4個で1組になります。
具体的には、(101, 13, 84, 99), (67, 16, 44, 87), ...となります。
これらをソートして結果が、(13, 84, 99, 101), (16, 44, 67, 87), ...となります。

axis = 2

test2
def test2(a):
  print(stringArray(a, 3))
  print("---")
  r = np.sort(a, 2)
  print(stringArray(r, 3))
  print("---")

test2(a)
出力
[[[[101  67]
   [ 36 114]
   [  1  23]]
  [[ 13  16]
   [107  27]
   [  5  50]]
  [[ 84  44]
   [ 98 100]
   [111 117]]
  [[ 99  87]
   [110  97]
   [ 20  34]]]
 [[[ 88  10]
   [ 39 105]
   [113  72]]
  [[ 52  11]
   [ 83   8]
   [  9  37]]
  [[ 22  19]
   [ 66  81]
   [ 55   4]]
  [[ 25  92]
   [109  53]
   [ 74  75]]]
 [[[ 71   3]
   [ 93  82]
   [ 42  15]]
  [[ 30   2]
   [ 35  94]
   [ 80  43]]
  [[ 17  90]
   [ 28  64]
   [ 18   6]]
  [[ 65  68]
   [ 24  86]
   [ 57  21]]]
 [[[ 91  85]
   [104  47]
   [  0  60]]
  [[ 70  62]
   [ 45 102]
   [ 26  51]]
  [[ 49  96]
   [112 115]
   [ 40  61]]
  [[ 12 118]
   [ 32 116]
   [ 89  46]]]
 [[[ 58  14]
   [ 73  38]
   [106  31]]
  [[108  48]
   [ 77  76]
   [  7 103]]
  [[ 63 119]
   [ 69  78]
   [ 59  54]]
  [[ 29  41]
   [ 56  33]
   [ 79  95]]]]
---
[[[[  1  23]
   [ 36  67]
   [101 114]]
  [[  5  16]
   [ 13  27]
   [107  50]]
  [[ 84  44]
   [ 98 100]
   [111 117]]
  [[ 20  34]
   [ 99  87]
   [110  97]]]
 [[[ 39  10]
   [ 88  72]
   [113 105]]
  [[  9   8]
   [ 52  11]
   [ 83  37]]
  [[ 22   4]
   [ 55  19]
   [ 66  81]]
  [[ 25  53]
   [ 74  75]
   [109  92]]]
 [[[ 42   3]
   [ 71  15]
   [ 93  82]]
  [[ 30   2]
   [ 35  43]
   [ 80  94]]
  [[ 17   6]
   [ 18  64]
   [ 28  90]]
  [[ 24  21]
   [ 57  68]
   [ 65  86]]]
 [[[  0  47]
   [ 91  60]
   [104  85]]
  [[ 26  51]
   [ 45  62]
   [ 70 102]]
  [[ 40  61]
   [ 49  96]
   [112 115]]
  [[ 12  46]
   [ 32 116]
   [ 89 118]]]
 [[[ 58  14]
   [ 73  31]
   [106  38]]
  [[  7  48]
   [ 77  76]
   [108 103]]
  [[ 59  54]
   [ 63  78]
   [ 69 119]]
  [[ 29  33]
   [ 56  41]
   [ 79  95]]]]
---

軸2の要素は、60個(軸0のshape * 軸1のshape * 軸2のshape = 5 * 4 * 3 = 60)になります。
具体的には、以下のようになります。

  1. [101 67]
  2. [ 36 114]
  3. [ 1 23]
    --
  4. [ 13 16]
  5. [107 27]
  6. [ 5 50]
    ...

軸2に対応するshapeが3なので、比較する組み合わせは、3個で1組になります。
具体的には、(101, 36, 1), (67, 114, 23), ...となります。
これらをソートして結果が、(1, 36, 101), (23, 67, 114), ...となります。

axis = 3 および axis = -1

デフォルト値でソートするとaxis = -1となります。
この場合は、axis = 3と同じ結果になります。

test3
def test3(a):
  print(stringArray(a, 4))
  print("---")
  r = np.sort(a, 3)
  print(stringArray(r, 4))
  print("---")
  r = np.sort(a)
  print(stringArray(r))
  print("---")
  r = np.sort(a, -1)
  print(stringArray(r))
  print("---")

test3(a)
出力
[[[[101
     67]
   [ 36
    114]
   [  1
     23]]
  [[ 13
     16]
   [107
     27]
   [  5
     50]]
  [[ 84
     44]
   [ 98
    100]
   [111
    117]]
  [[ 99
     87]
   [110
     97]
   [ 20
     34]]]
 [[[ 88
     10]
   [ 39
    105]
   [113
     72]]
  [[ 52
     11]
   [ 83
      8]
   [  9
     37]]
  [[ 22
     19]
   [ 66
     81]
   [ 55
      4]]
  [[ 25
     92]
   [109
     53]
   [ 74
     75]]]
 [[[ 71
      3]
   [ 93
     82]
   [ 42
     15]]
  [[ 30
      2]
   [ 35
     94]
   [ 80
     43]]
  [[ 17
     90]
   [ 28
     64]
   [ 18
      6]]
  [[ 65
     68]
   [ 24
     86]
   [ 57
     21]]]
 [[[ 91
     85]
   [104
     47]
   [  0
     60]]
  [[ 70
     62]
   [ 45
    102]
   [ 26
     51]]
  [[ 49
     96]
   [112
    115]
   [ 40
     61]]
  [[ 12
    118]
   [ 32
    116]
   [ 89
     46]]]
 [[[ 58
     14]
   [ 73
     38]
   [106
     31]]
  [[108
     48]
   [ 77
     76]
   [  7
    103]]
  [[ 63
    119]
   [ 69
     78]
   [ 59
     54]]
  [[ 29
     41]
   [ 56
     33]
   [ 79
     95]]]]
---
[[[[ 67
    101]
   [ 36
    114]
   [  1
     23]]
  [[ 13
     16]
   [ 27
    107]
   [  5
     50]]
  [[ 44
     84]
   [ 98
    100]
   [111
    117]]
  [[ 87
     99]
   [ 97
    110]
   [ 20
     34]]]
 [[[ 10
     88]
   [ 39
    105]
   [ 72
    113]]
  [[ 11
     52]
   [  8
     83]
   [  9
     37]]
  [[ 19
     22]
   [ 66
     81]
   [  4
     55]]
  [[ 25
     92]
   [ 53
    109]
   [ 74
     75]]]
 [[[  3
     71]
   [ 82
     93]
   [ 15
     42]]
  [[  2
     30]
   [ 35
     94]
   [ 43
     80]]
  [[ 17
     90]
   [ 28
     64]
   [  6
     18]]
  [[ 65
     68]
   [ 24
     86]
   [ 21
     57]]]
 [[[ 85
     91]
   [ 47
    104]
   [  0
     60]]
  [[ 62
     70]
   [ 45
    102]
   [ 26
     51]]
  [[ 49
     96]
   [112
    115]
   [ 40
     61]]
  [[ 12
    118]
   [ 32
    116]
   [ 46
     89]]]
 [[[ 14
     58]
   [ 38
     73]
   [ 31
    106]]
  [[ 48
    108]
   [ 76
     77]
   [  7
    103]]
  [[ 63
    119]
   [ 69
     78]
   [ 54
     59]]
  [[ 29
     41]
   [ 33
     56]
   [ 79
     95]]]]
---
[[[[ 67 101] [ 36 114] [  1  23]] [[ 13  16] [ 27 107] [  5  50]] [[ 44  84] [ 98 100] [111 117]] [[ 87  99] [ 97 110] [ 20  34]]] [[[ 10  88] [ 39 105] [ 72 113]] [[ 11  52] [  8  83] [  9  37]] [[ 19  22] [ 66  81] [  4  55]] [[ 25  92] [ 53 109] [ 74  75]]] [[[  3  71] [ 82  93] [ 15  42]] [[  2  30] [ 35  94] [ 43  80]] [[ 17  90] [ 28  64] [  6  18]] [[ 65  68] [ 24  86] [ 21  57]]] [[[ 85  91] [ 47 104] [  0  60]] [[ 62  70] [ 45 102] [ 26  51]] [[ 49  96] [112 115] [ 40  61]] [[ 12 118] [ 32 116] [ 46  89]]] [[[ 14  58] [ 38  73] [ 31 106]] [[ 48 108] [ 76  77] [  7 103]] [[ 63 119] [ 69  78] [ 54  59]] [[ 29  41] [ 33  56] [ 79  95]]]]
---
[[[[ 67 101] [ 36 114] [  1  23]] [[ 13  16] [ 27 107] [  5  50]] [[ 44  84] [ 98 100] [111 117]] [[ 87  99] [ 97 110] [ 20  34]]] [[[ 10  88] [ 39 105] [ 72 113]] [[ 11  52] [  8  83] [  9  37]] [[ 19  22] [ 66  81] [  4  55]] [[ 25  92] [ 53 109] [ 74  75]]] [[[  3  71] [ 82  93] [ 15  42]] [[  2  30] [ 35  94] [ 43  80]] [[ 17  90] [ 28  64] [  6  18]] [[ 65  68] [ 24  86] [ 21  57]]] [[[ 85  91] [ 47 104] [  0  60]] [[ 62  70] [ 45 102] [ 26  51]] [[ 49  96] [112 115] [ 40  61]] [[ 12 118] [ 32 116] [ 46  89]]] [[[ 14  58] [ 38  73] [ 31 106]] [[ 48 108] [ 76  77] [  7 103]] [[ 63 119] [ 69  78] [ 54  59]] [[ 29  41] [ 33  56] [ 79  95]]]]
---

軸3の要素は、120個(軸0のshape * 軸1のshape * 軸2のshape * 軸3のshape = 5 * 4 * 3 * 2 = 120)になります。
具体的には、以下のようになります。

  1. 101
  2. 67
    --
  3. 36
  4. 114
    ...

軸3に対応するshapeが2なので、比較する組み合わせは、2個で1組になります。
これは、一番最深の配列の要素となります。
つまり、最深の配列の要素をそれぞれソートする事になります。
具体的には、(101, 67), (36, 114), ...となります。
これらをソートして結果が、(67, 101), (36, 114), ...となります。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?