LoginSignup
14
12

More than 5 years have passed since last update.

PythonでPLSAを実装してみた2

Last updated at Posted at 2015-12-06

概要

前回の実装に手を加えました。内容は以下の通りです。

  • メモリ使用量の削減

  • 一部エラー処理導入

高速化はまた後日。

メモリ使用量削減

原因と方針

まず、前回の実装はメモリを非常に使います。原因はEMアルゴリズムのEステップでP(z|x,y)を求めていることです。単純に3次元配列を作ってしまっているので、これを作らないようにすればメモリの使用量がかなり少なくなります。

具体的に、MステップのP(x|z)の更新式を見てみましょう。分母は正規化してるだけなので、分子だけ考えます。

P\left( x | z \right) の分子 = \sum_{y} N_{x, y} P \left( z | x, y \right)

これにEステップのP(z|x,y)の更新式を代入します。

\begin{equation}
P\left( x | z \right) の分子 = \sum_{y} N_{x, y} \frac{P\left( z \right)P\left( x | z \right)P\left( y | z \right)}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)} \tag{1}
\end{equation}

この式を実装することになるのですが、for文をグルグル回すことはしたくないので、numpyのeinsumを利用します。

numpy.einsum()

einsum関数はアインシュタインの縮約です。と言っても分かりにくいので、例を一つ。

P(x,y) = \sum_{z}P(z)P(x|z)P(y|z)

を実装すると

Pxy = numpy.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)

となります。

このeinsum関数を使って式(1)を実装するのですが、そのままでは実装しにくいので以下のように式を変形します。

P\left( x | z \right) の分子 = \sum_{y} \frac{N_{x, y}}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)} P\left( z \right)P\left( x | z \right)P\left( y | z \right)

これを実装するとこうなります。

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)
Px_z = numpy.einsum('ij,k,ki,kj->ki', tmp, Pz, Px_z, Py_z)

これでメモリ使用量がどれくらい減ったかは後述します。

エラー処理

einsum関数を使って実装しましたが、

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)

において0で割ってしまった場合のエラー処理を考えます。
この分母が0になるということは、あるxとyについて、

\sum_{z}P(z)P(x|z)P(y|z) = 0

ということで、負の値は出てきませんので、あるxとyと全てのzについて、

P(z)P(x|z)P(y|z) = 0

が成り立ちます。
つまり、EMアルゴリズムのEステップは以下のようになります。

\begin{eqnarray}
P(z|x,y) & = & \frac{P\left( z \right)P\left( x | z \right)P\left( y | z \right)}{\sum_{z} P\left( z \right)P\left( x | z \right)P\left( y | z \right)}
& = & 0
\end{eqnarray}

よって、0で割ってしまった要素は0にするということになります。

ここで、numpyでは

1 / 0 = inf
0 / 0 = nan

ですので、それぞれnumpy.isinf、numpy.isnanを使って

tmp = N / numpu.einsum('k,ki,kj->ij', Pz, Px_z, Py_z)
tmp[numpy.isinf(tmp)] = 0
tmp[numpy.isnan(tmp)] = 0

Px_z = numpy.einsum('ij,k,ki,kj->ki', tmp, Pz, Px_z, Py_z)

となります。

実装

以上まとめまして、全体の実装は以下の通りです。

plsa.py
import numpy as np


class PLSA(object):
    def __init__(self, N, Z):
        self.N = N
        self.X = N.shape[0]
        self.Y = N.shape[1]
        self.Z = Z

        # P(z)
        self.Pz = np.random.rand(self.Z)
        # P(x|z)
        self.Px_z = np.random.rand(self.Z, self.X)
        # P(y|z)
        self.Py_z = np.random.rand(self.Z, self.Y)

        # 正規化
        self.Pz /= np.sum(self.Pz)
        self.Px_z /= np.sum(self.Px_z, axis=1)[:, None]
        self.Py_z /= np.sum(self.Py_z, axis=1)[:, None]

    def train(self, k=200, t=1.0e-7):
        '''
        対数尤度が収束するまでEステップとMステップを繰り返す
        '''
        prev_llh = 100000
        for i in xrange(k):
            self.em_algorithm()
            llh = self.llh()

            if abs((llh - prev_llh) / prev_llh) < t:
                break

            prev_llh = llh

    def em_algorithm(self):
        '''
        EMアルゴリズム
        P(z), P(x|z), P(y|z)の更新
        '''
        tmp = self.N / np.einsum('k,ki,kj->ij', self.Pz, self.Px_z, self.Py_z)
        tmp[np.isnan(tmp)] = 0
        tmp[np.isinf(tmp)] = 0

        Pz = np.einsum('ij,k,ki,kj->k', tmp, self.Pz, self.Px_z, self.Py_z)
        Px_z = np.einsum('ij,k,ki,kj->ki', tmp, self.Pz, self.Px_z, self.Py_z)
        Py_z = np.einsum('ij,k,ki,kj->kj', tmp, self.Pz, self.Px_z, self.Py_z)

        self.Pz = Pz / np.sum(Pz)
        self.Px_z = Px_z / np.sum(Px_z, axis=1)[:, None]
        self.Py_z = Py_z / np.sum(Py_z, axis=1)[:, None]

    def llh(self):
        '''
        対数尤度
        '''
        Pxy = np.einsum('k,ki,kj->ij', self.Pz, self.Px_z, self.Py_z)
        Pxy /= np.sum(Pxy)
        lPxy = np.log(Pxy)
        lPxy[np.isinf(lPxy)] = -1000

        return np.sum(self.N * lPxy)

