LoginSignup
1
1

More than 1 year has passed since last update.

書籍「Pytorch&深層学習プログラミング」6章補足 決定境界表示プログラムの解説

Last updated at Posted at 2022-05-03

はじめに

書籍「Pytorch&深層学習プログラミング」の著者です。

Amazonリンク
サポートサイトリンク

6章p.227-p.228では、学習データの散布図と、学習の結果得られたロジスティック回帰モデルを基にした決定境界のグラフを重ね描きしています。

しかし、このプログラムの解説は省略しています。決定境界のグラフは説明変数が2個の時しか描画することができませんが、実際の分類問題で説明変数が2個しかないことはまずありません。つまり、決定境界グラフは、実用の観点であまり意味がないです。しかし、グラフを見ることがロジスティック回帰の動作イメージを持つにはとても適しているという理由で、このように可視化の結果のみ示す形にしています。
とはいえ、なぜこのプログラムで決定境界が表示できるのか知りたいという質問は何度か受けているので、その質問に答えるための解説記事をアップしました。

数学的な説明

そもそも「決定境界とはなにか」を、書籍p.209の図6-7を使って説明します。

上の図でnn.Linearとは、説明変数$x_1$と$x_2$を入力とする1次関数です。
その出力$u$にシグモイド関数$f(u)$をかけた結果が、予測値ypであり、ypの値が0.5より大きいかどうかで、予測結果が1か0かを判断するのが、ロジスティック回帰の処理方式でした。
シグモイド関数のグラフを見ればわかるとおり、「シグモイド関数の結果が0.5より大きいかどうか」は入力段階の$u$の言葉で言い換えると「uが正か負か」という話と同じです。
つまり、「nn.Linearの結果がちょうどゼロになる点の集合」が、今求めたい決定境界を満たす値の条件ということになります。更に今回は、決定境界はモデルの構造から直線であることがわかっています。

ちなみに、同じ二値分類モデルでも、サポートベクターマシンや、ディープラーニングモデルなどの場合、決定境界は直線にはなりません。その場合、決定境界の描画方法は今回と比較してはるかに複雑になります。著者の別書籍「Pythonで儲かるAIをつくる」の解説用に書いた、下記の記事がその場合の方法を説明したものとなります。
アルゴリズムが一目でわかる! Pythonによる決定境界表示

話が脱線しましたが、今回は目的が直線であることがわかっているので、「nn.Linearの結果がちょうどゼロになる点」 を2つ見つけて、その点を直線で結べばそれが求める決定境界の直線ということになります。

この条件を数式で表現してみましょう。変数$x_1$の重みを$w_1$、変数$x_2$の重みを$w_2$、バイアスを$b$で表します。
すると、線形関数の出力$u$は、次の式で表されます。

$u = w_1 x_1 + w_2 x_2 + b$

決定境界上の点ではこの値がゼロなので

$0 = w_1 x_1 + w_2 x_2 + b$

$w_2$がゼロでないという前提を付けた上で、この式を$x_2$について解くと次の結果になります。

$x_2 = -\dfrac{w_1 x_1 + b}{w_2}$

プログラム実装

この数式が、線形関数の属性によってどう表されるかを考えます。まず、線形関数内の2つの変数weightとbiasを操作しやすいように、外部変数に代入します。

# パラメータの取得
bias = net.l1.bias.data.numpy()
weight = net.l1.weight.data.numpy()

すると、上の数式の$b$が、変数biasになります。
また、$w_1$がweight[0,0] に、$w_2$がweight[0,1]になります。

そこで、上の数式を用いて、「ある$x_1$の座標値を入力に該当する決定境界上の$x_2$の座標値を求める関数」をdecision(x)とすると、この関数は次の形で実装されることになります。

# 決定境界描画用 x1の値から x2の値を計算する
def decision(x):
    return(-(bias + weight[0,0] * x)/ weight[0,1])

最初に描画領域の$x_1$の最小値、最大値を求め、今定義した関数を使って対応する$x_2$の値を求める実装は次のようになります。
ここでは今までの説明と違い、変数名を$x_1$→ $x$、$x_2$→$y$に変えている点に注意して下さい。

# 散布図のx1の最小値と最大値
xl = np.array([x_test[:,0].min(), x_test[:,0].max()])
yl = decision(xl)

# 結果確認
print(f'xl = {xl}  yl = {yl}')

ここで得られた xlylを使い、以下のコードにより、決定境界の直線は表示できることになります。

# 決定境界直線
plt.plot(xl, yl, c='b')

以上が、6.10節の実装コードのうち、書籍で説明していない部分の解説となります。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1