LoginSignup
5
2

More than 5 years have passed since last update.

混合ガウス分布 vs k-means法 (塩尻MLもくもく会#4)

Last updated at Posted at 2018-06-17

クラスタリングの代表的な手法である
「混合ガウス分布」と「k-means法」を対決させてみました。

発表資料はこちら↓
https://docs.google.com/presentation/d/e/2PACX-1vRIAVFIFXnMyefy8cVjwI2FueU5NnCtJNvi4-SB61ynR6i9-C6IlgEiEz6X07OzxslsxeqmoftdrJhN/pub?start=false&loop=false&delayms=3000

一応コードも載せます。ぐちゃぐちゃですみません。

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import pandas as pd

from sklearn.metrics import confusion_matrix
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
import sklearn.mixture
from sklearn.preprocessing import StandardScaler

#混同行列の図示
def plotmatrix(matrix,log=False):
    n = len(matrix)
    plt.figure()
    plt.gca(xticks=np.arange(n),xticklabels=np.arange(n),yticks=np.arange(n),yticklabels=np.arange(n))
    plt.xlabel(u'Predict label')#,fontname='AppleGothic',size=16)
    plt.ylabel(u'True label')#,fontname='AppleGothic',size=16)
    for i in range(n):
        for j in range(n):
            plt.text(j,i,matrix[i,j],ha='center',va='center',size=14)
    if log==True:
        plt.imshow(matrix,cmap='Wistia',interpolation='nearest',norm=colors.LogNorm())
    else:
        plt.imshow(matrix,cmap='Wistia',interpolation='nearest')
    plt.colorbar(pad=0.01)
    plt.show()

#混合ガウス分布
def gmm(x,cluster_num,random,true):
    gmm = sklearn.mixture.GMM(cluster_num, covariance_type='full', random_state=random)
    gmm.fit(x)
    gmm_label = gmm.predict(x)
    conma = confusion_matrix(true,gmm_label)
    return conma

#k-means
def kmeans(x,cluster_num,random,true):
    k_means = KMeans(cluster_num, random_state=random)
    k_means.fit(x)
    k_means_label = k_means.predict(x)
    conma = confusion_matrix(true,k_means_label)
    return conma

#iris データロード
iris = load_iris()
#sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
data = pd.DataFrame(iris.data, columns=iris.feature_names)#150個
X = np.array(data)
true_label = np.array(iris.target)

#学習データ + 標準化
x_train = X
sc = StandardScaler()
x_train = sc.fit_transform(x_train)

#とりあえず、irisデータ図示
plt.figure()
plt.title("True label")
plt.scatter(x_train[:,2], x_train[:,3], c=[['orange', 'm', 'c'][x] for x in true_label])
plt.xlabel("petal length")
plt.ylabel("petal width")
plt.show()

#k-means 3回とって平均の混同行列を計算
conma = kmeans(x=x_train,cluster_num=3,random=0, true=true_label)
conma1 = conma[:, [1, 0, 2]]#列入れ替え
conma = kmeans(x=x_train,cluster_num=3,random=123, true=true_label)
conma2 = conma[:, [1, 0, 2]]#列入れ替え
conma = kmeans(x=x_train,cluster_num=3,random=512, true=true_label)
conma3 = conma[:, [1, 2, 0]]#列入れ替え
conma = (conma1 + conma2 + conma3)/3
conma = np.array(conma, dtype="int")
plotmatrix(conma)
temp = np.trace(conma)
temp2 = np.sum(conma)
print("k_means iris accuracy",temp/temp2)

#gmm 3回とって平均の混同行列を計算
conma = gmm(x=x_train,cluster_num=3,random=0, true=true_label)
conma1 = conma[:, [1, 0, 2]]#列入れ替え
conma = gmm(x=x_train,cluster_num=3,random=123, true=true_label)
conma2 = conma[:, [1, 0, 2]]#列入れ替え
conma = gmm(x=x_train,cluster_num=3,random=512, true=true_label)
conma3 = conma[:, [1, 2, 0]]#列入れ替え
conma = (conma1 + conma2 + conma3)/3
conma = np.array(conma, dtype="int")
plotmatrix(conma)
temp = np.trace(conma)
temp2 = np.sum(conma)
print("gmm iris accuracy",temp/temp2)

#digits データロード
digits = load_digits(n_class=10)
data = pd.DataFrame(digits.data)
X = np.array(data)
true_label = digits.target#array([0, 1, 2, ..., 8, 9, 8])

#学習データ + 標準化
x_train = X
sc = StandardScaler()
x_train = sc.fit_transform(x_train)

#主成分分析をして図示
pca = PCA(n_components=2)
transformed = pca.fit_transform(x_train)
plt.figure()
plt.title("True label")
plt.scatter(transformed[:,0], transformed[:,1], 
            c=true_label)#[['brown', 'g', 'b','grey', 'b', 'y','orange', 'm', 'c','r'][x] for x in true_label])
plt.xlabel("pc1")
plt.ylabel("pc2")
plt.colorbar()
plt.show()

#k-means 3回とって平均の混同行列を計算
conma = kmeans(x=x_train,cluster_num=10,random=0, true=true_label)
conma1 = conma[:, [0, 4, 7, 1, 2, 9, 8, 5, 6, 3]]#列入れ替え
conma = kmeans(x=x_train,cluster_num=10,random=123, true=true_label)
conma2 = conma[:, [3, 4, 2, 7, 6, 8, 0, 1, 5, 9]]#列入れ替え
conma = kmeans(x=x_train,cluster_num=10,random=512, true=true_label)
conma3 = conma[:, [6, 5, 1, 3, 9, 4, 2, 8, 7, 0]]#列入れ替え
conma = (conma1 + conma2 + conma3)/3
conma = np.array(conma, dtype="int")
plotmatrix(conma)
temp = np.trace(conma)
temp2 = np.sum(conma)
print("kmeans digits accuracy",temp/temp2)

#gmm 3回とって平均の混同行列を計算
conma = gmm(x=x_train,cluster_num=10,random=0, true=true_label)
conma1 = conma[:, [0, 4, 7, 1, 2, 9, 8, 5, 6, 3]]#列入れ替え
conma = gmm(x=x_train,cluster_num=10,random=123, true=true_label)
conma2 = conma[:, [3, 4, 2, 7, 6, 8, 0, 1, 5, 9]]#列入れ替え
conma = gmm(x=x_train,cluster_num=10,random=512, true=true_label)
conma3 = conma[:, [6, 5, 1, 3, 9, 4, 2, 8, 7, 0]]#列入れ替え
conma = (conma1 + conma2 + conma3)/3
conma = np.array(conma, dtype="int")
plotmatrix(conma)
temp = np.trace(conma)
temp2 = np.sum(conma)
print("gmm digits accuracy",temp/temp2)
5
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2