LoginSignup
7
5

More than 5 years have passed since last update.

連続空間トピックモデル実装

Posted at

0. 連続空間トピックモデル(CSTM)

先日統計数理研究所のセミナで連続空間トピックモデルを教えていただいた。
従来のLatent Dirichlet Allocationが混合モデル(和のモデル)を使った生成モデルに対して、
連続空間トピックモデルはRBM(積モデル)を使った生成モデルである。
論文によればLDAに比べて性能が良いとあった。

詳細は以下の論文を参照
http://chasen.org/~daiti-m/paper/nl213cstm.pdf

面白そうなので実装してみて検証を行いたい。

1. 実装

はじめCPUで実装していたが1000文章,8000単語で1epoch 10時間以上かかっていたので
GPU(cupy)で実施し直し、1epoch 1時間程度になった。

cstm_gpu.py
import numpy as np
from chainer import cuda

xp = cuda.cupy

# Γ(α + n) / Γ(α) = (α + (n-1) ) * (α + (n-2)) * ... * (α + 1) * α
# 値が大きくなるのでlogをとる
# alpha_arr GPU
# n_arr CPU
# n_arr_gpu GPU
def log_gamma_div_gpu(alpha_arr, n_arr, n_arr_gpu):
    n_max = np.max(n_arr)
    alpha_arr_tmp = xp.copy(alpha_arr)
    alpha_arr_tmp[np.where(n_arr==0)] = 1
    log_sum_arr = xp.log(alpha_arr_tmp)
    for i in range(1,int(n_max)):
        n_tmp = xp.where(n_arr_gpu>i+0.5,i,0)
        alpha_tmp = xp.where(n_tmp < 0.5, 1 , alpha_arr_tmp + n_tmp)
        log_sum_arr += xp.log(alpha_tmp)
    return log_sum_arr

def log_gamma_div_cpu(alpha_arr, n_arr):
    if isinstance(alpha_arr, np.float64) != 1:
        alpha_arr_cpu = cuda.to_cpu(alpha_arr)
    else:
        alpha_arr_cpu = alpha_arr
    n_max = np.max(n_arr)
    alpha_arr_tmp_cpu = np.where(alpha_arr_cpu!=0,alpha_arr_cpu,1)
    log_sum_arr_cpu = np.log(alpha_arr_tmp_cpu)
    for i in range(1,int(n_max)):
        n_tmp = np.where(n_arr>i,i,0)
        alpha_tmp_cpu = np.where(n_tmp==0, 1 , alpha_arr_tmp_cpu + n_tmp)
        log_sum_arr_cpu += np.log(alpha_tmp_cpu)
    return log_sum_arr_cpu

class AlphaClass:
    # 学習変数を代入
    def __init__(self,G0):
        self.G0 = cuda.to_gpu(G0)

    #  α を計算
    # 返り値 shape (文章の数,単語数)
    def calculate_alpha(self,alpha0,w_emb,d_emb):
        tmp = d_emb.dot(w_emb.transpose())
        alpha = alpha0 * self.G0 * xp.exp(tmp)
        return alpha

class ProbClass:
    def __init__(self, n):
        self.n = n
        self.n_gpu = cuda.to_gpu(n)

    # 全文書に関して各文書ごとの同時確率を計算
    # 返り値 shape (文章の数,)
    def calculate_prob_log(self,alpha):
        alpha2 = xp.copy(alpha)
        alpha2[np.where(self.n==0)] = 0
        tmp1 = log_gamma_div_cpu(xp.sum(alpha2,axis=1),np.sum(self.n,axis=1))
        tmp2 = xp.sum(log_gamma_div_gpu(alpha2,self.n,self.n_gpu),axis=1)
        prob = xp.asarray(-tmp1) + tmp2
        return prob

def calculate_alpha_index(alpha0,G0,w_emb,d):
    tmp = d.dot(w_emb.transpose())
    alpha = alpha0 * G0 * np.exp(tmp)
    return alpha


# 各文書ごとの同時確率を計算
# 返り値 値
def calculate_prob_log_index(alpha,n):
    alpha2 = np.copy(alpha)
    alpha2[np.where(n==0)] = 0
    tmp1 = log_gamma_div_cpu(np.sum(alpha2),np.sum(n))
    tmp2 = np.sum(log_gamma_div_cpu(alpha2,n))
    prob = - tmp1 + tmp2
    return prob

