44
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

PythonでPLSAを実装してみた

Posted at

概要

  • PythonでPLSA(確率的潜在意味解析: Probabilistic Latent Semantic Analysis)を実装してみました。

  • 高速化やエラー処理(log(0)の対策など)はまた後日。

PLSAとは

クラスタリング手法の1つで、

  1. 文書dがP(d)で選ばれる
  2. トピックzがP(z|d)で選ばれる
  3. 単語wがP(w|z)で生成される

というモデルです。
式にすると

P(d,w) = P(d)\sum_{z}P(z|d)P(w|z)

となりますが、少し変形して

P(d,w) = \sum_{z}P(z)P(d|z)P(w|z)

を扱うことが多いです。

このモデルに対し、対数尤度

L = \sum_{d}\sum_{w}N(d,w)\log P(d,w)

が最大になるようなP(z), P(d|z), P(w|z)を、EMアルゴリズムを使って求めます。
N(d,w)は文書d内における単語wの登場回数です。

EMアルゴリズム

対数尤度が収束するまで、以下のEステップとMステップを繰り返します。

  • Eステップ
P(z|d,w) = \frac{P\left( z \right)P\left( d | z \right)P\left( w | z \right)}{\sum_{z} P\left( z \right)P\left( d | z \right)P\left( w | z \right)}
  • Mステップ
\begin{eqnarray}
P\left( z \right) & = & \frac{\sum_{d} \sum_{w} N_{d, w} P\left( z | d, w \right)}{\sum_{d} \sum_{w} N_{d, w}} \\
P\left( d | z \right) & = & \frac{\sum_{w} N_{d, w} P \left( z | d, w \right)}{\sum_{d} \sum_{w} N_{d, w} P \left( z | d, w \right)} \\
P\left( w | z \right) & = & \frac{\sum_{d} N_{d, w} P \left( z | d, w \right)}{\sum_{d} \sum_{w} N_{d, w} P \left( z | d, w \right)}
\end{eqnarray}

実装

文書と単語以外にも使えるよね、という想いを込めて、

P(x,y) = \sum_{z}P(z)P(x|z)P(y|z)

としています。

plsa.py
import numpy as np


class PLSA(object):
    def __init__(self, N, Z):
        self.N = N
        self.X = N.shape[0]
        self.Y = N.shape[1]
        self.Z = Z

        # P(z)
        self.Pz = np.random.rand(self.Z)
        # P(x|z)
        self.Px_z = np.random.rand(self.Z, self.X)
        # P(y|z)
        self.Py_z = np.random.rand(self.Z, self.Y)

        # 正規化
        self.Pz /= np.sum(self.Pz)
        self.Px_z /= np.sum(self.Px_z, axis=1)[:, None]
        self.Py_z /= np.sum(self.Py_z, axis=1)[:, None]

    def train(self, k=200, t=1.0e-7):
        '''
        対数尤度が収束するまでEステップとMステップを繰り返す
        '''
        prev_llh = 100000
        for i in xrange(k):
            self.e_step()
            self.m_step()
            llh = self.llh()

            if abs((llh - prev_llh) / prev_llh) < t:
                break

            prev_llh = llh

    def e_step(self):
        '''
        Eステップ
        P(z|x,y)の更新
        '''
        self.Pz_xy = self.Pz[None, None, :] \
            * self.Px_z.T[:, None, :] \
            * self.Py_z.T[None, :, :]
        self.Pz_xy /= np.sum(self.Pz_xy, axis=2)[:, :, None]

    def m_step(self):
        '''
        Mステップ
        P(z), P(x|z), P(y|z)の更新
        '''
        NP = self.N[:, :, None] * self.Pz_xy

        self.Pz = np.sum(NP, axis=(0, 1))
        self.Px_z = np.sum(NP, axis=1).T
        self.Py_z = np.sum(NP, axis=0).T

        self.Pz /= np.sum(self.Pz)
        self.Px_z /= np.sum(self.Px_z, axis=1)[:, None]
        self.Py_z /= np.sum(self.Py_z, axis=1)[:, None]

    def llh(self):
        '''
        対数尤度
        '''
        Pxy = self.Pz[None, None, :] \
            * self.Px_z.T[:, None, :] \
            * self.Py_z.T[None, :, :]
        Pxy = np.sum(Pxy, axis=2)
        Pxy /= np.sum(Pxy)

        return np.sum(self.N * np.log(Pxy))

こんな感じで使います。
例はX=5, Y=4, Z=2です。

import numpy as np
from plsa import PLSA


N = np.array([
    [20, 23, 1, 4],
    [25, 19, 3, 0],
    [2, 1, 31, 28],
    [0, 1, 22, 17],
    [1, 0, 18, 24]
])

plsa = PLSA(N, 2)
plsa.train()

print 'P(z)'
print plsa.Pz
print 'P(x|z)'
print plsa.Px_z
print 'P(y|z)'
print plsa.Py_z
print 'P(z|x)'
Pz_x = plsa.Px_z.T * plsa.Pz[None, :]
print Pz_x / np.sum(Pz_x, axis=1)[:, None]
print 'P(z|y)'
Pz_y = plsa.Py_z.T * plsa.Pz[None, :]
print Pz_y / np.sum(Pz_y, axis=1)[:, None]
結果
P(z)
[ 0.3904831  0.6095169]
P(x|z)
[[  4.62748437e-01   5.01515512e-01   2.44650515e-02   1.11544339e-02   1.16565226e-04]
 [  3.16718954e-02   1.99625810e-09   4.08159551e-01   2.66294583e-01   2.93873968e-01]]
P(y|z)
[[  4.92518417e-01   4.69494935e-01   3.79855484e-02   1.09954245e-06]
 [  1.25999484e-02   5.73485868e-06   4.88365926e-01   4.99028390e-01]]
P(z|x)
[[  9.03477222e-01   9.65227775e-02]
 [  9.99999994e-01   6.21320706e-09]
 [  3.69800871e-02   9.63019913e-01]
 [  2.61337075e-02   9.73866292e-01]
 [  2.54046982e-04   9.99745953e-01]]
P(z|y)
[[  9.61600593e-01   3.83994074e-02]
 [  9.99980934e-01   1.90663270e-05]
 [  4.74646871e-02   9.52535313e-01]
 [  1.41157067e-06   9.99998588e-01]]

感想

  • Numpyのブロードキャスティングは便利です。
44
41
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
44
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?