3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Python - 二つのデータセットからマハラノビス距離が等しい曲線を描く

Posted at

ソース

MahalanobisDistance.py
import numpy as np
import matplotlib.pyplot as plt
from sympy import *

CENTER1 = np.array([2.0, 3.0])
CENTER2 = np.array([-4.0, -1.0])

def calc_sum_of_square(data_list1, data_list2):
    return (len(data_list1) * sum([data1 * data2 for data1, data2 in zip(data_list1, data_list2)]) - sum(data_list1) * sum(data_list2))/(len(data_list1)*(len(data_list1)-1))

# 分散共分散行列計算
def calc_variance_covariance_matrix(data_list):
    return [[calc_sum_of_square(data_list1, data_list2)/(len(data_list1)-1) for data_list1 in data_list]
            for data_list2 in data_list]

def calc_average(data):
    return sum(data) / len(data)
            
if __name__ == "__main__":
    # CENTER1を中心としたランダムな点を50個生成
    data_set1 = CENTER1 + 3.0 * np.random.rand(50, 2)
    # CENTER2を中心としたランダムな点を50個生成
    data_set2 = CENTER2 + 4.7 * np.random.rand(50, 2)
    
    X1 = []
    Y1 = []
    X2 = []
    Y2 = []
    for data1, data2 in zip(data_set1, data_set2):
        X1.append(data1[0])
        Y1.append(data1[1])
        X2.append(data2[0])
        Y2.append(data2[1])
    
    # data_set1のxとyの分散共分散を計算
    vc1 = np.array(calc_variance_covariance_matrix([X1, Y1]))
    # data_set2のxとyの分散共分散を計算
    vc2 = np.array(calc_variance_covariance_matrix([X2, Y2]))
    det1 = vc1[0][0]*vc1[1][1] - vc1[0][1]**2
    det2 = vc2[0][0]*vc2[1][1] - vc2[0][1]**2
    x1_ave = calc_average(X1)
    y1_ave = calc_average(Y1)
    x2_ave = calc_average(X2)
    y2_ave = calc_average(Y2)
    
    # data_set1のマハラノビス距離:D1^2 = a11*x^2 + a22*y^2 + a12 * x * y + a1 * x + a2 * y + a0
    a11 = vc1[1][1]/det1
    a22 = vc1[0][0]/det1
    a12 = -2*vc1[0][1]/det1
    a1 = 2*(vc1[0][1]*y1_ave - vc1[1][1]*x1_ave)/det1
    a2 = 2*(vc1[0][1]*x1_ave - vc1[0][0]*y1_ave)/det1
    a0 = (vc1[1][1]*x1_ave**2 + vc1[0][0]*y1_ave**2 - 2*vc1[0][0]*x1_ave*y1_ave)/det1
    
    # data_set2のマハラノビス距離:D1^2 = b11*x^2 + b22*y^2 + b12 * x * y + b1 * x + b2 * y + b0
    b11 = vc2[1][1]/det2
    b22 = vc2[0][0]/det2
    b12 = -2*vc2[0][1]/det2
    b1 = 2*(vc2[0][1]*y2_ave - vc2[1][1]*x2_ave)/det2
    b2 = 2*(vc2[0][1]*x2_ave - vc2[0][0]*y2_ave)/det2
    b0 = (vc2[1][1]*x2_ave**2 + vc2[0][0]*y2_ave**2 - 2*vc2[0][0]*x2_ave*y2_ave)/det2
    
    # data_set1とdata_set2の距離が等しい位置の曲線を描く
    x, y = symbols("x y")
    f = (a11 - b11) * x**2 + (a22 - b22) * y**2 + (a12 - b12) * x * y + (a1 - b1) * x + (a2 - b2) * y + a0 - b0
    plot_implicit(f, (x, -6, 6), (y, -7, 7))   
    
    # data_set1とdata_set2の散布図を描く
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.scatter(X1, Y1, c='red')
    ax.scatter(X2, Y2, c='blue')
    
    plt.show()

結果

Mahalanobis-2.png

Mahalanobis-1.png

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?