Posted at

StanでLDAをまわしてみた (自動微分変分推論)


はじめに

LDAはトピックモデルと呼ばれる自然言語処理の分野で広く使われている数理モデルです。モデルの詳細は後ほど説明しますが、ざっくり言うと書かれた単語からその文章のトピックを複数推定するモデルです。

LDAのパラメータ推定には一般的にはRのtopicmodelsパッケージやPythonのgensimモジュールが使われることが多いですが、今回はモデルの拡張や仕組みの理解を目的として、stanでモデルをゼロから構築して推定することを試みます。


モデルの定義

コードは松浦さんのブログから借りてきました。ただし、stanのバージョンアップに合わせて、一部コードを修正しています。

data {

int<lower=1> K;
int<lower=1> M;
int<lower=1> V;
int<lower=1> N;
int<lower=1,upper=V> W[N];
int<lower=1,upper=N> Offset[M,2]; // range of word index per doc

vector<lower=0>[K] Alpha; // topic prior
vector<lower=0>[V] Beta; // word prior
}
parameters {
simplex[K] theta[M];
simplex[V] phi[K];
}
model {
// prior
for (m in 1:M)
theta[m] ~ dirichlet(Alpha);
for (k in 1:K)
phi[k] ~ dirichlet(Beta);

// likelihood
for (m in 1:M) {
for (n in Offset[m,1]:Offset[m,2]) {
real gamma[K];
for (k in 1:K)
gamma[k] = log(theta[m,k]) + log(phi[k,W[n]]);
target += log_sum_exp(gamma);
}
}
}


データと下処理

今回は20_newsコーパスを使いました。これは20個のジャンルからなる20,000本のニュース記事の集合で、トピックモデルの評価にしばしば使われるコーパスです。

予め単語にインデックスを付与し、文章は1 10 8 9などの整数のリストに変換しています。したがって別ファイルvocab.txtにインデックスと単語の対応関係1 --> "news"などを記録しています。

library(tidyverse)

library(tidytext)

corpus = read_csv("./corpus_20news.txt", col_names = "words")
corpus = corpus %>% sample_n(1000)
corpus = corpus %>% mutate(doc = seq(1:nrow(corpus)))
corpus.tidy = corpus %>% unnest_tokens(word, words)
corpus.tidy = corpus.tidy %>% mutate(word = as.integer(word))

上のコードでコーパスをtidytext形式に変換しています。



doc
word


1
178


1
878


1
866


1
3579


1
2045


1
522

ここから、頻出語の削除とそれに伴うインデックスの振り直しを行っていきます。

stop_words = corpus.tidy %>% count(word, sort=T) %>% filter(n > 1000 | n < 15)

corpus.tidy = corpus.tidy %>% anti_join(stop_words)

appeared_word = unique(corpus.tidy$word)

vocab = read_csv("./vocab_20news.txt", col_names = "word")
vocab = vocab %>% mutate(index = seq(nrow(vocab)))
vocab.new = vocab[appeared_word,]

vocab.new = vocab.new %>% mutate(new_ind = seq(nrow(vocab.new)))

corpus.lda = corpus.tidy %>% left_join(vocab.new, by = c("word" = "index"))

あとは、文書iに属する単語がtidy形式でインデックス何番目から何番目にあるのかを表す情報をoffset変数に格納していきます。

M = nrow(corpus)

N = nrow(corpus.lda)
V = nrow(vocab.new)
K = 20

offset = matrix(nrow = M, ncol = 2)

last_ind = 0
for (m in 1:M) {
n_words = as.integer(corpus.lda %>% filter(doc == m) %>% count())
offset[m,1] = last_ind + 1
offset[m,2] = offset[m,1] + (n_words - 1)
last_ind = offset[m,2]
}


パラメータ推定

モデルとデータがそろったのでパラメータ推定を行っていきましょう。アルゴリズムにはver.2.9からstanに搭載されたADVI(自動微分変分推論)を使いました。stanといえばNUTSというアルゴリズムが有名ですが、MCMCだとLDAのようなパラメータの多いアルゴリズムでは推定にとてつもない時間が必要になってしまいます。ADVIであれば、MCMCより精度は劣るものの、高速に推論を行うことができます。

library(rstan)

load("./20news.RData")

data = list(
K = K,
M = M,
V = V,
N = N,
W = corpus.lda$new_ind,
Offset = offset,
Alpha = rep(1, K),
Beta = rep(0.5, V)
)

sm = stan_model(file = "./lda.stan")

fit.vb = vb(
sm,
data = data,
output_samples = 2000,
adapt_engaged = FALSE,
eta = .1
)

etaは自動的に調整してくれる機能もありますが、時間がかかるのでここでは天下り的に適当な値を指定しています。


結果

結果は以下のようになりました。

topic 1
university
ca
time
apr
problem
made
people
fact
law
israel

topic 2
image
years
time
make
things
people
set
god
file
color

topic 3
ftp
information
windows
system
file
version
files
jpeg
scsi
gif

topic 4
image
list
system
graphics
bit
format
line
run
images
jpeg

topic 5
time
apr
make
message
problem
car
back
point
god
drug

topic 6
man
day
time
apr
work
people
car
called
ve
drive

topic 7
image
mail
pc
people
find
price
software
file
files
openwindows

topic 8
good
ftp
state
space
apr
made
system
point
year
team

topic 9
good
ca
time
support
make
people
jews
religion
ve
god

topic 10
good
image
university
key
ca
work
people
system
bit
jpeg

topic 11
cs
good
state
time
make
ve
government
god
part
give

topic 12
good
time
bill
made
people
person
place
life
book
government

topic 13
good
ca
time
apr
make
work
people
thing
file
jpeg

topic 14
man
years
information
uk
problem
things
people
car
thing
god

topic 15
good
university
years
ed
time
support
make
work
ve
god

topic 16
good
makes
apr
make
people
back
ve
ll
point
god

topic 17
good
question
time
make
windows
work
people
god
run
file

topic 18
cs
good
question
time
make
work
read
system
ve
data

topic 19
good
information
mail
time
apr
make
work
problem
ve
game

topic 20
question
apr
david
make
windows
problem
put
people
back
god

なんとなくトピックらしきものが抽出されてる気もしますが、重なるものが多かったりあんまり品質はよくないようです。


感想

ある程度予想はしていたのですが、やはり既存のパッケージに比べると推定時間・精度ともに悪いです。ただ、LDAには無数の派生系があり、そうしたモデルの中にはどうしても自分で実装するしかないものも多いので、その後の拡張のしやすさという意味ではstanを使う意義もあるのではないでしょうか。僕はとりあえずgaussian-LDAをstanでやってみようと思っています。

コードおよびデータはここにも保存しています。