R
NLP
LDA
統計学
Stan

StanでLDAをまわしてみた (自動微分変分推論)

はじめに

LDAはトピックモデルと呼ばれる自然言語処理の分野で広く使われている数理モデルです。モデルの詳細は後ほど説明しますが、ざっくり言うと書かれた単語からその文章のトピックを複数推定するモデルです。

LDAのパラメータ推定には一般的にはRのtopicmodelsパッケージやPythonのgensimモジュールが使われることが多いですが、今回はモデルの拡張や仕組みの理解を目的として、stanでモデルをゼロから構築して推定することを試みます。

モデルの定義

コードは松浦さんのブログから借りてきました。ただし、stanのバージョンアップに合わせて、一部コードを修正しています。

data {
  int<lower=1> K;
  int<lower=1> M;
  int<lower=1> V;
  int<lower=1> N;
  int<lower=1,upper=V> W[N];
  int<lower=1,upper=N> Offset[M,2]; // range of word index per doc

  vector<lower=0>[K] Alpha; // topic prior
  vector<lower=0>[V] Beta; // word prior
}
parameters {
  simplex[K] theta[M];
  simplex[V] phi[K];
}
model {
  // prior
  for (m in 1:M)
    theta[m] ~ dirichlet(Alpha);
  for (k in 1:K)
    phi[k] ~ dirichlet(Beta);

  // likelihood
  for (m in 1:M) {
    for (n in Offset[m,1]:Offset[m,2]) {
      real gamma[K];
      for (k in 1:K)
        gamma[k] = log(theta[m,k]) + log(phi[k,W[n]]);
      target += log_sum_exp(gamma);
    }
  }
}

データと下処理

今回は20_newsコーパスを使いました。これは20個のジャンルからなる20,000本のニュース記事の集合で、トピックモデルの評価にしばしば使われるコーパスです。

予め単語にインデックスを付与し、文章は1 10 8 9などの整数のリストに変換しています。したがって別ファイルvocab.txtにインデックスと単語の対応関係1 --> "news"などを記録しています。

library(tidyverse)
library(tidytext)

corpus = read_csv("./corpus_20news.txt", col_names = "words")
corpus = corpus %>% sample_n(1000)
corpus = corpus %>% mutate(doc = seq(1:nrow(corpus)))
corpus.tidy = corpus %>% unnest_tokens(word, words)
corpus.tidy = corpus.tidy %>% mutate(word = as.integer(word))

上のコードでコーパスをtidytext形式に変換しています。

doc word
1 178
1 878
1 866
1 3579
1 2045
1 522

ここから、頻出語の削除とそれに伴うインデックスの振り直しを行っていきます。

stop_words = corpus.tidy %>% count(word, sort=T) %>% filter(n > 1000 | n < 15)
corpus.tidy = corpus.tidy %>% anti_join(stop_words)

appeared_word = unique(corpus.tidy$word)

vocab = read_csv("./vocab_20news.txt", col_names = "word")
vocab = vocab %>% mutate(index = seq(nrow(vocab)))
vocab.new = vocab[appeared_word,]

vocab.new = vocab.new %>% mutate(new_ind = seq(nrow(vocab.new)))

corpus.lda = corpus.tidy %>% left_join(vocab.new, by = c("word" = "index"))

あとは、文書iに属する単語がtidy形式でインデックス何番目から何番目にあるのかを表す情報をoffset変数に格納していきます。

M = nrow(corpus)
N = nrow(corpus.lda)
V = nrow(vocab.new)
K = 20

offset = matrix(nrow = M, ncol = 2)

last_ind = 0
for (m in 1:M) {
  n_words = as.integer(corpus.lda %>% filter(doc == m) %>% count())
  offset[m,1] = last_ind + 1
  offset[m,2] = offset[m,1] + (n_words - 1)
  last_ind = offset[m,2]
}

パラメータ推定

モデルとデータがそろったのでパラメータ推定を行っていきましょう。アルゴリズムにはver.2.9からstanに搭載されたADVI(自動微分変分推論)を使いました。stanといえばNUTSというアルゴリズムが有名ですが、MCMCだとLDAのようなパラメータの多いアルゴリズムでは推定にとてつもない時間が必要になってしまいます。ADVIであれば、MCMCより精度は劣るものの、高速に推論を行うことができます。

library(rstan)

load("./20news.RData")

data = list(
  K = K,
  M = M,
  V = V,
  N = N,
  W = corpus.lda$new_ind,
  Offset = offset,
  Alpha = rep(1, K),
  Beta = rep(0.5, V)
)

sm = stan_model(file = "./lda.stan")

fit.vb = vb(
  sm,
  data = data,
  output_samples = 2000,
  adapt_engaged = FALSE,
  eta = .1
)

etaは自動的に調整してくれる機能もありますが、時間がかかるのでここでは天下り的に適当な値を指定しています。

結果

結果は以下のようになりました。

topic 1 university ca time apr problem made people fact law israel
topic 2 image years time make things people set god file color
topic 3 ftp information windows system file version files jpeg scsi gif
topic 4 image list system graphics bit format line run images jpeg
topic 5 time apr make message problem car back point god drug
topic 6 man day time apr work people car called ve drive
topic 7 image mail pc people find price software file files openwindows
topic 8 good ftp state space apr made system point year team
topic 9 good ca time support make people jews religion ve god
topic 10 good image university key ca work people system bit jpeg
topic 11 cs good state time make ve government god part give
topic 12 good time bill made people person place life book government
topic 13 good ca time apr make work people thing file jpeg
topic 14 man years information uk problem things people car thing god
topic 15 good university years ed time support make work ve god
topic 16 good makes apr make people back ve ll point god
topic 17 good question time make windows work people god run file
topic 18 cs good question time make work read system ve data
topic 19 good information mail time apr make work problem ve game
topic 20 question apr david make windows problem put people back god

なんとなくトピックらしきものが抽出されてる気もしますが、重なるものが多かったりあんまり品質はよくないようです。

感想

ある程度予想はしていたのですが、やはり既存のパッケージに比べると推定時間・精度ともに悪いです。ただ、LDAには無数の派生系があり、そうしたモデルの中にはどうしても自分で実装するしかないものも多いので、その後の拡張のしやすさという意味ではstanを使う意義もあるのではないでしょうか。僕はとりあえずgaussian-LDAをstanでやってみようと思っています。
コードおよびデータはここにも保存しています。