0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

テキストの効果を推定する教師ありインド料理過程モデルをStanで実装してみた

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

今回の記事は、前回の記事の続きとして、教師ありインド料理過程モデルでテキストデータの効果を推定する方法を紹介します。

モデルの構造はFong and Grimmer(2016)を参考にしていますが、少し構造を変えています。

では、早速説明に入ります!

モデル定式化

まず、このモデルでどのようなことを達成したいかを説明します。前回の記事で挙げた履歴書の分析の例を思い出していただきたいです。データには「弁護士歴ありフラグ」や「コンサル歴ありフラグ」はありませんが、弁護士歴のある方の履歴書には、おそらく「訴状」「裁判所」「依頼人」といった単語の出現頻度が高く、コンサル歴のある方の履歴書には「プロジェクト」「分析」「戦略」といった単語の出現頻度が高いと考えられます。したがって、履歴書というテキストデータの内容の中には、目に見えないフラグ情報が隠されていると推測できます。

このモデルは、隠れたフラグが履歴書のテキスト生成と目的変数に同時に影響を与えていると仮定し、各フラグ(トピック)に関連性の高い単語とフラグの効果を同時に推定することを目的としています。

まず、前回の記事で説明したように、棒折り過程で各フラグの出現頻度$\pi$をサンプリングします:

$$
\eta_{k} \sim Beta(\alpha, 1)
$$

$$
\pi_{k} = \prod_{z = 1}^{k}\eta_{k}
$$

各観測値におけるフラグ$z_{j,k}$は、$\pi_k$の確率で1になるベルヌーイ分布からサンプリングされます:

$$
z_{j,k} \sim Bernoulli(\pi_{k})
$$

最後に、例えば観測値$j$のフラグをすべて集めてベクトル$Z_j$にすると、

