前回クラスを作ってマハラノビス分類をできるようにしましたので今回はScikit-Learn風にマハラノビスkNNに改良してみました。
マハラノビスk-NN
from collections import Counter
import numpy as np
import pandas as pd
from scipy.spatial import distance
class MahalanobisKNN:
def __init__(self, k=3):
self.k = k
self.x_train = None
self.y_train = None
self.cov_i = None
def fit(self, x_train, y_train):
if isinstance(x_train, pd.DataFrame):
x_train = x_train.values
if isinstance(y_train, pd.Series):
y_train = y_train.values
self.x_train = x_train
self.y_train = y_train
cov = np.cov(x_train.T)
self.cov_i = np.linalg.inv(cov)
def predict(self, x_test):
if isinstance(x_test, pd.DataFrame):
x_test = x_test.values
predictions = []
for i in range(len(x_test)):
distances = []
for j in range(len(self.x_train)):
d = distance.mahalanobis(x_test[i], self.x_train[j], self.cov_i)
distances.append((self.y_train[j], d))
distances = sorted(distances, key=lambda x: x[1])[:self.k]
class_counts = Counter([label for label, _ in distances])
predictions.append(class_counts.most_common(1)[0][0])
return np.array(predictions)
で思いついたのが品質の教科書だったのでそのデータを使ってみます。
ロットをインデックスにして判定を目的変数にします。
(データが1レコード見切れています)
df.index = df["ロット"]
df = df.drop("ロット", axis=1)
df = df.replace("不適合", 0)
df = df.replace("適合", 1)
x = df.drop("判定", axis=1)
y = df["判定"]
hold-outで精度を測定します。
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1, test_size=0.2)
マハラノビスkNN
model = MahalanobisKNN(k=3)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 0.67 1.00 0.80 2
1 1.00 0.50 0.67 2
accuracy 0.75 4
macro avg 0.83 0.75 0.73 4
weighted avg 0.83 0.75 0.73 4
比較対象をLightGBMにしてみます。
from lightgbm import LGBMClassifier as LGBMC
model = LGBMC()
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(classification_report(y_test, y_pred))
precision recall f1-score support
0 0.50 1.00 0.67 2
1 0.00 0.00 0.00 2
accuracy 0.50 4
macro avg 0.25 0.50 0.33 4
weighted avg 0.25 0.50 0.33 4
意外と精度で勝った!
では今度はオープンデータでやってみます。
Wineデータではこうなりました。
LightGBM
precision recall f1-score support
1 1.00 0.93 0.96 14
2 0.93 1.00 0.96 13
3 1.00 1.00 1.00 9
accuracy 0.97 36
macro avg 0.98 0.98 0.98 36
weighted avg 0.97 0.97 0.97 36
マハラノビスkNN
precision recall f1-score support
1 0.93 1.00 0.97 14
2 0.92 0.92 0.92 13
3 0.88 0.78 0.82 9
accuracy 0.92 36
macro avg 0.91 0.90 0.90 36
weighted avg 0.92 0.92 0.91 36
オープンデータでは精度で負けました。