numpyの3次元配列を頭でイメージするのが苦手で、axis=1ってどこを指すんだっけ?みたいに毎回なるので自分なりの理解をまとめます。以下、「配列」はnp.arrayのことを指します。
きっかけ
『機械学習のエッセンス』(著: 加藤公一) で、二次元ベクトルを7つ並べた二次元配列Xと3つ並べた二次元配列cluster_centersがあったときに、 両者のベクトル間の距離の二乗を総当たりで求める計算を以下のようにしていました(第三版 pp.354-356から抜粋)。
>>> X = np.array([[1, 2],
[2, 3],
[3, 4],
[4, 5],
[5, 6],
[7, 9]])
>>> cluster_centers = np.array([[1, 1],
[2, 2],
[3, 3]])
>>> ((X[:, :, np.newaxis]
- cluster_centers.T[np.newaxis, :, :])**2).sum(axis=1)
array([[ 1, 1, 5],
[ 5, 1, 1],
[ 13, 5, 1],
[ 25, 13, 5],
[ 41, 25, 13],
[ 61, 41, 25],
[100, 74, 52]])
エレガントすぎて、ちょっと何やってるのかわかりませんでした。泣
np.newaxisとは何ぞやというところから始まり、Xとcluster_centersでnp.newaxisの位置が違ったり、転置したりsumしたりで何が何だか……(本では補足説明をしてくれてますが、それでも私にはハードルが高すぎました)
という訳でこれを私が理解するまでの道のりを記しておきます。長くなったので前後半に記事を分けます。本記事では配列のイメージとnp.newaxisが何ぞやっていうところまで。後半はこちら。
STEP 1. 配列の基本イメージ
おさらい。
一次元配列
>>> x = np.array([1, 2, 3])
ベクトルです。ひとまず、表示の通りの横ベクトル $(1, 2, 3)$ をイメージします。
>>> y = np.array([1])
また、要素が一つだけでも一次元「配列」です。要素が増える可能性があるので、横にずらっと並んだマス目に一つだけ数字が入っているイメージ。$(1, \ldots)$ みたいな。
二次元配列
>>> A = np.array([[1, 2, 3],
[4, 5, 6]])
見た目通り、横ベクトルが縦に連なっているイメージをします。行列としては
A = \left(
\begin{array}{ccc}
1 & 2 & 3 \\
4 & 5 & 6 \\
\end{array}
\right).
>>> B = np.array([[1, 2, 3]])
>>> C = np.array([[1],
[2],
[3]])
これも二次元配列です。Bは縦に連なるはずの横ベクトルが一つしかない状態で、縦横に並んだマス目の一列目だけに数字が入っているイメージ。Cは要素が1つのみの横ベクトルが縦に3つ並んでる。
配列の要素の取り出し方やaxisのイメージは以下のような感じ。
# インデックスは最初の数字に二次元目の軸(縦方向)を、2番目に横方向を指定する。
>>> A[1, 2]
6
# 数字のところをコロンにするとその軸全部とってくる。次元が一つ下がってベクトルになる。
>>> A[0, :]
array([1, 2, 3])
# 縦の軸をとってきても1次元になる。
>>> A[:, 1]
array([2, 5])
# 縦軸(二次元目)がaxis=0、横軸(一次元目)がaxis=1 (A.sumは軸ごとの合計)
>>> A.sum(axis=0)
array([5, 7, 9])
>>> A.sum(1) # "axis="は省略可
array([6, 15])
三次元配列
三次元配列は、二次元のマス目が奥行き方向に何枚もある感じ
>>> X = np.array([[[1, 2, 3],
[4, 5, 6]],
[[7, 8, 9],
[10, 11, 12]]])
>>> Y = np.array([[[1]]])
Yみたいなのも三次元配列。奥行き、縦、横方向にマス目は用意してあるが、一つしか埋まっていない状態。
上で二次元配列から一次元配列を取り出せたように、コロンを2つ用いると三次元配列から二次元配列を取り出すことができる。ただ、ちょっとイメージが掴みにくい。
# 最初の数字が三次元目の軸(奥行き方向)、2番目が二次元目の軸(縦)、3番目が一次元目(横)
>>> X[0, 1, 2]
6
# 奥行き軸(axis=0)の0番目を取り出す
>>> X[0, :, :]
array([[1, 2, 3],
[4, 5, 6]])
# 縦軸(axis=1)の1番目を取り出す
>>> X[:, 1, :]
array([[ 4, 5, 6],
[10, 11, 12]])
# 横軸(axis=2)の2番目を取り出す
>>> X[:, :, 2]
array([[ 3, 6],
[ 9, 12]])
STEP 2. 配列に軸を追加して次元を上げる
上でやったスライシングで配列を取り出す操作とは逆に、配列に新たに軸を追加して次元を上げることもできます。
一次元から二次元へ
>>> x = np.array([1, 2, 3])
# 軸の追加にはnp.newaxisを使う。
# axis=0(縦の軸)を追加
>>> x[np.newaxis, :]
array([[1, 2, 3]])
# axis=1の追加。xの要素を縦(axis=0)に並べて、横軸(axis=1)を作る。
>>> x[:, np.newaxis]
array([[1],
[2],
[3]])
# np.newaxisはNoneで代用可
>>> x[None, :]
array([[1, 2, 3]])
二次元から三次元へ
>>> A = np.array([[1, 2, 3],
[4, 5, 6]])
# axis=0(奥行きの軸)を追加
>>> A[np.newaxis, :, :]
array([[[1, 2, 3],
[4, 5, 6]]])
# axis=1(縦軸)を追加。xの要素を3次元格子の上の面に並べる。
>>> A[:, np.newaxis, :]
array([[[1, 2, 3]],
[[4, 5, 6]]])
# axis=2(横軸)を追加。xの要素を3次元格子のサイドの面に並べる。
>>> A[:, :, np.newaxis]
array([[[1],
[2],
[3]],
[[4],
[5],
[6]]])