よくnumpyのaxisで迷うのでその備忘録。
[[[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 7],
[3, 4, 5, 6, 7, 8]],
[[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 7],
[3, 4, 5, 6, 7, 8]]]
とかなっているとaxisが何が何だか解らなくなるので簡単な判別法
。
一番外の[]がaxis=0で内側の[]に入るごとにaxisの値が一つずつ増えていく。
例
C=[[[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 7],
[3, 4, 5, 6, 7, 8]],
[[1, 2, 3, 4, 5, 6],
[2, 3, 4, 5, 6, 7],
[3, 4, 5, 6, 7, 8]]]
np.sum(C,axis=0)だと一番外のカッコ内レベルで計算されるので[3行5列]単位で計算され
[3行5列]+[3行5列]
array([[ 2, 4, 6, 8, 10, 12],
[ 4, 6, 8, 10, 12, 14],
[ 6, 8, 10, 12, 14, 16]])
np.sum(C,axis=2)
外側から3番目のカッコ内を足すので
[[[1行目5列分の和], [2行目5列分の和], [3行目5列分の和]]
[4行目5列分の和], [5行目5列分の和], [6行目5列分の和]]]
array([[21, 27, 33],
[21, 27, 33]])
になる
理解のためにもうちょっと話しましょうか。
a = [[[1,2,6,4,6], [9,8,0,-1,-5]],
[[5,89,0.1,56,74],[666,2,15,45555,75]]]
とあるとき
np.ndim(a)
で3が返るので3次元配列と分かる。
のでそれらの最小値を考えていきましょう!
次元数 - 1 = 2をaxisにセット
np.min(a,axis=2)
で一番内側の
[1,2,6,4,6]の最小値1,
[9,8,0,-1,-5]の最小値-5
[5,89,0.1,56,74]の最小値0.1,
[666,2,15,45555,75]の最小値2
となり
array([[ 1. , -5. ],
[ 0.1, 2. ]])
np.min(a,axis=1)で一番内側から一つ次元が外側に行き、
[[1, 2, 6, 4, 6],
[9, 8, 0,-1,-5]]
[1, 2, 0,-1,-5]
[5, 89, 0.1, 56, 74],
[666, 2, 15, 45555, 75]
の列の最小値
[5, 2, 0.1, 56, 74]
に分けて計算され結果として
array([[ 1. , 2. , 0. , -1. , -5. ],
[ 5. , 2. , 0.1, 56. , 74. ]])
がでます。
結局何が言いたいかというとaxis=0にして外側からの認識で計算していくか?
axis=次元数−1
にして最も内側の次元の認識で計算していくか?
お好きな方をどうぞ
という事です。