a = np.arange(20).reshape(4,5)
=> array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
jcol = [1, 1, 2, 3]
第0,1,2,3行からそれぞれ第 1 1 2 3 列の数値を取り出したいとする。つまり a[i, jcol[i]]
からなる [1 6 12 18]
を得たい。
Python でループをまわすことなく高速にこれをするには:
np.choose(jcol, a.T)
=> array([ 1, 6, 12, 18])
応用
想定している応用は y[i, j]
が i番目のデータの答えが j である確率で、 t[i]
が正解の index だとすると、交差エントロピー誤差
$$
E = - \sum_i \log y[i, t[i]]
$$
は
-np.sum(np.log(np.choose(t, y.T) + 1.0e-7))
と書ける。 +1.0e-7
は log が発散しないように下駄を履かせてる。
参考