0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

教師あり学習:分析 -1次元入力2クラス分類-

Last updated at Posted at 2024-09-10

1次元入力2クラス分類

入力情報が1元、分類するクラスが2つの場合のこと。
体重データXに対して、性別Tが生成されるような分類を決定する。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(seed=0)
X_min = 0
X_max = 2.5
X_n = 30
X_col = ['cornflowerblue', 'gray']

X = np.zeros(X_n)
T = np.zeros(X_n, dtype=np.uint8)

Dist_s = [0.4, 0.8]
Dist_w = [0.8, 1.6]
Pi = 0.5

for n in range(X_n):
    wk = np.random.rand()
    T[n] = 0 * (wk < Pi) + 1 * (wk >= Pi)
    X[n] = np.random.rand() * Dist_w[T[n]] + Dist_s[T[n]]

print('X=' + str(np.round(X, 2)))
print('T=' + str(T))

def show_data1(x, t):
    K = np.max(t) + 1
    for k in range(K):
        plt.plot(x[t == k], t[t == k], X_col[k], alpha=0.5, linestyle='none', marker='o')
    plt.grid(True)
    plt.ylim(-.5, 1.5)
    plt.xlim(X_min, X_max)
    plt.yticks([0, 1])

fig = plt.figure(figsize=(3, 3))
show_data1(X, T)
plt.show()

オスとメスを確率的に決定し、メスになる確率をPi=0.5として、ランダムに決定する。
1をオス、0をメスとすると、以下のように分布される。

実行結果

image.png

1.0を超えたあたりからオスとメスを分ける境界線があるように思える。
この境界線を決めるような線のことを決定境界線と呼ぶ。

この決定境界線を決める方法は線形回帰モデルを使い、直線にフィッティングする。
この際にオスのデータが大きな質量を取ってしまうと、決定境界線がオス側に引っ張られてしまう。
外れ値が大きくなるとオスとメスの識別に誤差が多くなってしまう。

image.png

ロジスティック回帰モデル

例に沿って、体重xに対するt=1(オス)である確率を条件付き確率といい、

P(t=1|x)

と表す。

データを一様分布(T=1, 0に決まるようなデータ)から生成されたもので考えてきたため、扱いやすい問題になっていた。
例えば、身長と体重のように体重xに対して身長tが一様でない場合が存在する。
(体重75kgの人の身長が170cmや175cmとばらつきがある場合)

ばらつきがある場合はガウス関数を使うと実際の分布を表すことが出来る。
このためガウス分布に従っているとして、P(t=1|x)はロジスティック回帰モデルで表せる。

ロジスティック回帰モデルは直線の式をシグモイド関数に入れたものになる。

直線の式
y = w_{0} x + w_{1}
シグモイド関数
\displaylines{ y = \frac{1}{1 + e^{-x}} }
ロジスティック回帰モデル
\displaylines{ y = δ(w_{0} x + w_{1}) = \frac{1}{1 + e^{\{-(w_{0} x + w_{1})\}}} }

以下はロジスティック回帰モデルと決定境界線をともに表示するコードになっている。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(seed=0)
X_min = 0
X_max = 2.5
X_n = 30
X_col = ['cornflowerblue', 'gray']

X = np.zeros(X_n)
T = np.zeros(X_n, dtype=np.uint8)

Dist_s = [0.4, 0.8]
Dist_w = [0.8, 1.6]
Pi = 0.5

for n in range(X_n):
    wk = np.random.rand()
    T[n] = 0 * (wk < Pi) + 1 * (wk >= Pi)
    X[n] = np.random.rand() * Dist_w[T[n]] + Dist_s[T[n]]

print('X=' + str(np.round(X, 2)))
print('T=' + str(T))

def logistic(x, w):
    y = 1 / (1 + np.exp(-(w[0] * x + w[1])))
    return y

def show_logistic(w):
    xb = np.linspace(X_min, X_max, 100)
    y = logistic(xb, w)
    plt.plot(xb, y, color='gray', linewidth=4)
    # 決定境界
    i = np.min(np.where(y > 0.5))
    B = (xb[i - 2] + xb[i]) / 2
    plt.plot([B, B], [-.5, 1.5], color='k', linestyle='-')
    plt.grid(True)
    return B

