25
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Python「numpy.sum(...)」のaxisオプション指定まとめ

Last updated at Posted at 2020-02-24

はじめに

numpy.ndarray.sum(...)numpy.ndarray.mean(...) などで指定可能な axis オプション についてのまとめです。2次元および3次元のnumpy配列を対象に、図を使って解説します。

2次元のnumpy配列の場合

まずは、2次元(ndim=2)のnumpy配列を対象にします。例として、次のような3行4列2次元のnumpy配列 x を考えます。

m1.png

この配列 x は次のコードで生成できます。

shape=(3,4)のnumpy配列を生成
import numpy as np
x = np.arange(1,12+1).reshape(3,4)
print(x)
#print(x.ndim)  # -> 2
#print(x.shape) # -> (3, 4)
実行結果
[ [ 1  2  3  4]
  [ 5  6  7  8]
  [ 9 10 11 12] ]

axis=None を指定

axis=None がデフォルト引数になっているので、x.sum() でも x.sum(axis=None) でも同じ動作をします。numpy.sum(...) / numpy.ndarray.sum(...) のリファレンス

sum(axis=None)
s = x.sum(axis=None)
# print(type(s)) # -> <class 'numpy.int64'>
# print(s.ndim)  # -> 0
# print(s.shape) # -> ()
print(s)

axis=None を指定すると配列を構成する全ての要素についての合計が求められます。具体的には $1+2+3+\cdots+11+12=$$\bf{78}$ が計算されます。

実行結果
78

axis=0 を指定

sum(axis=0)
s = x.sum(axis=0)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 1
# print(s.shape) # -> (4,)
print(s)
実行結果
[15 18 21 24]

axis=0 を指定すると、行方向に要素が合計されます。具体的には、$1+5+9=$$\bf{15}$、$2+6+10=$$\bf{18}$、$3+7+11=$$\bf{21}$、$4+8+12=$$\bf{24}$ が計算され [15 18 21 24] となっています。

ma0.png
ma0s.png

【注意】:各行の要素合計ではなくて、行方向(行が 0, 1, 2, 3, ... と大きくなる方向)に要素を合計していることに注意してください(ここを勘違いすると一気に???に陥ります(経験談))。

axis=1 を指定

sum(axis=1)
s = x.sum(axis=1)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 1
# print(s.shape) # -> (3,)
print(s)
実行結果
[10 26 42]

axis=1 を指定すると、列方向(列が 0, 1, 2, 3, ... と大きくなる方向)に要素が合計されています。具体的には、$1+2+3+4=$$\bf{10}$、$5+6+7+8=$$\bf{26}$、$9+10+11+12=$$\bf{42}$ が計算されて [10 26 42] となっています。

ma1.png
ma1s.png

3次元のnumpy配列の場合

つづいて、3次元のnumpy配列を対象にします。例として、shape が (3,4,2) のnumpy配列 x を扱います。

shape=(3,4,2)のnumpy配列を生成
import numpy as np
x = np.arange(1,24+1).reshape(3,4,2)
print(x)
# print(x.ndim)  # -> 3
# print(x.shape) # -> (3, 4, 2)
実行結果(読みやすくするために改行位置などの整形をしています)
[  [  [ 1  2]  [ 3  4] [ 5  6]  [ 7  8]  ]
   [  [ 9 10]  [11 12]  [13 14]  [15 16]  ] 
   [  [17 18]  [19 20]  [21 22]  [23 24]  ]  ]

axis=2 で分離して描くと次のようになります。

m2.png

axis=None を指定

sum(axis=None)
s = x.sum(axis=None)
# print(type(s)) # -> <class 'numpy.int64'>
# print(s.ndim)  # -> 0
# print(s.shape) # -> ()
print(s)

すべての要素についての合計です。具体的には $1+2+3+4+\cdots+22+23+24=300$ が計算されます。

実行結果
300

axis=0 を指定

sum(axis=0)
s = x.sum(axis=0)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 2
# print(s.shape) # -> (4, 2)
print(s)
実行結果
[ [27 30]
  [33 36]
  [39 42]
  [45 48] ]

