0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

クラスタリング結果をビジュアライゼーション/評価する関数を作成

Posted at

クラスタリングした結果をビジュアライゼーションし、評価する

vaeなどでクラスタリングした結果をビジュアライゼーションし、評価値を表示する関数を実装した。

正解ラベルとクラスタリング結果であるクラスタ番号を多数決でリラベルし、
擬似的なconfusion matrixを描画、accuracyを計算する。
また、NMIとARIによる評価値も表示する。
どれだけうまくクラスタリングしたかを評価できる関数を作ったつもり。

# 必要なライブラリのインポート
import numpy as np
import pandas as pd
import sklearn
# Jupyter notebookを利用している際に、notebook内にplot結果を表示するようにする
import matplotlib.pyplot as plt
%matplotlib inline
df_result_dense = pd.read_csv('result-dense.csv')
df_result_dense
Unnamed: 0 labels k-means
0 0 7 2
1 1 2 5
2 2 1 9
3 3 0 3
4 4 4 7
... ... ... ...
9995 9995 2 5
9996 9996 3 0
9997 9997 4 7
9998 9998 5 4
9999 9999 6 6

10000 rows × 3 columns

def relabel(ans, labels):
    df = pd.DataFrame()
    df['ans'] = ans
    df['labels'] = labels
    relabel(df, 'ans', 'labels')

def relabel(df, ans, label):
    # ansに一番近いラベリングをし直す
    # df[ans]に正答、df[labels]にクラスタラベルが入っていることを期待している
    labels = df[label].unique()
    label_dic = {}
    for i in labels:
        counts = df[df[label] == i][ans].value_counts()
        label_dic[i] = counts.index[0]
    display(label_dic)
    return list(pd.Series(df[label]).replace(label_dic))
relabel_k_means = relabel(df_result_dense, 'labels', 'k-means')
df_result_dense['relabel_k_means'] = relabel_k_means
{2: 7, 5: 2, 9: 1, 3: 0, 7: 4, 1: 9, 4: 5, 8: 8, 6: 6, 0: 3}
from sklearn.metrics import accuracy_score
print(accuracy_score(df_result_dense['labels'],df_result_dense['k-means']))
print(accuracy_score(df_result_dense['labels'],df_result_dense['relabel_k_means']))
0.1841
0.9309
ans = df_result_dense['labels']
labels = df_result_dense['k-means']
relabels = df_result_dense['relabel_k_means']
def eval_cluster(ans, labels, relabels):
    import seaborn as sns
    from sklearn.metrics import confusion_matrix

    plt.title('no-relabel')
    sns.heatmap(confusion_matrix(ans, labels), annot=True, fmt='d')
    plt.show()

    from sklearn.metrics import normalized_mutual_info_score
    print("nmi: " + str(normalized_mutual_info_score(ans, labels)))
    from sklearn.metrics.cluster import adjusted_rand_score
    print("ari: " + str(adjusted_rand_score(ans, labels)))

    plt.title('relabel')
    sns.heatmap(confusion_matrix(ans, relabels), annot=True, fmt='d')
    plt.show()
    print("acc: " + str(accuracy_score(ans, relabels)))

eval_cluster(ans, labels, relabels)

output_7_0.png

nmi: 0.8804532777228216
ari: 0.8405114317316403

output_7_2.png

acc: 0.9309

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?