はじめに
LDAはトピックモデルと呼ばれる自然言語処理の分野で広く使われている数理モデルです。モデルの詳細は後ほど説明しますが、ざっくり言うと書かれた単語からその文章のトピックを複数推定するモデルです。
LDAのパラメータ推定には一般的にはRのtopicmodelsパッケージやPythonのgensimモジュールが使われることが多いですが、今回はモデルの拡張や仕組みの理解を目的として、stanでモデルをゼロから構築して推定することを試みます。
モデルの定義
コードは松浦さんのブログから借りてきました。ただし、stanのバージョンアップに合わせて、一部コードを修正しています。
data {
int<lower=1> K;
int<lower=1> M;
int<lower=1> V;
int<lower=1> N;
int<lower=1,upper=V> W[N];
int<lower=1,upper=N> Offset[M,2]; // range of word index per doc
vector<lower=0>[K] Alpha; // topic prior
vector<lower=0>[V] Beta; // word prior
}
parameters {
simplex[K] theta[M];
simplex[V] phi[K];
}
model {
// prior
for (m in 1:M)
theta[m] ~ dirichlet(Alpha);
for (k in 1:K)
phi[k] ~ dirichlet(Beta);
// likelihood
for (m in 1:M) {
for (n in Offset[m,1]:Offset[m,2]) {
real gamma[K];
for (k in 1:K)
gamma[k] = log(theta[m,k]) + log(phi[k,W[n]]);
target += log_sum_exp(gamma);
}
}
}
データと下処理
今回は20_newsコーパスを使いました。これは20個のジャンルからなる20,000本のニュース記事の集合で、トピックモデルの評価にしばしば使われるコーパスです。
予め単語にインデックスを付与し、文章は1 10 8 9
などの整数のリストに変換しています。したがって別ファイルvocab.txt
にインデックスと単語の対応関係1 --> "news"
などを記録しています。
library(tidyverse)
library(tidytext)
corpus = read_csv("./corpus_20news.txt", col_names = "words")
corpus = corpus %>% sample_n(1000)
corpus = corpus %>% mutate(doc = seq(1:nrow(corpus)))
corpus.tidy = corpus %>% unnest_tokens(word, words)
corpus.tidy = corpus.tidy %>% mutate(word = as.integer(word))
上のコードでコーパスをtidytext形式に変換しています。
doc | word |
---|---|
1 | 178 |
1 | 878 |
1 | 866 |
1 | 3579 |
1 | 2045 |
1 | 522 |
ここから、頻出語の削除とそれに伴うインデックスの振り直しを行っていきます。
stop_words = corpus.tidy %>% count(word, sort=T) %>% filter(n > 1000 | n < 15)
corpus.tidy = corpus.tidy %>% anti_join(stop_words)
appeared_word = unique(corpus.tidy$word)
vocab = read_csv("./vocab_20news.txt", col_names = "word")
vocab = vocab %>% mutate(index = seq(nrow(vocab)))
vocab.new = vocab[appeared_word,]
vocab.new = vocab.new %>% mutate(new_ind = seq(nrow(vocab.new)))
corpus.lda = corpus.tidy %>% left_join(vocab.new, by = c("word" = "index"))
あとは、文書iに属する単語がtidy形式でインデックス何番目から何番目にあるのかを表す情報をoffset変数に格納していきます。
M = nrow(corpus)
N = nrow(corpus.lda)
V = nrow(vocab.new)
K = 20
offset = matrix(nrow = M, ncol = 2)
last_ind = 0
for (m in 1:M) {
n_words = as.integer(corpus.lda %>% filter(doc == m) %>% count())
offset[m,1] = last_ind + 1
offset[m,2] = offset[m,1] + (n_words - 1)
last_ind = offset[m,2]
}
パラメータ推定
モデルとデータがそろったのでパラメータ推定を行っていきましょう。アルゴリズムにはver.2.9からstanに搭載されたADVI(自動微分変分推論)を使いました。stanといえばNUTSというアルゴリズムが有名ですが、MCMCだとLDAのようなパラメータの多いアルゴリズムでは推定にとてつもない時間が必要になってしまいます。ADVIであれば、MCMCより精度は劣るものの、高速に推論を行うことができます。
library(rstan)
load("./20news.RData")
data = list(
K = K,
M = M,
V = V,
N = N,
W = corpus.lda$new_ind,
Offset = offset,
Alpha = rep(1, K),
Beta = rep(0.5, V)
)
sm = stan_model(file = "./lda.stan")
fit.vb = vb(
sm,
data = data,
output_samples = 2000,
adapt_engaged = FALSE,
eta = .1
)
etaは自動的に調整してくれる機能もありますが、時間がかかるのでここでは天下り的に適当な値を指定しています。
結果
結果は以下のようになりました。
topic 1 | university | ca | time | apr | problem | made | people | fact | law | israel |
topic 2 | image | years | time | make | things | people | set | god | file | color |
topic 3 | ftp | information | windows | system | file | version | files | jpeg | scsi | gif |
topic 4 | image | list | system | graphics | bit | format | line | run | images | jpeg |
topic 5 | time | apr | make | message | problem | car | back | point | god | drug |
topic 6 | man | day | time | apr | work | people | car | called | ve | drive |
topic 7 | image | pc | people | find | price | software | file | files | openwindows | |
topic 8 | good | ftp | state | space | apr | made | system | point | year | team |
topic 9 | good | ca | time | support | make | people | jews | religion | ve | god |
topic 10 | good | image | university | key | ca | work | people | system | bit | jpeg |
topic 11 | cs | good | state | time | make | ve | government | god | part | give |
topic 12 | good | time | bill | made | people | person | place | life | book | government |
topic 13 | good | ca | time | apr | make | work | people | thing | file | jpeg |
topic 14 | man | years | information | uk | problem | things | people | car | thing | god |
topic 15 | good | university | years | ed | time | support | make | work | ve | god |
topic 16 | good | makes | apr | make | people | back | ve | ll | point | god |
topic 17 | good | question | time | make | windows | work | people | god | run | file |
topic 18 | cs | good | question | time | make | work | read | system | ve | data |
topic 19 | good | information | time | apr | make | work | problem | ve | game | |
topic 20 | question | apr | david | make | windows | problem | put | people | back | god |
なんとなくトピックらしきものが抽出されてる気もしますが、重なるものが多かったりあんまり品質はよくないようです。 |
感想
ある程度予想はしていたのですが、やはり既存のパッケージに比べると推定時間・精度ともに悪いです。ただ、LDAには無数の派生系があり、そうしたモデルの中にはどうしても自分で実装するしかないものも多いので、その後の拡張のしやすさという意味ではstanを使う意義もあるのではないでしょうか。僕はとりあえずgaussian-LDAをstanでやってみようと思っています。
コードおよびデータはここにも保存しています。