def mh_accept(log_ll_old,log_ll_new):
    if log_ll_old < log_ll_new:
        return 1
    else:
        p = np.exp(cuda.to_cpu(log_ll_new-log_ll_old))
        return np.random.binomial(1,p)
train_cstm_gpu.py
alpha0 = 1.0
w_emb = np.random.randn(単語数,ユニット数)
d_emb = np.random.randn(文章数,ユニット数)
wordcntbydoc = 各文章の単語頻出回数
G0 = 全文章の単語頻出回数/単語の頻出回数の和

sigma_doc = 0.01
sigma_word = 0.02
sigma_alpha = 0.2

alphaclass = cstm.AlphaClass(G0)
probclass = cstm.ProbClass(wordcntbydoc)

def calculate_prob_log_all(alpha0,w,d):
    alpha = alphaclass.calculate_alpha(alpha0,w,d)
    prob = probclass.calculate_prob_log(alpha)
    return prob

prob = calculate_prob_log_all(alpha0,xp.asarray(w_emb),xp.asarray(d_emb))
print("prob={}".format(xp.sum(prob)))


begin_time = time.time()
for epoch in range(0,args.epoch):
    logwrite('epoch=' + str(epoch))
    w_perm = np.random.permutation(n_vocab)
    d_perm = np.random.permutation(n_wid)

    for i in d_perm:
        d_i = d_emb[i] + sigma_doc * np.random.randn(args.unit)
        alpha_new = cstm.calculate_alpha_index(alpha0,G0,w_emb,d_i)
        prob_new = cstm.calculate_prob_log_index(alpha_new, wordcntbydoc[i])
        new_flag = cstm.mh_accept(prob_old,prob_new)
        if new_flag == 1:
            d_emb[i] = d_i
            prob[i] = prob_new
        logwrite("  doc i={}:ll={}".format(i,xp.sum(prob)))


    end_time = time.time()
    duration = end_time - begin_time

    prob = calculate_prob_log_all(xp.asarray(alpha0),xp.asarray(w_emb),xp.asarray(d_emb))
    logwrite("doc result={}".format(xp.sum(prob)))
    logwrite('doc: {:.2f} sec'.format(duration))
    begin_time = time.time()

    for i in w_perm:
        w_emb_new = xp.copy(w_emb)
        w_emb_new[i] = w_emb[i] + sigma_word * np.random.randn(args.unit)
        prob_new = calculate_prob_log_all(xp.asarray(alpha0),xp.asarray(w_emb_new),xp.asarray(d_emb))

        new_flag = cstm.mh_accept(xp.sum(prob),xp.sum(prob_new))
        if new_flag == 1:
            prob = xp.copy(prob_new)
            w_emb = np.copy(w_emb_new)
        logwrite("  word i={}:ll={}".format(i,xp.sum(prob)))

    end_time = time.time()
    duration = end_time - begin_time

    prob = calculate_prob_log_all(xp.asarray(alpha0),xp.asarray(w_emb),xp.asarray(d_emb))
    logwrite("word result={}".format(xp.sum(prob)))
    logwrite('word: {:.2f} sec'.format(duration))
    begin_time = time.time()

    z = sigma_alpha * np.random.randn(1)
    alpha0_new = alpha0 * np.exp(z)
    prob_new = calculate_prob_log_all(xp.asarray(alpha0_new),xp.asarray(w_emb),xp.asarray(d_emb))
    new_flag = cstm.mh_accept(xp.sum(prob),np.sum(prob_new))
    if new_flag == 1:
        alpha0 = alpha0_new
        prob = xp.copy(prob_new)
    logwrite("  alpha0 : ll={} alpha0={}".format(xp.sum(prob),alpha0))

    end_time = time.time()
    duration = end_time - begin_time

    prob = calculate_prob_log_all(xp.asarray(alpha0),xp.asarray(w_emb),xp.asarray(d_emb))
    logwrite("alpha0 result={}".format(xp.sum(prob)))
    logwrite('alpha0: {:.2f} sec'.format(duration))
    begin_time = time.time()

    logwrite('Save to pkl epoch=' + str(epoch) )

    result_all = [alpha0,w_emb,d_emb]
    file1 = resultdir + "/result_epoch_" + "{0:0>3}".format(epoch) + ".pkl"
    with open(file1, 'wb') as output1:
        six.moves.cPickle.dump(result_all, output1, -1)

GPUで高速化するためガンマ関数の計算が強引になってしまった。
一旦実装できたので論文の計算式とソースの値があっているか
確認テストをしながら検証をしてみたいと思う。

7
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
5