はじめに
トピックモデルとは、文章データの潜在的なトピックを推定する確率モデルです。
本モデルを用いることで、以下のような情報が得られます。
- トピックに基づく文章のクラスタリング
- ある文章を構成するトピックの寄与(document-topic distribution)
- あるトピックを構成する単語の寄与(topic-word distribution)
代表的なトピックモデルとして、Latent Dirichlet Allocation (LDA; 潜在的ディリクレ配分法) [1] が広く知られており、Qiitaにおいても盛んに議論されています。
今回はあえてPachinko Allocation Model (PAM; パチンコ配分法?) [2] というマイナー手法に焦点を当て、実装に取り組みました。
※ 本記事ではPAMアルゴリズムの数学的な解釈について深入りしません。あくまでも実装の紹介という立ち位置です。
ソースコードはGitHubで公開しております(https://github.com/groovy-phazuma/PAM_Cython)。
Pachinko Allocation Model (PAM)ってなんだ?
2003年に発表されたLDAでは、トピックは単語の共起表現を捉える一方で、トピック間の相関を明示的にモデル化しません。そのため、多数の密にまとまったトピックを見出す能力が低いことが懸念されます。これらの課題を踏まえ、2006年ごろには、トピック間の相関を考慮したモデルがいくつか考案されました。今回紹介するPAMもその一種です。(同時期に、CTM [3]という比較的有名な手法も考案されています。)
PAMのグラフィカルモデルは以下のようになります。
2層のトピックから構成されており、上位トピックが下位トピックの相関を表現している点がポイントです。ここでは、$Z_{w2}$が単語wの上位トピックを、$Z_{w3}$が下位トピックを表しています。もう少し簡略化して図にすると、以下のようになります。
パチンコ玉を落とすようにトピックや単語を生成するのが名前の由来です。
PAMの生成過程の近似
原著でも採用されている、Gibbsサンプリングを用いた推論は以下の通りです。
これは、文章dの単語wは、「上位トピック$t_k$が生成される確率」×「上位トピック$t_k$を介して下位トピック$t_p$が生成される確率」×「トピック$t_p$から単語wが生成される確率」で近似できると解釈できます。
なお、詳細は原著を参照してください。NTT コミュニケーション科学基礎研究所が公開している講座も非常に参考になりました[4]。
Pythonでの実装
Pythonによる実装は先人の功績を参考にしました[5]。上記のGibbsサンプリングによる推論は以下のようになります。
def infer_z(self):
for d, words in enumerate(self.bw):
for w, word in enumerate(words):
if word < 0:
pass
else:
# remove the target word
self.D_S_K[d][self.z_s_k[d][w][0]][self.z_s_k[d][w][1]] -= 1
self.K_V[self.z_s_k[d][w][1]][word] -= 1
# calculate p(z = s, k)
probs = np.zeros(self.S*self.K)
for s in range(self.S):
for k in range(self.K):
N_ds = np.sum(self.D_S_K[d][s])
N_dsk = self.D_S_K[d][s][k]
N_k = np.sum(self.K_V[k])
N_kw = self.K_V[k][word]
prob = (N_ds + self.alpha0[s]) \
* ((N_dsk + self.alpha1[s][k]) / (N_ds + np.sum(self.alpha1[s]))) \
* ((N_kw + self.beta) / (N_k + self.beta * self.V))
probs[self.K*s + k] = prob
probs /= np.sum(probs)
# sampling
s_k = np.random.multinomial(1, probs).argmax()
# update count
self.z_s_k[d][w][0] = s_k // self.K
self.z_s_k[d][w][1] = s_k % self.K
self.D_S_K[d][self.z_s_k[d][w][0]][self.z_s_k[d][w][1]] += 1
self.K_V[self.z_s_k[d][w][1]][word] += 1
ここで、全文章の全単語数を$N$, 上位トピック数$S$, 下位トピック数$K$とした場合、計算量は$O(NSK)$となります。そのため、そもそものPAMの開発目的の一つである、「多数の密なトピックを検出する」シナリオにおいて、現実的な実行時間とならない可能性があります。
Cythonでの実装
さて、Cythonを用いて高速化してみましょう。
#cython: language_level=3
#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True
from cython.operator cimport preincrement as inc, predecrement as dec
from libc.stdlib cimport malloc, free
import numpy as np
cdef extern from "gamma.h":
cdef double lda_lgamma(double x) nogil
cdef double lgamma(double x) nogil:
if x <= 0:
with gil:
raise ValueError("x must be strictly positive")
return lda_lgamma(x)
cdef int searchsorted(double* arr, int length, double value) nogil:
"""Bisection search (c.f. numpy.searchsorted)
Find the index into sorted array `arr` of length `length` such that, if
`value` were inserted before the index, the order of `arr` would be
preserved.
"""
cdef int imin, imax, imid
imin = 0
imax = length
while imin < imax:
imid = imin + ((imax - imin) >> 1)
if value > arr[imid]:
imin = imid + 1
else:
imax = imid
return imin
cdef int calc_int_sum(int[:] target, int length) nogil:
cdef int i
cdef int target_sum = 0
for i in range(length):
target_sum += target[i]
return target_sum
cdef double calc_double_sum(double[:] target, int length) nogil:
cdef int i
cdef double target_sum = 0
for i in range(length):
target_sum += target[i]
return target_sum
def _infer_z(int[:, :] bw, int[:, :, :] dsk, int[:, :] kv, int[:, :, :] zsk, double[:] probs,
double[:] alpha0, double[:,:] alpha1, double beta):
cdef int i, j, w, u, l, p, s, k, v, Nds, Nk, Ndsk, Nkw, s_k
cdef int D = bw.shape[0]
cdef int W = bw.shape[1]
cdef int V = kv.shape[1]
cdef int S = dsk.shape[1]
cdef int K = dsk.shape[2]
cdef double prob, prob_sum, alpha_sum
with nogil:
for i in range(D):
for j in range(W):
w = bw[i][j] # value of jth word in ith document
u = zsk[i][j][0] # upper topic idx
l = zsk[i][j][1] # lower topic idx
if w < 0:
pass
else:
dec(dsk[i, u, l])
dec(kv[l, w])
# initialize
for p in range(S*K):
probs[p] = 0
# calculate p(z=s,k)
prob_sum = 0
for s in range(S):
for k in range(K):
Nds = calc_int_sum(dsk[i][s],K)
Nk = calc_int_sum(kv[k],kv.shape[1])
Ndsk = dsk[i][s][k]
Nkw = kv[k][w]
alpha_sum = calc_double_sum(alpha1[s],K)
prob = ((Nds + alpha0[s]) * (Ndsk + alpha1[s][k]) * (Nkw + beta)) / ((Nk + beta * V) * (Nds + alpha_sum))
probs[K*s + k] = prob # update probs
prob_sum += prob
for p in range(S*K):
probs[p] /= prob_sum # normalize
# sampling
with gil:
try:
s_k = np.random.multinomial(1, probs).argmax()
except:
print(probs)
pass
zsk[i][j][0] = s_k / K
zsk[i][j][1] = s_k % K
inc(dsk[i, zsk[i][j][0], zsk[i][j][1]])
inc(kv[zsk[i][j][1], w])
サンプリング箇所のs_k = np.random.multinomial(1, probs).argmax()
では、numpyを用いている点に留意してください。こちらはGlobal Interpreter Lock (GIL)による制約を受けるため処理速度が落ちます。他にもいくつか高速化の余地はありそうなので、余力があれば改良していきます。
実行時間の比較(Cython vs Python)
それでは実行時間を比較してみます。ベンチマークデータとして、4258種類の単語から成る395の文章を使用しました。上位トピック数$S=5$, 下位トピック数$K=5$の条件下での実験結果は以下の通りです。
PAM_Cython | PAM_Python |
---|---|
8.90 sec | 82.86 sec |
Cythonによる実装で実行速度が9倍程度に増加していることが分かります。大規模なデータを入力とした際にも、現実的な時間で推論が可能になりそうですね。
おわりに
今回はPachinko Allocation Mode (PAM) をCythonで実装することで高速化を実現しました。コードはGitHubで公開しています(https://github.com/groovy-phazuma/PAM_Cython)。
PAMの後続研究はそこまで活気があるわけでもないため、どこまで需要があるのかは不明ですが、本記事がPAMの活用を試みる方のお役に立てれば幸いです。
References
[1] David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3, null (3/1/2003), 993–1022.
[2] Wei Li and Andrew McCallum. 2006. Pachinko allocation: DAG-structured mixture models of topic correlations. In Proceedings of the 23rd international conference on Machine learning (ICML '06). Association for Computing Machinery, New York, NY, USA, 577–584. https://doi.org/10.1145/1143844.1143917
[3] David M. Blei and John D. Lafferty. 2005. Correlated topic models. In Proceedings of the 18th International Conference on Neural Information Processing Systems (NIPS'05). MIT Press, Cambridge, MA, USA, 147–154.
[4] https://www.ism.ac.jp/~daichi/lectures/H24-TopicModel/ISM-2012-TopicModels_day2_2_structured.pdf
[5] https://github.com/kenchin110100/machine_learning/blob/master/samplePAM.py