はじめに
書籍「最短コースでわかるディープラーニングの数学」の著者です。
書籍のリンクはこちら
読者から9章の多値分類モデルに関して、決定境界表示はできないのかとの質問があり、試しに実装した結果を公開します。
Notebook全体のリンクは、こちらです。
考え方
決定境界とは、2つの確率値が等しくなる点の集まりです。つまり、 softmax関数の入力の時点で考えると、「2つの1次関数の計算結果が等しい場所」ということになります。
これを、NumPyの重み行列の言葉でいいかえると、重み行列を一列ずらした行列と元の行列の差をとり、その結果の行列と(1, x, y)の内積の結果がゼロになる場所です。
※ この考え方が成り立つのは正確にいうと例題で扱っている3値分類の場合に限定されます。4値以上の分類の場合は、決定境界の組み合わせがもっと複雑なので、計算もそれに対応して複雑になります。
これから紹介するコードはこの考えに基づいて実装されています。
実装コード紹介
これから紹介する実装コードは、書籍でいうとp.263で学習が終わり、三次元のグラフの描画までできた、その後に追加したものです。
x, yの描画領域の計算
決定境界の直線を元の散布図と同時に表示しようとすると。決定境界の直線に引っ張られて、散布図の点と関係ない領域まで表示されてしまいます。そのことを防ぐために、x, yの描画領域をあらかじめ計算し、 axis関数で明示的に描画領域を指定するようにします。
※ この考え方は一般論としては間違っていないのですが、今回の場合は、一番最後の図を見ると、こういう問題は起きなかったみたいです。つまり、結果論としては、y_min, y_maxの計算と、描画のときのaxis関数呼び出しは不要でした。
# x, yの描画領域計算
x_min = x_select[:,0].min()
x_max = x_select[:,0].max()
y_min = x_select[:,1].min()
y_max = x_select[:,1].max()
x_bound = np.array([x_min, x_max])
決定境界のyの値計算用関数
# 決定境界用の1次関数定義
def d_bound(x, i, W):
W1 = W[:,[2,0,1]]
W2 = W - W1
w = W2[:,i]
v = -1/w[2]*(w[1]*x + w[0])
return v
行列W2は上の説明に基づいて実装しています。
その後の計算は、
$ w_0 + w_1 x + w_2 y = 0$
をyについて解いた結果です。
決定境界のyの値を計算
今、定義した関数を使って、決定境界のyの値を計算していきます。決定境界は全部で3つあります。
# 決定境界のyの値を計算
y0_bound = d_bound(x_bound, 0, W)
y1_bound = d_bound(x_bound, 1, W)
y2_bound = d_bound(x_bound, 2, W)
# 結果確認
print(y0_bound)
print(y1_bound)
print(y2_bound)
[2.90696266 5.10075533]
[2.13912805 3.9149779 ]
[3.58457746 6.14720506]
グラフ描画
これで準備は全部整いました。あとは、今計算で出したyの値を使ってplot関数で直線を書けばいいだけとなります。
# 散布図と決定境界の標示
#グラフサイズ指定
plt.figure(figsize=(8,8))
# 元データをグループ分け
x_t0 = x_select[y_org == 0]
x_t1 = x_select[y_org == 1]
x_t2 = x_select[y_org == 2]
# xとyの範囲を明示的に指定
plt.axis([x_min, x_max, y_min, y_max])
# 散布図の表示
plt.scatter(x_t0[:,0], x_t0[:,1], marker='x', c='k', s=50, label='0 (setosa)')
plt.scatter(x_t1[:,0], x_t1[:,1], marker='o', c='b', s=50, label='1 (versicolour)')
plt.scatter(x_t2[:,0], x_t2[:,1], marker='+', c='k', s=50, label='2 (virginica)')
# 決定境界の標示
plt.plot(x_bound, y0_bound, label='2_0')
plt.plot(x_bound, y1_bound, linestyle=':',label='0_1')
plt.plot(x_bound, y2_bound,linestyle='-.',label='1_2')
# 軸ラベル設定
plt.xlabel('sepal_length', fontsize=14)
plt.ylabel('petal_length', fontsize=14)
plt.xticks(size=14)
plt.yticks(size=14)
plt.legend(fontsize=14)
plt.show()
下のようなグラフが表示されれば成功です!