LoginSignup
4
3

More than 5 years have passed since last update.

Numpyのsumのaxisについて理解する

Posted at

sumやmeanについてくるaxisのパラメーター。頭がこんがらがってしまったので、for文で置き換えて整理してみました。

import numpy as np

A = np.arange(30).reshape(2,3,5)
print(A)

# axis = 0の場合
S0 = np.zeros(A.shape[1:], dtype="int32") #=(3,5)
for i in range(A.shape[0]):
    S0 += A[i, :, :]

print(S0)
print(A.sum(axis=0))


# axis = 1の場合
S1 = np.zeros(A.shape[::2], dtype="int32") #=(2,5)
for j in range(A.shape[1]):
    S1 += A[:, j, :]

print(S1)
print(A.sum(axis=1))


# axis = 2の場合
S2 = np.zeros(A.shape[:-1], dtype="int32") #=(2,3)
for k in range(A.shape[2]):
    S2 += A[:, :, k]

print(S2)
print(A.sum(axis=2))

結果:

>>> print(A)
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]]

>>> print(S0)
[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]
>>> print(A.sum(axis=0))
[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

>>> print(S1)
[[15 18 21 24 27]
 [60 63 66 69 72]]
>>> print(A.sum(axis=1))
[[15 18 21 24 27]
 [60 63 66 69 72]]

>>> print(S2)
[[ 10  35  60]
 [ 85 110 135]]
>>> print(A.sum(axis=2))
[[ 10  35  60]
 [ 85 110 135]]

理解のためにfor文で書き換えただけで、for文を使った書き方は遅いだけなのでやめたほうがいいです。
行、列で集計と覚えていると3階以上のときにわけわからなくなるので、インデックスで覚えたほうがわかりやすいかと思います。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3