はじめに
numpy.ndarray.sum(...)
や numpy.ndarray.mean(...)
などで指定可能な axis オプション についてのまとめです。2次元および3次元のnumpy配列を対象に、図を使って解説します。
2次元のnumpy配列の場合
まずは、2次元(ndim=2)のnumpy配列を対象にします。例として、次のような3行4列の2次元のnumpy配列 x
を考えます。
この配列 x
は次のコードで生成できます。
import numpy as np
x = np.arange(1,12+1).reshape(3,4)
print(x)
#print(x.ndim) # -> 2
#print(x.shape) # -> (3, 4)
[ [ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12] ]
axis=None
を指定
axis=None
がデフォルト引数になっているので、x.sum()
でも x.sum(axis=None)
でも同じ動作をします。numpy.sum(...) / numpy.ndarray.sum(...) のリファレンス
s = x.sum(axis=None)
# print(type(s)) # -> <class 'numpy.int64'>
# print(s.ndim) # -> 0
# print(s.shape) # -> ()
print(s)
axis=None
を指定すると配列を構成する全ての要素についての合計が求められます。具体的には $1+2+3+\cdots+11+12=$$\bf{78}$ が計算されます。
78
axis=0
を指定
s = x.sum(axis=0)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 1
# print(s.shape) # -> (4,)
print(s)
[15 18 21 24]
axis=0
を指定すると、行方向に要素が合計されます。具体的には、$1+5+9=$$\bf{15}$、$2+6+10=$$\bf{18}$、$3+7+11=$$\bf{21}$、$4+8+12=$$\bf{24}$ が計算され [15 18 21 24]
となっています。
【注意】:各行の要素合計ではなくて、行方向(行が 0, 1, 2, 3, ... と大きくなる方向)に要素を合計していることに注意してください(ここを勘違いすると一気に???に陥ります(経験談))。
axis=1
を指定
s = x.sum(axis=1)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 1
# print(s.shape) # -> (3,)
print(s)
[10 26 42]
axis=1
を指定すると、列方向(列が 0, 1, 2, 3, ... と大きくなる方向)に要素が合計されています。具体的には、$1+2+3+4=$$\bf{10}$、$5+6+7+8=$$\bf{26}$、$9+10+11+12=$$\bf{42}$ が計算されて [10 26 42]
となっています。
3次元のnumpy配列の場合
つづいて、3次元のnumpy配列を対象にします。例として、shape が (3,4,2) のnumpy配列 x
を扱います。
import numpy as np
x = np.arange(1,24+1).reshape(3,4,2)
print(x)
# print(x.ndim) # -> 3
# print(x.shape) # -> (3, 4, 2)
[ [ [ 1 2] [ 3 4] [ 5 6] [ 7 8] ]
[ [ 9 10] [11 12] [13 14] [15 16] ]
[ [17 18] [19 20] [21 22] [23 24] ] ]
axis=2
で分離して描くと次のようになります。
axis=None
を指定
s = x.sum(axis=None)
# print(type(s)) # -> <class 'numpy.int64'>
# print(s.ndim) # -> 0
# print(s.shape) # -> ()
print(s)
すべての要素についての合計です。具体的には $1+2+3+4+\cdots+22+23+24=300$ が計算されます。
300
axis=0
を指定
s = x.sum(axis=0)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 2
# print(s.shape) # -> (4, 2)
print(s)
[ [27 30]
[33 36]
[39 42]
[45 48] ]
axis=0
の指定により、行方向に要素が合計されます。
上記の実行結果の最初の要素 [27 30]
は、$1+9+17=$$\bf{27}$、$2+10+18=$$\bf{30}$ のように計算された結果です。
shape=(3,4,2) であった x
が、x.sum(axis=0)
で shape=(4,2) になります。
axis=1
を指定
s = x.sum(axis=1)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 2
# print(s.shape) # -> (3, 2)
print(s)
[ [16 20]
[48 52]
[80 84] ]
axis=1
の指定により、列方向に要素が合計されます。上記実行結果の最初の要素 [16 20]
は、$1+3+5+7=$$\bf{16}$、$2+4+6+8=$$\bf{20}$ のように計算された結果です。
shape=(3,4,2) であった x
が、x.sum(axis=1)
で shape=(3,2) になります。
axis=2
を指定
s = x.sum(axis=2)
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 2
# print(s.shape) # -> (3, 4)
print(s)
[ [ 3 7 11 15]
[19 23 27 31]
[35 39 43 47] ]
axis=2
の指定により、画像であればチャンネル方向に要素が合計されます。上記実行結果の最初の左上の要素 3
は $1+2=$$\bf{3}$、右下の要素 47
は$23+24=$$\bf{47}$ と計算された結果です。
shape=(3,4,2) であった x
が、x.sum(axis=2)
で shape=(3,4) になります。
axis=(0,1)
を指定
NumPy 1.7 以降では、axis をタプルで指定することができます。
axis=(0,1)
の指定により、行と列についての合計が求められます。なお、axis=(1,0)
としても結果は同じです(順番は関係ありません)。
s = x.sum(axis=(0,1))
#print(type(s)) # -> <class 'numpy.ndarray'>
#print(s.ndim) # -> 1
#print(s.shape) # -> (2,)
print(s)
[144 156]
実行結果の最初の要素 144
は、x[:,:,0]
の全要素の合計、つまり $1+3+5+7+\cdots + 19+21+23$ を計算した結果になります。
shape=(3,4,2) であった x
が、x.sum(axis=(0,1))
で shape=(2,) になります。
axis=(1,2)
を指定
axis=(2,1)
としても同じです。
s = x.sum(axis=(1,2))
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 1
# print(s.shape) # -> (3,)
print(s)
[ 36 100 164]
実行結果の最初の要素 36
は、下図の青枠の要素の合計、つまり、$(1+3+5+7)+(2+4+6+8)$ を計算した結果になります。
shape=(3,4,2) であった x
が、x.sum(axis=(1,2))
で shape=(3,) になります。
axis=(0,2)
を指定
axis=(2,0)
でも同じです。
s = x.sum(axis=(0,2))
# print(type(s)) # -> <class 'numpy.ndarray'>
# print(s.ndim) # -> 1
# print(s.shape) # -> (4,)
print(s)
[57 69 81 93]
実行結果の最初の要素 57
は、下図の青枠の要素の合計、つまり、$(1+9+17)+(2+10+18)$ を計算した結果になります。
shape=(3,4,2) であった x
が、x.sum(axis=(0,2))
で shape=(4,) になります。
axis=(0,1,2)
を指定
sum(axis=(0,1,2))
は、sum(axis=None)
または sum()
と同じで全要素の合計が計算されます。
s = x.sum(axis=(0,1,2))
#print(type(s)) # -> <class 'numpy.int64'>
#print(s.ndim) # -> 0
#print(s.shape) # -> ()
print(s)
300
shape=(3,4,2) であった x
が、x.sum(axis=(0,1,2))
で shape=(0) になります。