対数尤度の計算の際にアンダーフローが発生し、log(0)=-infとなることがあります。
倍精度浮動小数点数の最小値は約4.94e-324ですので、log(4.94e-324)=-744.4より小さい値ということで、雑に-1000を入れています。

測定

どれくらいメモリ使用量が減ったか、memory_profilerを使って測ってみます。
以下のようなスクリプトで測定しました。

memory_profile.py
import numpy as np
from memory_profiler import profile

X = 1000
Y = 1000
Z = 10


@profile
def main():
    from plsa import PLSA
    plsa = PLSA(np.random.rand(X, Y), Z)
    plsa.em_algorithm()
    llh = plsa.llh()


if __name__ == '__main__':
    main()

X=1000, Y=1000, Z=10の場合、前回の実装では、

$ python profile_memory_element_wise_product.py 
Filename: profile_memory_element_wise_product.py

Line #    Mem usage    Increment   Line Contents
================================================
    10     15.9 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.9 MiB      0.0 MiB       from plsa_element_wise_product import PLSA
    13     23.9 MiB      8.0 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14    108.0 MiB     84.1 MiB       plsa.e_step()
    15    184.5 MiB     76.5 MiB       plsa.m_step()
    16    199.8 MiB     15.3 MiB       llh = plsa.llh()

今回の実装では、

$ python profile_memory_einsum.py 
Filename: profile_memory_einsum.py

Line #    Mem usage    Increment   Line Contents
================================================
    10     15.8 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.8 MiB      0.0 MiB       from plsa_einsum import PLSA
    13     23.7 MiB      7.9 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14     40.7 MiB     16.9 MiB       plsa.em_algorithm()
    15     48.4 MiB      7.8 MiB       llh = plsa.llh()

X=5000, Y=5000, Z=10の場合、前回の実装では、

$ python profile_memory_element_wise_product.py 
Filename: profile_memory_element_wise_product.py

Line #    Mem usage    Increment   Line Contents
================================================
    10     15.9 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.9 MiB      0.0 MiB       from plsa_element_wise_product import PLSA
    13    207.6 MiB    191.7 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14   2115.4 MiB   1907.8 MiB       plsa.e_step()
    15   2115.5 MiB      0.1 MiB       plsa.m_step()
    16   2115.5 MiB      0.0 MiB       llh = plsa.llh()

今回の実装では、

$ python profile_memory_einsum.py 
Filename: profile_memory_einsum.py

Line #    Mem usage    Increment   Line Contents
================================================
    10     15.7 MiB      0.0 MiB   @profile
    11                             def main():
    12     15.7 MiB      0.0 MiB       from plsa_einsum import PLSA
    13    207.5 MiB    191.7 MiB       plsa = PLSA(np.random.rand(X, Y), Z)
    14    233.0 MiB     25.6 MiB       plsa.em_algorithm()
    15    233.1 MiB      0.0 MiB       llh = plsa.llh()

まとめると、全体のメモリ使用量は、

実装 X=1000,Y=1000,Z=10 X=5000,Y=5000,Z=10
前回の実装 199.8 MiB 2115.5 MiB
今回の実装 48.4 MiB 233.1 MiB

とかなり減っていることが分かりますが、これをEMアルゴリズムと対数尤度の計算部分のみに限定すると、

実装 X=1000,Y=1000,Z=10 X=5000,Y=5000,Z=10
前回の実装 175.9 MiB 1907.9 MiB
今回の実装 24.7 MiB 25.6 MiB

となり、今回の実装では使用されたメモリ量がほとんど増えていないことが分かります。
これでデータ数が増えてもメモリに関しては怖くありません。

計算速度

einsum関数は便利なのですが、今回のように3つも4つもまとめて計算させると処理に時間がかかります。
私の手元のMacBookの場合、X=5000,Y=5000,Z=10で一度のEMアルゴリズムと対数尤度の計算に約8.7秒かかりました。
これは現実的ではないので改善が必要ですが、そこらへんはまた後日。

感想

  • 思ってた以上に長くなってしまいました。

  • einsum関数も便利ですね。

14
12
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
12