1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

前回マハラノビス距離を使った分類をしたのですが

実際どのような形で分類されているのか今回は機械学習っぽく実装してみました。

ライブラリのインポート

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance
from sklearn.datasets import make_blobs

データの作成

x_train, y_train = make_blobs(n_samples=300, centers=4, random_state=0, cluster_std=1.0)

マハラノビス分類器

class MahalanobisClassifier:
    def __init__(self):
        self.x_train = None
        self.y_train = None
        self.cov_i = None
    
    def fit(self, x_train, y_train):
        self.x_train = x_train
        self.y_train = y_train
        cov = np.cov(x_train.T)
        self.cov_i = np.linalg.inv(cov)
    
    def predict(self, x_test):
        predictions = []
        for i in range(len(x_test)):
            tmp_mrv = []
            for j in range(len(self.x_train)):
                tmp_mrv.append([self.y_train[j], distance.mahalanobis(x_test[i], self.x_train[j], self.cov_i)])
            tmp_mrv = sorted(tmp_mrv, key=lambda x: x[1])
            predictions.append(tmp_mrv[0][0])  # 最も近い点のクラスを返す
        return np.array(predictions)

描画関数

def visualize_classifier(model, X, y, ax=None, cmap='brg'):
    ax = ax or plt.gca()
    
    # トレーニングデータのプロット
    ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=cmap,
               clim=(y.min(), y.max()), zorder=3)
    ax.axis('tight')
    ax.axis('off')
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    
    # 決定境界の計算
    xx, yy = np.meshgrid(np.linspace(*xlim, num=200),
                         np.linspace(*ylim, num=200))
    grid_points = np.c_[xx.ravel(), yy.ravel()]
    Z = model.predict(grid_points).reshape(xx.shape)
    
    # 結果を色分けして描画
    n_classes = len(np.unique(y))
    ax.contourf(xx, yy, Z, alpha=0.3,
                levels=np.arange(n_classes + 1) - 0.5,
                cmap=cmap, clim=(y.min(), y.max()),
                zorder=1)
    ax.set(xlim=xlim, ylim=ylim)

では描画してみましょう

model = MahalanobisClassifier()
model.fit(x_train, y_train)
plt.figure(figsize=(8, 6))
visualize_classifier(model, x_train, y_train)
plt.show()

image.png

これだけ見るとかなり柔軟なモデルになっていると思いますが、まあでも過学習も心配ですね。

ちなみにですが今回はChatGPTを使って作りました。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?