sumやmeanについてくるaxisのパラメーター。頭がこんがらがってしまったので、for文で置き換えて整理してみました。

import numpy as np

A = np.arange(30).reshape(2,3,5)
print(A)

# axis = 0の場合
S0 = np.zeros(A.shape[1:], dtype="int32") #=(3,5)
for i in range(A.shape[0]):
    S0 += A[i, :, :]

print(S0)
print(A.sum(axis=0))


# axis = 1の場合
S1 = np.zeros(A.shape[::2], dtype="int32") #=(2,5)
for j in range(A.shape[1]):
    S1 += A[:, j, :]

print(S1)
print(A.sum(axis=1))


# axis = 2の場合
S2 = np.zeros(A.shape[:-1], dtype="int32") #=(2,3)
for k in range(A.shape[2]):
    S2 += A[:, :, k]

print(S2)
print(A.sum(axis=2))

結果:

>>> print(A)
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]]

>>> print(S0)
[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]
>>> print(A.sum(axis=0))
[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

>>> print(S1)
[[15 18 21 24 27]
 [60 63 66 69 72]]
>>> print(A.sum(axis=1))
[[15 18 21 24 27]
 [60 63 66 69 72]]

>>> print(S2)
[[ 10  35  60]
 [ 85 110 135]]
>>> print(A.sum(axis=2))
[[ 10  35  60]
 [ 85 110 135]]

理解のためにfor文で書き換えただけで、for文を使った書き方は遅いだけなのでやめたほうがいいです。
行、列で集計と覚えていると3階以上のときにわけわからなくなるので、インデックスで覚えたほうがわかりやすいかと思います。

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.