$$
単語出現回数_{j} \sim Poisson(exp(base + Z_j '* topic \text{_} intencity))
$$

ただし、$単語出現回数_{j}$は観測値の単語文書行列の行で、$base$はトピックに影響されない単語の出現傾向で、要するに切片のようなものです。最後の$topic\text{_}intensity$は各トピックの単語出現傾向を示します。「弁護士歴ありフラグ」が立っている人の「訴状」「裁判所」「依頼人」といった単語の出現傾向が高くなり、さらに「コンサル歴ありフラグ」も同時に立っていると、「プロジェクト」「分析」「戦略」の出現傾向も高くなります。

最後に、結果変数も0か1しか取らない変数だと

$$
y_{j} \sim Bernoulli(logit(Z_{j} '* \beta))
$$

でモデリングできます。結果変数があるため、教師ありと呼ばれています。

テキストの効果分析で、こんな面倒なことをしなくても、単語文書行列を独立変数にし、結果変数を説明する回帰モデルを組めば良いという考え方もあるかもしれません。しかし、そのようなモデルを構築する際に何が得られるかを想像してみてください。おそらく多くの単語の係数が得られますが、その係数表は本当にビジネスやアカデミアの課題解決に寄与するのでしょうか?

係数の羅列だけでは、どのような特性を持つ文書が結果変数(例えば転職の成功など)に寄与しているのかを判断・解釈するのは難しいです。「依頼人」や「プロジェクト」などの単語が転職に成功した人の履歴書に現れやすいという結果が得られたと仮定しましょう。しかし、残念ながら、これだけでは意思決定者にとって「で?」という疑問しか生まれないと思います。

テキストデータのような高次元データは、何らかの次元圧縮を行わない限り、意思決定に活用しづらいです。そこで、単語の「共起関係」と該当するトピック、さらにはそのトピックの効果も推定できるインド料理過程モデルは、非常に便利なモデルとなります。

また、フラグとフラグの効果を別々に推定すればよいのではという意見もあるでしょう。しかし、インド料理過程のようにフラグを仮定した場合や、LDAをはじめとする伝統的なトピックモデルのようにトピック分布を仮定した場合、さらにはより単純なクラスタリングを仮定した場合でも、結局データを分類・整理する方法は無限に存在します。その無限にある整理方法の中で、私たちは結果変数と関係がありそうなものを選びたいと考えています。そのためには、同じモデルで推定し、単語の出現傾向だけでなく、結果変数に影響を与える観点も加味する必要があります。これにより、より意思決定の目的に適した結果を得ることができます。

これをStanで表現すると次のようになります。単語文書行列はゼロが多いため、ダウンサンプリングを実施したデータを想定しています。また、計算速度の向上のために、計算のベクトル化(LLMのベクトル化ではありません(笑))および分散処理も行っています。細かいパラメータの事前分布についての説明は省略します。

supervised_indian_buffet_process.stan
functions {
  real partial_sum_lpmf(
    array[] int response,

    int begin, int end,

    int vocab_type, int latent_flag_type, int latent_flag_combination_type,

    array[,] int latent_flag_combination_int, array[] vector latent_flag_combination_vector,

    array[] int word, array[] int frequency,

    array[] int text_begin, array[] int text_end,

    vector latent_flag_probability,

    vector word_base_intensity, array[] vector word_flag_intensity,

    vector beta_flag, real intercept
  ){
    vector[end - begin + 1] log_likelihood;
    int count = 1;

    for (i in begin:end){
    vector[latent_flag_combination_type] case_when;
    for (p in 1:latent_flag_combination_type){
      vector[vocab_type] mu = rep_vector(0.0, vocab_type);
      for (q in 1:latent_flag_type){
        mu += latent_flag_combination_int[p, q] * word_flag_intensity[q];
      }
      case_when[p] = bernoulli_lpmf(
        latent_flag_combination_int[p] | latent_flag_probability
      ) +
      poisson_log_lpmf(
        frequency[text_begin[i]:text_end[i]] |
        word_base_intensity[word[text_begin[i]:text_end[i]]] +
        mu[word[text_begin[i]:text_end[i]]]
      ) +
      bernoulli_logit_lpmf(
        response[count] | intercept + latent_flag_combination_vector[p] '* beta_flag
      );
    }
    log_likelihood[count] = log_sum_exp(case_when);
    count += 1;
    }

    return sum(log_likelihood);
  }
}
data {
  int vocab_type;
  int latent_flag_type;

  int latent_flag_combination_type;
  array[latent_flag_combination_type, latent_flag_type] int latent_flag_combination_int;
  array[latent_flag_combination_type] vector[latent_flag_type] latent_flag_combination_vector;

  int N;
  array[N] int word;
  array[N] int frequency;

  int text_type;
  array[text_type] int text_begin;
  array[text_type] int text_end;
  array[text_type] int response;
}
parameters {
  real<lower=0> alpha;
  vector<lower=0,upper=1>[latent_flag_type] eta;

  real<lower=0> word_base_lambda;
  vector<lower=0>[vocab_type] word_base_sigma;
  vector[vocab_type] word_base_intensity;

  real<lower=0> word_flag_lambda;
  array[latent_flag_type] vector<lower=0>[vocab_type] word_flag_sigma;
  array[latent_flag_type] vector[vocab_type] word_flag_intensity;

  real<lower=0> lambda_flag;
  vector<lower=0>[latent_flag_type] sigma_flag;
  vector[latent_flag_type] beta_flag;

  real intercept;
}
transformed parameters {
  vector[latent_flag_type] latent_flag_probability;
  for (i in 1:latent_flag_type){
    latent_flag_probability[i] = prod(eta[1:i]);
  }
}
model {
  alpha ~ gamma(0.01, 0.01);
  eta ~ beta(alpha, 1);

  word_base_lambda ~ gamma(0.01, 0.01);
  word_base_sigma ~ exponential((word_base_lambda^2)/2.0);
  word_base_intensity ~ normal(0, word_base_sigma);

  word_flag_lambda ~ gamma(0.01, 0.01);
  for (i in 1:latent_flag_type){
    word_flag_sigma[i] ~ exponential((word_flag_lambda^2)/2.0);
    word_flag_intensity[i] ~ normal(0, word_flag_sigma[i]);
  }

  lambda_flag ~ gamma(0.01, 0.01);
  sigma_flag ~ exponential((lambda_flag^2)/2.0);
  beta_flag ~ normal(0, sigma_flag);

  intercept ~ normal(0, 1);

  int grainsize = 1;

  target += reduce_sum(
    partial_sum_lupmf, response, grainsize,

    vocab_type, latent_flag_type, latent_flag_combination_type,

    latent_flag_combination_int, latent_flag_combination_vector,

    word, frequency,

    text_begin, text_end,

    latent_flag_probability,

    word_base_intensity, word_flag_intensity,

    beta_flag, intercept
  );
}

データ説明とモデル推定

今回の記事では、このデータを利用します:

詳細はEgami, Fong, Grimmer, Roberts, and Stewart(2022)の論文を確認していただきたいですが、簡潔に説明すると、アメリカの消費者金融保護局(Consumer Financial Protection Bureau、略称:CFPB)がどういう訴えをより迅速に解決するかを分析するためのデータです。結果変数は迅速な解決につながったかどうかを示すフラグです。

ここでは、時間の節約のため、ランダムに10000けんのデータを選んで分析を実施します:

set.seed(12345)

text_df <- readr::read_csv("CFPBDt.csv") |>
  dplyr::slice_sample(n = 10000) |>
  dplyr::mutate(
    text_id = dplyr::row_number()
  )

vocab_master <- tibble::tibble(
  vocab = colnames(text_df)[-1]
) |>
  dplyr::mutate(
    vocab_id = dplyr::row_number()
  )

long_text_df <- text_df |>
  dplyr::select(!resp) |>
  tidyr::pivot_longer(!text_id, names_to = "vocab", values_to = "count") |>
  split(~ text_id) |>
  purrr::map(
    \(this_df){
      # ダウンサンプリングを実施
      # 出現回数がゼロではない単語と同じ数の出現回数ゼロの単語をサンプリング
      nonzero_df <- this_df |>
        dplyr::filter(count > 0)

      downsampled_zero_df <- this_df |>
        dplyr::filter(count == 0) |>
        dplyr::slice_sample(n = nrow(nonzero_df))

      nonzero_df |>
        dplyr::bind_rows(
          downsampled_zero_df
        )
    },
    .progress = TRUE
  ) |>
  dplyr::bind_rows() |>
  dplyr::left_join(
    vocab_master, by = "vocab"
  )


text_begin_end_response_df <- long_text_df |>
  dplyr::mutate(
    location_id = dplyr::row_number()
  ) |>
  dplyr::summarise(
    begin = min(location_id),
    end = max(location_id),
    .by = text_id
  ) |>
  dplyr::left_join(
    text_df |>
      dplyr::select(text_id, resp) |>
      dplyr::mutate(
        resp_flag = dplyr::case_when(
          resp == "Yes" ~ 1,
          TRUE ~ 0
        )
      ),
    by = "text_id"
  )

latent_flag_combination <- tibble::tibble(
  one = c(0, 1),
  two = c(0, 1),
  three = c(0, 1),
  four = c(0, 1),
  five = c(0, 1)
) |>
  tidyr::complete(one, two, three, four, five)

data_list <- list(
  vocab_type = nrow(vocab_master),
  latent_flag_type = 5,

  latent_flag_combination_type = nrow(latent_flag_combination),
  latent_flag_combination_int = as.matrix(latent_flag_combination),
  latent_flag_combination_vector = as.matrix(latent_flag_combination),

  N = nrow(long_text_df),
  word = long_text_df$vocab_id,
  frequency = long_text_df$count,

  text_type = nrow(text_begin_end_response_df),
  text_begin = text_begin_end_response_df$begin,
  text_end = text_begin_end_response_df$end,
  response = text_begin_end_response_df$resp_flag
)


では、ここでモデルをコンパイルし:

m_sibp_init <- cmdstanr::cmdstan_model("supervised_indian_buffet_process.stan",
                                       cpp_options = list(
                                         stan_threads = TRUE
                                       )
                                       )

推定を実施します:

> m_sibp_estimate <- m_sibp_init$variational(
     seed = 12345,
     threads = 30,
     iter = 50000,
     data = data_list
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 9.22142 seconds 
1000 transitions using 10 leapfrog steps per transition would take 92214.2 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Iteration: 250 / 250 [100%]  (Adaptation) 
Success! Found best value [eta = 0.1]. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100      -885196.154             1.000            1.000 
   200      -787746.339             0.562            1.000 
   300      -754997.244             0.389            0.124 
   400      -736538.208             0.298            0.124 
   500      -725127.098             0.242            0.043 
   600      -717668.924             0.203            0.043 
   700      -712313.856             0.175            0.025 
   800      -708294.099             0.154            0.025 
   900      -705297.218             0.137            0.016 
  1000      -702973.216             0.124            0.016 
  1100      -701217.800             0.113            0.010 
  1200      -700041.435             0.104            0.010 
  1300      -699061.090             0.096            0.008   MEDIAN ELBO CONVERGED 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  14001.5 seconds.

かなり時間がかかりました。インド料理過程モデルは観測値ごとのトピックのような中間変数を保持しないので、推定が終わって事後分布を書き出す際に大変なことになることはないです。

では、結果を保存します:

m_sibp_summary <- m_sibp_estimate$summary()

推定結果

まずは、フラグの出現確率を確認します;

> m_sibp_summary |>
+     dplyr::filter(stringr::str_detect(variable, "^latent_flag_probability\\["))
# A tibble: 5 × 7
  variable                    mean median      sd     mad    q5   q95
  <chr>                      <dbl>  <dbl>   <dbl>   <dbl> <dbl> <dbl>
1 latent_flag_probability[1] 0.412  0.412 0.0115  0.0117  0.394 0.431
2 latent_flag_probability[2] 0.250  0.250 0.00919 0.00922 0.236 0.266
3 latent_flag_probability[3] 0.160  0.160 0.00834 0.00845 0.147 0.174
4 latent_flag_probability[4] 0.159  0.159 0.00831 0.00836 0.146 0.173
5 latent_flag_probability[5] 0.154  0.154 0.00827 0.00835 0.141 0.168

モデルに最大限考慮してほしいフラグの数がすべて出現しているように見えるため、実際にはもう少し考慮すべきフラグを追加した方が良いかもしれません。

次に、各フラグが結果変数に与える効果を示す係数の値を確認します:

> m_sibp_summary |>
     dplyr::filter(stringr::str_detect(variable, "^beta_flag\\["))
# A tibble: 5 × 7
  variable        mean  median    sd   mad     q5   q95
  <chr>          <dbl>   <dbl> <dbl> <dbl>  <dbl> <dbl>
1 beta_flag[1] -0.0329 -0.0274 0.178 0.171 -0.332 0.253
2 beta_flag[2]  1.21    1.20   0.242 0.246  0.806 1.63 
3 beta_flag[3] -0.0662 -0.0651 0.264 0.258 -0.520 0.357
4 beta_flag[4]  1.31    1.30   0.270 0.279  0.880 1.74 
5 beta_flag[5]  1.35    1.36   0.293 0.290  0.862 1.83 

フラグ(トピック)2、4、5が含まれている場合、消費者金融保護局の対応速度が速くなるという結果が得られています。一方、フラグ1と3は統計的に有意ではないものの、若干のマイナス効果が見受けられます。

さて、ここで各フラグの代表的な単語を出します:

> m_sibp_summary |>
     dplyr::filter(stringr::str_detect(variable, "^word_flag_intensity\\[")) |>
     dplyr::mutate(
         id = variable |>
             purrr::map(
                 \(x){
                     as.integer(stringr::str_split(x, "\\[|\\]|,")[[1]][2:3])
                 }
             )
     ) |>
     tidyr::unnest_wider(id, names_sep = "_") |>
     dplyr::left_join(
         vocab_master, by = c("id_2" = "vocab_id")
     ) |>
     split(~ id_1) |>
     purrr::map(
         \(df){
             df |>
                 dplyr::arrange(dplyr::desc(mean)) |>
                 dplyr::select(vocab)
         }
     ) |>
     dplyr::bind_cols() |>
     print(n = 20)
New names:
 `vocab` -> `vocab...1`
 `vocab` -> `vocab...2`
 `vocab` -> `vocab...3`
 `vocab` -> `vocab...4`
 `vocab` -> `vocab...5`
# A tibble: 749 × 5
   vocab...1     vocab...2          vocab...3         vocab...4         vocab...5       
   <chr>         <chr>              <chr>             <chr>             <chr>           
 1 car           delete             firm              chase             overdraft       
 2 xxxx_xxxx     experian           xxxx_xxxx         ocwen             deposited       
 3 xxxx_payment  inquiries          xxxx_xxxxxxxxxxxx trust             points          
 4 late_fee      identity_theft     trust             escrow            branch          
 5 loans         deleted            debt_collector    modification      deposit         
 6 late_fees     xxxx_xxxxxxxxxxxx  ocwen             wells_fargo       cards           
 7 school        item               fdcpa             principal         chase           
 8 xxxx_payments fair_credit        alleged           xxxx_xxxx         transactions    
 9 harassing     credit_file        xxxxxxxxxxxx      property          purchases       
10 xxxx_paid     fcra               equifax           payoff            pending         
11 late_payment  xxxx_accounts      address_xxxx      taxes             union           
12 vehicle       inaccurate         debt              loan_modification banking         
13 mother        reporting_agencies vehicle           dated             citibank        
14 phone_xxxx    inquiry            court             servicer          business_days   
15 medical       belong             validation        xxxxxxxxxxxx      american        
16 xxxx_phone    validate           phone_xxxx        mortgage_xxxx     requirements    
17 paid_xxxx     creditors          evidence          closing           card_xxxx       
18 billing       equifax            public            agreement         check           
19 payments_xxxx chapter_xxxx       llc               xxxx_mortgage     checking_account
20 month_xxxx    validation         bureau            program           checks          
# ℹ 729 more rows
# ℹ Use `print(n = ...)` to see more rows

結果をまとめます:

フラグ2、4、5(係数値:1.2–1.35)→ 迅速な対応との強い正の関連

  • フラグ2(クレジットレポートファイルとアイデンティティの盗難、係数値:1.21): クレジットレポートファイル(credit_file)やアイデンティティの盗難(identity_theft)に関する訴えは、迅速な対応を受ける可能性が高いという結果になっています。他にも信用情報機関(experianequifaxの名前も特徴的なワードとして現れています

  • フラグ4(モーゲージとローンサービシング、係数値:1.31): モーゲージサービス(mortgage_xxxxxxxx_mortgage)、エスクロー口座(escrow)、変更(modification)に関連する問題は、金融的利害関係や消費者保護法により、迅速な対応を促すようです

  • フラグ5(銀行取引とオーバードラフト、係数値:1.35): 預金(depositeddeposit)、オーバードラフト(overdraft)に関連する問題も高い反応を示し、消費者金融保護局がこれらの争議を迅速に処理するための構造化されたプロセスを持っているためかもしれません

フラグ1と3(係数値:-0.03から-0.07)→ 迅速な対応に対する有意な影響なしまたはわずかに負の効果

  • フラグ1(自動車ローンと遅延料金、係数値:-0.03): 自動車ローン(carloans)や遅延料金(late_fees)に関連する問題は、迅速な対応を受けることがないようで、これは規制が少ないか優先度が低いためかもしれません

  • フラグ3(債務回収と法的問題、係数値:-0.07): 債務回収(debt_collector)に関する訴えは、迅速な反応が最も低く、これは第三者の回収業者の関与や法的な複雑さ、強力な規制の執行が不足しているためと考えられます

結論

いかがでしたか?

本記事で紹介したモデルを活用することで、テキストデータなどの非構造化データから、結果変数に影響を与える要因を発見し、その効果を推定することができます。もちろん、アカデミアでもビジネスでも、非構造化データから注目すべき要素を発見できた場合、次のステップとして、その要因を推定結果を参考にして定義し、コストをかけてアノテーションを行うことをお勧めします。そして、アノテーションされたデータを用いて、より深い分析を実施することが可能になります。

コストとリードタイムの観点から、LLMを利用したアノテーションを検討している方には、こちらの論文をおすすめします:

Using Large Language Model Annotations for the Social Sciences: A General Framework of Using Predicted Variables in Downstream Analyses

最後に、私たちと一緒に、データサイエンスの力で社会を改善したい方はこちらをご確認ください:

参考文献

Egami, Naoki, Christian J. Fong, Justin Grimmer, Margaret E. Roberts, and Brandon M. Stewart. "How to make causal inferences using texts." Science Advances 8.42 (2022): eabg2652.

Fong, Christian, and Justin Grimmer. "Discovery of treatments from text corpora." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2016.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?