前回マハラノビス距離を使った分類をしたのですが
実際どのような形で分類されているのか今回は機械学習っぽく実装してみました。
ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance
from sklearn.datasets import make_blobs
データの作成
x_train, y_train = make_blobs(n_samples=300, centers=4, random_state=0, cluster_std=1.0)
マハラノビス分類器
class MahalanobisClassifier:
def __init__(self):
self.x_train = None
self.y_train = None
self.cov_i = None
def fit(self, x_train, y_train):
self.x_train = x_train
self.y_train = y_train
cov = np.cov(x_train.T)
self.cov_i = np.linalg.inv(cov)
def predict(self, x_test):
predictions = []
for i in range(len(x_test)):
tmp_mrv = []
for j in range(len(self.x_train)):
tmp_mrv.append([self.y_train[j], distance.mahalanobis(x_test[i], self.x_train[j], self.cov_i)])
tmp_mrv = sorted(tmp_mrv, key=lambda x: x[1])
predictions.append(tmp_mrv[0][0]) # 最も近い点のクラスを返す
return np.array(predictions)
描画関数
def visualize_classifier(model, X, y, ax=None, cmap='brg'):
ax = ax or plt.gca()
# トレーニングデータのプロット
ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=cmap,
clim=(y.min(), y.max()), zorder=3)
ax.axis('tight')
ax.axis('off')
xlim = ax.get_xlim()
ylim = ax.get_ylim()
# 決定境界の計算
xx, yy = np.meshgrid(np.linspace(*xlim, num=200),
np.linspace(*ylim, num=200))
grid_points = np.c_[xx.ravel(), yy.ravel()]
Z = model.predict(grid_points).reshape(xx.shape)
# 結果を色分けして描画
n_classes = len(np.unique(y))
ax.contourf(xx, yy, Z, alpha=0.3,
levels=np.arange(n_classes + 1) - 0.5,
cmap=cmap, clim=(y.min(), y.max()),
zorder=1)
ax.set(xlim=xlim, ylim=ylim)
では描画してみましょう
model = MahalanobisClassifier()
model.fit(x_train, y_train)
plt.figure(figsize=(8, 6))
visualize_classifier(model, x_train, y_train)
plt.show()
これだけ見るとかなり柔軟なモデルになっていると思いますが、まあでも過学習も心配ですね。
ちなみにですが今回はChatGPTを使って作りました。