# test
W = [8, -10]
show_logistic(W)
plt.show()

image.png

交差エントロピー誤差

ロジスティック回帰モデルを使って、xがt=1になる確率を

y = δ(w_{0} x + w_{1}) = P(t=1|x)

と表す。

このときデータxから最もありえるパラメータw0, w1をもとめるために最尤推定という方法を使う。
P(t=1|x)となるような確率をyとするので、そうならない確率は(1-y)となり、最尤推定は以下のようになる。

P(t|x) = y^t (1-y)^{1-t}

tの個数がN個あるとき、データX=x0, x1…, xN-1に対して、クラスデータT=t0, t1…, tN-1の生成確率は、

P(T|X) = \prod_{k=0}^{N-1} P(t_{n}|x_{n}) = \prod_{k=0}^{N-1} y_{n}^{t_{n}} (1-y_{n})^{1-t_{n}}

といい、尤度という。

これを対数に取ると、

\log P(T|X) = \sum_{k=0}^{N-1} \{t_{n} \log y_{n} + (1-t_{n}) \log (1-y_{n})\}

元の式の最小を求めるために対数を取ると、最小は対数を取った式に-1をかけたものになる。これを交差エントロピー誤差という。

平均二乗誤差と同様に平均を取ることを平均交差エントロピー誤差といい、E(w)を以下に定義する。

E(w) = - \frac{1}{N} \log P(T|X) = - \frac{1}{N} \sum_{k=0}^{N-1} \{t_{n} \log y_{n} + (1-t_{n}) \log (1-y_{n})\}
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

np.random.seed(seed=0)
X_min = 0
X_max = 2.5
X_n = 30
X_col = ['cornflowerblue', 'gray']

X = np.zeros(X_n)
T = np.zeros(X_n, dtype=np.uint8)

Dist_s = [0.4, 0.8]
Dist_w = [0.8, 1.6]
Pi = 0.5

for n in range(X_n):
    wk = np.random.rand()
    T[n] = 0 * (wk < Pi) + 1 * (wk >= Pi)
    X[n] = np.random.rand() * Dist_w[T[n]] + Dist_s[T[n]]

print('X=' + str(np.round(X, 2)))
print('T=' + str(T))

# ロジスティック回帰モデル
def logistic(x, w):
    y = 1 / (1 + np.exp(-(w[0] * x + w[1])))
    return y

# ロジスティック回帰モデル 表示
def show_logistic(w):
    xb = np.linspace(X_min, X_max, 100)
    y = logistic(xb, w)
    plt.plot(xb, y, color='gray', linewidth=4)
    # 決定境界
    i = np.min(np.where(y > 0.5))
    B = (xb[i - 2] + xb[i]) / 2
    plt.plot([B, B], [-.5, 1.5], color='k', linestyle='-')
    plt.grid(True)
    return B

# 平均交差エントロピー誤差
def cee_logistic(w, x, t):
    y = logistic(x, w)
    cee = 0
    for n in range(len(y)):
        cee = cee - (t[n] * np.log(y[n]) + (1 - t[n]) * np.log(1 - y[n]))
    cee = cee / X_n
    return cee

# test
xn = 80         # 等高線表示の解像度
w_range = np.array([[0, 15], [-15, 0]])
x0 = np.linspace(w_range[0, 0], w_range[0, 1], xn)
x1 = np.linspace(w_range[1, 0], w_range[1, 1], xn)
xx0, xx1 = np.meshgrid(x0, x1)
C = np.zeros((len(x1), len(x0)))
w = np.zeros(2)

for i0 in range(xn):
    for i1 in range(xn):
        w[0] = x0[i0]
        w[1] = x1[i1]
        C[i1, i0] = cee_logistic(w, X, T)

# 表示
plt.figure(figsize=(12, 5))
plt.subplots_adjust(wspace=0.5)

