5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

「ディープラーニングの数学」9章補足 3値分類モデルの決定境界の表示

Last updated at Posted at 2022-01-05

はじめに

書籍「最短コースでわかるディープラーニングの数学」の著者です。
書籍のリンクはこちら
読者から9章の多値分類モデルに関して、決定境界表示はできないのかとの質問があり、試しに実装した結果を公開します。

Notebook全体のリンクは、こちらです。

考え方

決定境界とは、2つの確率値が等しくなる点の集まりです。つまり、 softmax関数の入力の時点で考えると、「2つの1次関数の計算結果が等しい場所」ということになります。
これを、NumPyの重み行列の言葉でいいかえると、重み行列を一列ずらした行列と元の行列の差をとり、その結果の行列と(1, x, y)の内積の結果がゼロになる場所です。
※ この考え方が成り立つのは正確にいうと例題で扱っている3値分類の場合に限定されます。4値以上の分類の場合は、決定境界の組み合わせがもっと複雑なので、計算もそれに対応して複雑になります。

これから紹介するコードはこの考えに基づいて実装されています。

実装コード紹介

これから紹介する実装コードは、書籍でいうとp.263で学習が終わり、三次元のグラフの描画までできた、その後に追加したものです。

x, yの描画領域の計算

決定境界の直線を元の散布図と同時に表示しようとすると。決定境界の直線に引っ張られて、散布図の点と関係ない領域まで表示されてしまいます。そのことを防ぐために、x, yの描画領域をあらかじめ計算し、 axis関数で明示的に描画領域を指定するようにします。
※ この考え方は一般論としては間違っていないのですが、今回の場合は、一番最後の図を見ると、こういう問題は起きなかったみたいです。つまり、結果論としては、y_min, y_maxの計算と、描画のときのaxis関数呼び出しは不要でした。

# x, yの描画領域計算
x_min = x_select[:,0].min()
x_max = x_select[:,0].max()
y_min = x_select[:,1].min()
y_max = x_select[:,1].max()
x_bound = np.array([x_min, x_max])

決定境界のyの値計算用関数

# 決定境界用の1次関数定義
def d_bound(x, i, W):
    W1 = W[:,[2,0,1]]
    W2 = W - W1
    w = W2[:,i]
    v = -1/w[2]*(w[1]*x + w[0])
    return v

行列W2は上の説明に基づいて実装しています。
その後の計算は、

$ w_0 + w_1 x + w_2 y = 0$

をyについて解いた結果です。

決定境界のyの値を計算

今、定義した関数を使って、決定境界のyの値を計算していきます。決定境界は全部で3つあります。

# 決定境界のyの値を計算
y0_bound = d_bound(x_bound, 0, W)
y1_bound = d_bound(x_bound, 1, W)
y2_bound = d_bound(x_bound, 2, W)

# 結果確認
print(y0_bound)
print(y1_bound)
print(y2_bound)
[2.90696266 5.10075533]
[2.13912805 3.9149779 ]
[3.58457746 6.14720506]

グラフ描画

これで準備は全部整いました。あとは、今計算で出したyの値を使ってplot関数で直線を書けばいいだけとなります。

# 散布図と決定境界の標示

#グラフサイズ指定
plt.figure(figsize=(8,8))

# 元データをグループ分け
x_t0 = x_select[y_org == 0]
x_t1 = x_select[y_org == 1]
x_t2 = x_select[y_org == 2]

# xとyの範囲を明示的に指定
plt.axis([x_min, x_max, y_min, y_max])

#  散布図の表示
plt.scatter(x_t0[:,0], x_t0[:,1], marker='x', c='k', s=50, label='0 (setosa)')
plt.scatter(x_t1[:,0], x_t1[:,1], marker='o', c='b', s=50, label='1 (versicolour)')
plt.scatter(x_t2[:,0], x_t2[:,1], marker='+', c='k', s=50, label='2 (virginica)')

# 決定境界の標示
plt.plot(x_bound, y0_bound, label='2_0')
plt.plot(x_bound, y1_bound, linestyle=':',label='0_1')
plt.plot(x_bound, y2_bound,linestyle='-.',label='1_2')

# 軸ラベル設定
plt.xlabel('sepal_length', fontsize=14)
plt.ylabel('petal_length', fontsize=14)
plt.xticks(size=14)
plt.yticks(size=14)
plt.legend(fontsize=14)
plt.show()

下のようなグラフが表示されれば成功です!

スクリーンショット 2022-01-05 22.15.42.png

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?