sumやmeanについてくるaxisのパラメーター。頭がこんがらがってしまったので、for文で置き換えて整理してみました。
import numpy as np
A = np.arange(30).reshape(2,3,5)
print(A)
# axis = 0の場合
S0 = np.zeros(A.shape[1:], dtype="int32") #=(3,5)
for i in range(A.shape[0]):
S0 += A[i, :, :]
print(S0)
print(A.sum(axis=0))
# axis = 1の場合
S1 = np.zeros(A.shape[::2], dtype="int32") #=(2,5)
for j in range(A.shape[1]):
S1 += A[:, j, :]
print(S1)
print(A.sum(axis=1))
# axis = 2の場合
S2 = np.zeros(A.shape[:-1], dtype="int32") #=(2,3)
for k in range(A.shape[2]):
S2 += A[:, :, k]
print(S2)
print(A.sum(axis=2))
結果:
>>> print(A)
[[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]]
[[15 16 17 18 19]
[20 21 22 23 24]
[25 26 27 28 29]]]
>>> print(S0)
[[15 17 19 21 23]
[25 27 29 31 33]
[35 37 39 41 43]]
>>> print(A.sum(axis=0))
[[15 17 19 21 23]
[25 27 29 31 33]
[35 37 39 41 43]]
>>> print(S1)
[[15 18 21 24 27]
[60 63 66 69 72]]
>>> print(A.sum(axis=1))
[[15 18 21 24 27]
[60 63 66 69 72]]
>>> print(S2)
[[ 10 35 60]
[ 85 110 135]]
>>> print(A.sum(axis=2))
[[ 10 35 60]
[ 85 110 135]]
理解のためにfor文で書き換えただけで、for文を使った書き方は遅いだけなのでやめたほうがいいです。
行、列で集計と覚えていると3階以上のときにわけわからなくなるので、インデックスで覚えたほうがわかりやすいかと思います。