ax = plt.subplot(1, 2, 1, projection='3d')
ax.plot_surface(xx0, xx1, C, color='blue', edgecolor='black', rstride=10, cstride=10, alpha=0.3)
ax.set_xlabel('$w_0$', fontsize=14)
ax.set_ylabel('$w_1$', fontsize=14)
ax.set_xlim(0, 15)
ax.set_ylim(-15, 0)
ax.set_zlim(0, 8)
ax.view_init(30, -95)

plt.subplot(1, 2, 2)
cont = plt.contour(xx0, xx1, C, 20, colors='black', levels=[0.26, 0.4, 0.8, 1.6, 3.2, 6.4])
cont.clabel(fmt='%1.1f', fontsize=8)
plt.xlabel('$w_0$', fontsize=14)
plt.ylabel('$w_1$', fontsize=14)
plt.grid(True)
plt.show()

image.png

上記の左図は平均交差エントロピー誤差関数になり、右図の等高線の真ん中に最小値がありそうだとわかる。

交差エントロピー誤差が最小になるパラメータに解析解を求めることができない。これは非線形のシグモイド関数を含んでいるためである。

数値解を求めるため勾配法を使い、偏微分する必要がある。
勾配法による解は、scipy.optimizeライブラリに含まれるminimizeを使う。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import minimize

np.random.seed(seed=0)
X_min = 0
X_max = 2.5
X_n = 30
X_col = ['cornflowerblue', 'gray']

X = np.zeros(X_n)
T = np.zeros(X_n, dtype=np.uint8)

Dist_s = [0.4, 0.8]
Dist_w = [0.8, 1.6]
Pi = 0.5

for n in range(X_n):
    wk = np.random.rand()
    T[n] = 0 * (wk < Pi) + 1 * (wk >= Pi)
    X[n] = np.random.rand() * Dist_w[T[n]] + Dist_s[T[n]]

print('X=' + str(np.round(X, 2)))
print('T=' + str(T))

def show_data1(x, t):
    K = np.max(t) + 1
    for k in range(K):
        plt.plot(x[t == k], t[t == k], X_col[k], alpha=0.5, linestyle='none', marker='o')
    plt.grid(True)
    plt.ylim(-.5, 1.5)
    plt.xlim(X_min, X_max)
    plt.yticks([0, 1])

# ロジスティック回帰モデル
def logistic(x, w):
    y = 1 / (1 + np.exp(-(w[0] * x + w[1])))
    return y

# ロジスティック回帰モデル 表示
def show_logistic(w):
    xb = np.linspace(X_min, X_max, 100)
    y = logistic(xb, w)
    plt.plot(xb, y, color='gray', linewidth=4)
    # 決定境界
    i = np.min(np.where(y > 0.5))
    B = (xb[i - 2] + xb[i]) / 2
    plt.plot([B, B], [-.5, 1.5], color='k', linestyle='-')
    plt.grid(True)
    return B

# 平均交差エントロピー誤差
def cee_logistic(w, x, t):
    y = logistic(x, w)
    cee = 0
    for n in range(len(y)):
        cee = cee - (t[n] * np.log(y[n]) + (1 - t[n]) * np.log(1 - y[n]))
    cee = cee / X_n
    return cee

# 平均交差エントロピー誤差
def dcee_logistic(w, x, t):
    y = logistic(x, w)
    dcee = np.zeros(2)
    for n in range(len(y)):
        dcee[0] = dcee[0] + (y[n] - t[n]) * x[n]
        dcee[1] = dcee[1] + (y[n] - t[n])
    dcee = dcee / X_n
    return dcee

# 勾配法によるパラメータサーチ
def fit_logistic(w_init, x, t):
    res1 = minimize(cee_logistic, w_init, args=(x, t), jac=dcee_logistic, method="CG")
    return res1.x

# test
plt.figure(1, figsize=(3, 3))
W_init = [1, -1]
W = fit_logistic(W_init, X, T)
print("w0 = {0:.2f}, w1 = {1:.2f}".format(W[0], W[1]))
B=show_logistic(W)
show_data1(X, T)
plt.ylim(-.5, 1.5)
plt.xlim(X_min, X_max)
cee = cee_logistic(W, X, T)
print("CEE = {0:.2f}".format(cee))
print("Boundary = {0:.2f} g".format(B))
plt.show()

image.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?