axis=0 の指定により、行方向に要素が合計されます。

上記の実行結果の最初の要素 [27 30] は、$1+9+17=$$\bf{27}$、$2+10+18=$$\bf{30}$ のように計算された結果です。

m2a0.png
m2as.png

shape=(3,4,2) であった x が、x.sum(axis=0) で shape=(4,2) になります。

axis=1 を指定

sum(axis=1)
s = x.sum(axis=1)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 2
# print(s.shape) # -> (3, 2)
print(s)
実行結果
[ [16 20]
  [48 52]
  [80 84] ]

axis=1 の指定により、列方向に要素が合計されます。上記実行結果の最初の要素 [16 20] は、$1+3+5+7=$$\bf{16}$、$2+4+6+8=$$\bf{20}$ のように計算された結果です。

m2a1.png
m2a1s.png

shape=(3,4,2) であった x が、x.sum(axis=1) で shape=(3,2) になります。

axis=2 を指定

sum(axis=2)
s = x.sum(axis=2)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 2
# print(s.shape) # -> (3, 4)
print(s)
実行結果
[ [ 3  7 11 15]
  [19 23 27 31]
  [35 39 43 47] ]

axis=2 の指定により、画像であればチャンネル方向に要素が合計されます。上記実行結果の最初の左上の要素 3 は $1+2=$$\bf{3}$、右下の要素 47 は$23+24=$$\bf{47}$ と計算された結果です。

m2a2.png
m2a2s.png

shape=(3,4,2) であった x が、x.sum(axis=2) で shape=(3,4) になります。

axis=(0,1) を指定

NumPy 1.7 以降では、axis をタプルで指定することができます。

axis=(0,1) の指定により、行と列についての合計が求められます。なお、axis=(1,0) としても結果は同じです(順番は関係ありません)。

sum(axis=(0,1))
s = x.sum(axis=(0,1))
#print(type(s)) # -> <class 'numpy.ndarray'>
#print(s.ndim)  # -> 1
#print(s.shape) # -> (2,)
print(s)
実行結果
[144 156]

実行結果の最初の要素 144 は、x[:,:,0] の全要素の合計、つまり $1+3+5+7+\cdots + 19+21+23$ を計算した結果になります。

m2a01.png
m2a01s.png

shape=(3,4,2) であった x が、x.sum(axis=(0,1)) で shape=(2,) になります。

axis=(1,2) を指定

axis=(2,1) としても同じです。

sum(axis=(1,2))
s = x.sum(axis=(1,2))
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 1
# print(s.shape) # -> (3,)
print(s)
実行結果
[ 36 100 164]

実行結果の最初の要素 36 は、下図の青枠の要素の合計、つまり、$(1+3+5+7)+(2+4+6+8)$ を計算した結果になります。

m2a12.png
m2a12s.png

shape=(3,4,2) であった x が、x.sum(axis=(1,2)) で shape=(3,) になります。

axis=(0,2) を指定

axis=(2,0) でも同じです。

sum(axis=(0,2))
s = x.sum(axis=(0,2))
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim)  # -> 1
# print(s.shape) # -> (4,)
print(s)
実行結果
[57 69 81 93]

実行結果の最初の要素 57 は、下図の青枠の要素の合計、つまり、$(1+9+17)+(2+10+18)$ を計算した結果になります。

m2a02.png
m2a02s.png

shape=(3,4,2) であった x が、x.sum(axis=(0,2)) で shape=(4,) になります。

axis=(0,1,2) を指定

sum(axis=(0,1,2)) は、sum(axis=None) または sum() と同じで全要素の合計が計算されます。

sum(axis=(0,1,2))
s = x.sum(axis=(0,1,2))
#print(type(s)) # -> <class 'numpy.int64'>
#print(s.ndim)  # -> 0
#print(s.shape) # -> ()
print(s)
実行結果
300

shape=(3,4,2) であった x が、x.sum(axis=(0,1,2)) で shape=(0) になります。

25
24
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
25
24

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?