# 流行りのDoc2vecを使って文書クラスタリング
みなさん、元気に自然言語処理〜!?
流行りのWord2vecの文書版、Doc2vecでクラスタリングしてみました。ベクトルさえあればクラスタリングは簡単にできます。
なんらかの方法でベクトル、またはベクトルをつくってくれるモデルを取得しましょう。
Doc2vecを行う際はここを参考にすると良いと思います。モデルが簡単に手に入ります。
generate_model.py
#coding: UTF-8
from gensim.models.doc2vec import Doc2Vec
from gensim.models.doc2vec import TaggedDocument
import sys
f = open(sys.argv[1],'r')#空白で単語を区切り、改行で文書を区切っているテキストデータ
#1文書ずつ、単語に分割してリストに入れていく[([単語1,単語2,単語3],文書id),...]こんなイメージ
#words:文書に含まれる単語のリスト(単語の重複あり)
# tags:文書の識別子(リストで指定.1つの文書に複数のタグを付与できる)
trainings = [TaggedDocument(words = data.split(),tags = [i+200]) for i,data in enumerate(f)]
# トレーニング(パラメータについては後日)
m = Doc2Vec(documents= trainings, dm = 1, size=300, window=8, min_count=10, workers=4)
# モデルのセーブ
m.save(sys.argv[2])
モデルが完成したら、あとはモデルから二次元のテンソル(行列)を抽出してリストに格納します。
そしたらk-means実行です。
execute_clustering.py
from gensim.models.doc2vec import Doc2Vec
from gensim.models.doc2vec import TaggedDocument
from sklearn.cluster import KMeans
import sys
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
#モデルを読み込む
#モデルは絶対パスで指定してください
m = Doc2Vec.load('~/hogehoge.model')
#ベクトルをリストに格納
vectors_list=[m.docvecs[n] for n in range(len(m.docvecs))]
#ドキュメント番号のリスト
doc_nums=range(200,200+len(m.docvecs))
#クラスタリング設定
#クラスター数を変えたい場合はn_clustersを変えてください
n_clusters = 8
kmeans_model = KMeans(n_clusters=n_clusters, verbose=1, random_state=1, n_jobs=-1)
#クラスタリング実行
kmeans_model.fit(vectors_list)
#クラスタリングデータにラベル付け
labels=kmeans_model.labels_
#ラベルとドキュメント番号の辞書づくり
cluster_to_docs = defaultdict(list)
for cluster_id, doc_num in zip(labels, doc_nums):
cluster_to_docs[cluster_id].append(doc_num)
#クラスター出力
for docs in cluster_to_docs.values():
print(docs)
#どんなクラスタリングになったか、棒グラフ出力しますよ
import matplotlib.pyplot as plt
#x軸ラベル
x_label_name = []
for i in range(n_clusters):
x_label_name.append("Cluster"+str(i))
#x=left ,y=heightデータ. ここではx=クラスター名、y=クラスター内の文書数
left = range(n_clusters)
height = []
for docs in cluster_to_docs.values():
height.append(len(docs))
print(height,left,x_label_name)
#棒グラフ設定
plt.bar(left,height,color="#FF5B70",tick_label=x_label_name,align="center")
plt.title("Document clusters")
plt.xlabel("cluster name")
plt.ylabel("number of documents")
plt.grid(True)
plt.show()
テキストデータの前処理が済んでいれば、面白い結果が得られるかと思います。(手軽なところだとWikipedia)
でも結局・・・
原理的にWord2vecから発展している感じはないので、そんなにイケてない印象ですた。
自分のツイートが誰のツイートに近いだとか予測する診断サービスなんかに利用できそうですね()