1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pachinko Allocation Model (PAM) をCythonで高速化した

Last updated at Posted at 2023-11-17

はじめに

トピックモデルとは、文章データの潜在的なトピックを推定する確率モデルです。
本モデルを用いることで、以下のような情報が得られます。

  • トピックに基づく文章のクラスタリング
  • ある文章を構成するトピックの寄与(document-topic distribution)
  • あるトピックを構成する単語の寄与(topic-word distribution)

代表的なトピックモデルとして、Latent Dirichlet Allocation (LDA; 潜在的ディリクレ配分法) [1] が広く知られており、Qiitaにおいても盛んに議論されています。

今回はあえてPachinko Allocation Model (PAM; パチンコ配分法?) [2] というマイナー手法に焦点を当て、実装に取り組みました。

※ 本記事ではPAMアルゴリズムの数学的な解釈について深入りしません。あくまでも実装の紹介という立ち位置です。

ソースコードはGitHubで公開しております(https://github.com/groovy-phazuma/PAM_Cython)。

Pachinko Allocation Model (PAM)ってなんだ?

2003年に発表されたLDAでは、トピックは単語の共起表現を捉える一方で、トピック間の相関を明示的にモデル化しません。そのため、多数の密にまとまったトピックを見出す能力が低いことが懸念されます。これらの課題を踏まえ、2006年ごろには、トピック間の相関を考慮したモデルがいくつか考案されました。今回紹介するPAMもその一種です。(同時期に、CTM [3]という比較的有名な手法も考案されています。)

PAMのグラフィカルモデルは以下のようになります。

image.png
             文献[2] Fig.2より抜粋

2層のトピックから構成されており、上位トピックが下位トピックの相関を表現している点がポイントです。ここでは、$Z_{w2}$が単語wの上位トピックを、$Z_{w3}$が下位トピックを表しています。もう少し簡略化して図にすると、以下のようになります。

image.png
  文献[2] Fig.1より抜粋

パチンコ玉を落とすようにトピックや単語を生成するのが名前の由来です。

PAMの生成過程の近似

原著でも採用されている、Gibbsサンプリングを用いた推論は以下の通りです。

image.png
              文献[2]より抜粋

これは、文章dの単語wは、「上位トピック$t_k$が生成される確率」×「上位トピック$t_k$を介して下位トピック$t_p$が生成される確率」×「トピック$t_p$から単語wが生成される確率」で近似できると解釈できます。

なお、詳細は原著を参照してください。NTT コミュニケーション科学基礎研究所が公開している講座も非常に参考になりました[4]。

Pythonでの実装

Pythonによる実装は先人の功績を参考にしました[5]。上記のGibbsサンプリングによる推論は以下のようになります。

def infer_z(self):
    for d, words in enumerate(self.bw):
        for w, word in enumerate(words):
            if word < 0:
                pass
            else:
                # remove the target word
                self.D_S_K[d][self.z_s_k[d][w][0]][self.z_s_k[d][w][1]] -= 1
                self.K_V[self.z_s_k[d][w][1]][word] -= 1
                # calculate p(z = s, k)
                probs = np.zeros(self.S*self.K)
                for s in range(self.S):
                    for k in range(self.K):
                        N_ds = np.sum(self.D_S_K[d][s])
                        N_dsk = self.D_S_K[d][s][k]
                        N_k = np.sum(self.K_V[k])
                        N_kw = self.K_V[k][word]
                        prob = (N_ds + self.alpha0[s]) \
                            * ((N_dsk + self.alpha1[s][k]) / (N_ds + np.sum(self.alpha1[s]))) \
                            * ((N_kw + self.beta) / (N_k + self.beta * self.V))
                        probs[self.K*s + k] = prob
            probs /= np.sum(probs)
            # sampling
            s_k = np.random.multinomial(1, probs).argmax()
            # update count
            self.z_s_k[d][w][0] = s_k // self.K
            self.z_s_k[d][w][1] = s_k % self.K
            self.D_S_K[d][self.z_s_k[d][w][0]][self.z_s_k[d][w][1]] += 1
            self.K_V[self.z_s_k[d][w][1]][word] += 1

ここで、全文章の全単語数を$N$, 上位トピック数$S$, 下位トピック数$K$とした場合、計算量は$O(NSK)$となります。そのため、そもそものPAMの開発目的の一つである、「多数の密なトピックを検出する」シナリオにおいて、現実的な実行時間とならない可能性があります。

Cythonでの実装

さて、Cythonを用いて高速化してみましょう。

#cython: language_level=3
#cython: boundscheck=False
#cython: wraparound=False
#cython: cdivision=True

from cython.operator cimport preincrement as inc, predecrement as dec
from libc.stdlib cimport malloc, free
import numpy as np

cdef extern from "gamma.h":
    cdef double lda_lgamma(double x) nogil


cdef double lgamma(double x) nogil:
    if x <= 0:
        with gil:
            raise ValueError("x must be strictly positive")
    return lda_lgamma(x)


cdef int searchsorted(double* arr, int length, double value) nogil:
    """Bisection search (c.f. numpy.searchsorted)

    Find the index into sorted array `arr` of length `length` such that, if
    `value` were inserted before the index, the order of `arr` would be
    preserved.
    """
    cdef int imin, imax, imid
    imin = 0
    imax = length
    while imin < imax:
        imid = imin + ((imax - imin) >> 1)
        if value > arr[imid]:
            imin = imid + 1
        else:
            imax = imid
    return imin
        
cdef int calc_int_sum(int[:] target, int length) nogil:
    cdef int i
    cdef int target_sum = 0
    
    for i in range(length):
        target_sum += target[i]
    return target_sum
    
cdef double calc_double_sum(double[:] target, int length) nogil:
    cdef int i
    cdef double target_sum = 0
    
    for i in range(length):
        target_sum += target[i]
    return target_sum
        
def _infer_z(int[:, :] bw, int[:, :, :] dsk, int[:, :] kv, int[:, :, :] zsk, double[:] probs,
             double[:] alpha0, double[:,:] alpha1, double beta):
    cdef int i, j, w, u, l, p, s, k, v, Nds, Nk, Ndsk, Nkw, s_k
    cdef int D = bw.shape[0]
    cdef int W = bw.shape[1]
    cdef int V = kv.shape[1]
    cdef int S = dsk.shape[1]
    cdef int K = dsk.shape[2]
    cdef double prob, prob_sum, alpha_sum
    
    with nogil:
        for i in range(D):
            for j in range(W):
                w = bw[i][j] # value of jth word in ith document
                u = zsk[i][j][0] # upper topic idx
                l = zsk[i][j][1] # lower topic idx
                if w < 0:
                    pass
                else:
                    dec(dsk[i, u, l])
                    dec(kv[l, w])
                    # initialize
                    for p in range(S*K):
                        probs[p] = 0
                    # calculate p(z=s,k)
                    prob_sum = 0
                    for s in range(S):
                        for k in range(K):
                            Nds = calc_int_sum(dsk[i][s],K)
                            Nk = calc_int_sum(kv[k],kv.shape[1])
                            Ndsk = dsk[i][s][k]
                            Nkw = kv[k][w]
                            alpha_sum = calc_double_sum(alpha1[s],K)
                            prob = ((Nds + alpha0[s]) * (Ndsk + alpha1[s][k]) * (Nkw + beta)) / ((Nk + beta * V) * (Nds + alpha_sum))

                            probs[K*s + k] = prob # update probs
                            prob_sum += prob
                    
                    for p in range(S*K):
                        probs[p] /= prob_sum # normalize
                    # sampling
                    with gil:
                        try:
                            s_k = np.random.multinomial(1, probs).argmax()
                        except:
                            print(probs)
                            pass
                    zsk[i][j][0] = s_k / K
                    zsk[i][j][1] = s_k % K
                    inc(dsk[i, zsk[i][j][0], zsk[i][j][1]])
                    inc(kv[zsk[i][j][1], w])

サンプリング箇所のs_k = np.random.multinomial(1, probs).argmax()では、numpyを用いている点に留意してください。こちらはGlobal Interpreter Lock (GIL)による制約を受けるため処理速度が落ちます。他にもいくつか高速化の余地はありそうなので、余力があれば改良していきます。

実行時間の比較(Cython vs Python)

それでは実行時間を比較してみます。ベンチマークデータとして、4258種類の単語から成る395の文章を使用しました。上位トピック数$S=5$, 下位トピック数$K=5$の条件下での実験結果は以下の通りです。

PAM_Cython PAM_Python
8.90 sec 82.86 sec

Cythonによる実装で実行速度が9倍程度に増加していることが分かります。大規模なデータを入力とした際にも、現実的な時間で推論が可能になりそうですね。

おわりに

今回はPachinko Allocation Mode (PAM) をCythonで実装することで高速化を実現しました。コードはGitHubで公開しています(https://github.com/groovy-phazuma/PAM_Cython)。

PAMの後続研究はそこまで活気があるわけでもないため、どこまで需要があるのかは不明ですが、本記事がPAMの活用を試みる方のお役に立てれば幸いです。

References

[1] David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3, null (3/1/2003), 993–1022.
[2] Wei Li and Andrew McCallum. 2006. Pachinko allocation: DAG-structured mixture models of topic correlations. In Proceedings of the 23rd international conference on Machine learning (ICML '06). Association for Computing Machinery, New York, NY, USA, 577–584. https://doi.org/10.1145/1143844.1143917
[3] David M. Blei and John D. Lafferty. 2005. Correlated topic models. In Proceedings of the 18th International Conference on Neural Information Processing Systems (NIPS'05). MIT Press, Cambridge, MA, USA, 147–154.
[4] https://www.ism.ac.jp/~daichi/lectures/H24-TopicModel/ISM-2012-TopicModels_day2_2_structured.pdf
[5] https://github.com/kenchin110100/machine_learning/blob/master/samplePAM.py

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?