A = np.random.randint(0,10,(3,4,3,4))
print(A)
# solution by flattening the last two dimensions into one
# (useful for functions that don't accept tuples for axis argument)
print(A.shape)
AS = A.shape[:-2] + (-1,)
print(AS)
Ar = A.reshape(AS)
print(Ar)
sum = Ar.sum(axis=-1)
print(sum)
- e.g.
[[[[6 1 1 3]
[2 6 4 0]
[3 7 2 6]]
[[8 0 6 0]
[9 7 7 6]
[1 4 0 3]]
[[7 9 9 7]
[1 4 2 9]
[5 7 4 7]]
[[5 0 3 6]
[3 7 0 3]
[3 1 4 4]]]
[[[7 7 0 9]
[6 1 9 5]
[0 2 2 4]]
[[0 2 3 9]
[4 6 1 2]
[1 2 7 7]]
[[9 1 5 3]
[3 4 2 0]
[5 2 2 6]]
[[0 4 2 2]
[7 5 1 2]
[5 7 9 0]]]
[[[6 6 9 2]
[6 5 3 7]
[9 7 5 4]]
[[6 6 3 8]
[6 7 2 5]
[9 4 4 8]]
[[7 3 1 4]
[0 2 0 9]
[5 0 6 1]]
[[4 3 3 1]
[8 8 0 0]
[1 0 7 3]]]]
(3, 4, 3, 4)
(3, 4, -1)
[[[6 1 1 3 2 6 4 0 3 7 2 6]
[8 0 6 0 9 7 7 6 1 4 0 3]
[7 9 9 7 1 4 2 9 5 7 4 7]
[5 0 3 6 3 7 0 3 3 1 4 4]]
[[7 7 0 9 6 1 9 5 0 2 2 4]
[0 2 3 9 4 6 1 2 1 2 7 7]
[9 1 5 3 3 4 2 0 5 2 2 6]
[0 4 2 2 7 5 1 2 5 7 9 0]]
[[6 6 9 2 6 5 3 7 9 7 5 4]
[6 6 3 8 6 7 2 5 9 4 4 8]
[7 3 1 4 0 2 0 9 5 0 6 1]
[4 3 3 1 8 8 0 0 1 0 7 3]]]
[[41 51 71 39]
[52 44 42 44]
[69 68 38 38]]