はじめに
こんにちは、事業会社で働いているデータサイエンティストです。
今回の記事は、前回の記事の続きとして、教師ありインド料理過程モデルでテキストデータの効果を推定する方法を紹介します。
モデルの構造はFong and Grimmer(2016)を参考にしていますが、少し構造を変えています。
では、早速説明に入ります!
モデル定式化
まず、このモデルでどのようなことを達成したいかを説明します。前回の記事で挙げた履歴書の分析の例を思い出していただきたいです。データには「弁護士歴ありフラグ」や「コンサル歴ありフラグ」はありませんが、弁護士歴のある方の履歴書には、おそらく「訴状」「裁判所」「依頼人」といった単語の出現頻度が高く、コンサル歴のある方の履歴書には「プロジェクト」「分析」「戦略」といった単語の出現頻度が高いと考えられます。したがって、履歴書というテキストデータの内容の中には、目に見えないフラグ情報が隠されていると推測できます。
このモデルは、隠れたフラグが履歴書のテキスト生成と目的変数に同時に影響を与えていると仮定し、各フラグ(トピック)に関連性の高い単語とフラグの効果を同時に推定することを目的としています。
まず、前回の記事で説明したように、棒折り過程で各フラグの出現頻度$\pi$をサンプリングします:
$$
\eta_{k} \sim Beta(\alpha, 1)
$$
$$
\pi_{k} = \prod_{z = 1}^{k}\eta_{k}
$$
各観測値におけるフラグ$z_{j,k}$は、$\pi_k$の確率で1になるベルヌーイ分布からサンプリングされます:
$$
z_{j,k} \sim Bernoulli(\pi_{k})
$$
最後に、例えば観測値$j$のフラグをすべて集めてベクトル$Z_j$にすると、
$$
単語出現回数_{j} \sim Poisson(exp(base + Z_j '* topic \text{_} intencity))
$$
ただし、$単語出現回数_{j}$は観測値の単語文書行列の行で、$base$はトピックに影響されない単語の出現傾向で、要するに切片のようなものです。最後の$topic\text{_}intensity$は各トピックの単語出現傾向を示します。「弁護士歴ありフラグ」が立っている人の「訴状」「裁判所」「依頼人」といった単語の出現傾向が高くなり、さらに「コンサル歴ありフラグ」も同時に立っていると、「プロジェクト」「分析」「戦略」の出現傾向も高くなります。
最後に、結果変数も0か1しか取らない変数だと
$$
y_{j} \sim Bernoulli(logit(Z_{j} '* \beta))
$$
でモデリングできます。結果変数があるため、教師ありと呼ばれています。
テキストの効果分析で、こんな面倒なことをしなくても、単語文書行列を独立変数にし、結果変数を説明する回帰モデルを組めば良いという考え方もあるかもしれません。しかし、そのようなモデルを構築する際に何が得られるかを想像してみてください。おそらく多くの単語の係数が得られますが、その係数表は本当にビジネスやアカデミアの課題解決に寄与するのでしょうか?
係数の羅列だけでは、どのような特性を持つ文書が結果変数(例えば転職の成功など)に寄与しているのかを判断・解釈するのは難しいです。「依頼人」や「プロジェクト」などの単語が転職に成功した人の履歴書に現れやすいという結果が得られたと仮定しましょう。しかし、残念ながら、これだけでは意思決定者にとって「で?」という疑問しか生まれないと思います。
テキストデータのような高次元データは、何らかの次元圧縮を行わない限り、意思決定に活用しづらいです。そこで、単語の「共起関係」と該当するトピック、さらにはそのトピックの効果も推定できるインド料理過程モデルは、非常に便利なモデルとなります。
また、フラグとフラグの効果を別々に推定すればよいのではという意見もあるでしょう。しかし、インド料理過程のようにフラグを仮定した場合や、LDAをはじめとする伝統的なトピックモデルのようにトピック分布を仮定した場合、さらにはより単純なクラスタリングを仮定した場合でも、結局データを分類・整理する方法は無限に存在します。その無限にある整理方法の中で、私たちは結果変数と関係がありそうなものを選びたいと考えています。そのためには、同じモデルで推定し、単語の出現傾向だけでなく、結果変数に影響を与える観点も加味する必要があります。これにより、より意思決定の目的に適した結果を得ることができます。
これをStanで表現すると次のようになります。単語文書行列はゼロが多いため、ダウンサンプリングを実施したデータを想定しています。また、計算速度の向上のために、計算のベクトル化(LLMのベクトル化ではありません(笑))および分散処理も行っています。細かいパラメータの事前分布についての説明は省略します。
functions {
real partial_sum_lpmf(
array[] int response,
int begin, int end,
int vocab_type, int latent_flag_type, int latent_flag_combination_type,
array[,] int latent_flag_combination_int, array[] vector latent_flag_combination_vector,
array[] int word, array[] int frequency,
array[] int text_begin, array[] int text_end,
vector latent_flag_probability,
vector word_base_intensity, array[] vector word_flag_intensity,
vector beta_flag, real intercept
){
vector[end - begin + 1] log_likelihood;
int count = 1;
for (i in begin:end){
vector[latent_flag_combination_type] case_when;
for (p in 1:latent_flag_combination_type){
vector[vocab_type] mu = rep_vector(0.0, vocab_type);
for (q in 1:latent_flag_type){
mu += latent_flag_combination_int[p, q] * word_flag_intensity[q];
}
case_when[p] = bernoulli_lpmf(
latent_flag_combination_int[p] | latent_flag_probability
) +
poisson_log_lpmf(
frequency[text_begin[i]:text_end[i]] |
word_base_intensity[word[text_begin[i]:text_end[i]]] +
mu[word[text_begin[i]:text_end[i]]]
) +
bernoulli_logit_lpmf(
response[count] | intercept + latent_flag_combination_vector[p] '* beta_flag
);
}
log_likelihood[count] = log_sum_exp(case_when);
count += 1;
}
return sum(log_likelihood);
}
}
data {
int vocab_type;
int latent_flag_type;
int latent_flag_combination_type;
array[latent_flag_combination_type, latent_flag_type] int latent_flag_combination_int;
array[latent_flag_combination_type] vector[latent_flag_type] latent_flag_combination_vector;
int N;
array[N] int word;
array[N] int frequency;
int text_type;
array[text_type] int text_begin;
array[text_type] int text_end;
array[text_type] int response;
}
parameters {
real<lower=0> alpha;
vector<lower=0,upper=1>[latent_flag_type] eta;
real<lower=0> word_base_lambda;
vector<lower=0>[vocab_type] word_base_sigma;
vector[vocab_type] word_base_intensity;
real<lower=0> word_flag_lambda;
array[latent_flag_type] vector<lower=0>[vocab_type] word_flag_sigma;
array[latent_flag_type] vector[vocab_type] word_flag_intensity;
real<lower=0> lambda_flag;
vector<lower=0>[latent_flag_type] sigma_flag;
vector[latent_flag_type] beta_flag;
real intercept;
}
transformed parameters {
vector[latent_flag_type] latent_flag_probability;
for (i in 1:latent_flag_type){
latent_flag_probability[i] = prod(eta[1:i]);
}
}
model {
alpha ~ gamma(0.01, 0.01);
eta ~ beta(alpha, 1);
word_base_lambda ~ gamma(0.01, 0.01);
word_base_sigma ~ exponential((word_base_lambda^2)/2.0);
word_base_intensity ~ normal(0, word_base_sigma);
word_flag_lambda ~ gamma(0.01, 0.01);
for (i in 1:latent_flag_type){
word_flag_sigma[i] ~ exponential((word_flag_lambda^2)/2.0);
word_flag_intensity[i] ~ normal(0, word_flag_sigma[i]);
}
lambda_flag ~ gamma(0.01, 0.01);
sigma_flag ~ exponential((lambda_flag^2)/2.0);
beta_flag ~ normal(0, sigma_flag);
intercept ~ normal(0, 1);
int grainsize = 1;
target += reduce_sum(
partial_sum_lupmf, response, grainsize,
vocab_type, latent_flag_type, latent_flag_combination_type,
latent_flag_combination_int, latent_flag_combination_vector,
word, frequency,
text_begin, text_end,
latent_flag_probability,
word_base_intensity, word_flag_intensity,
beta_flag, intercept
);
}
データ説明とモデル推定
今回の記事では、このデータを利用します:
詳細はEgami, Fong, Grimmer, Roberts, and Stewart(2022)の論文を確認していただきたいですが、簡潔に説明すると、アメリカの消費者金融保護局(Consumer Financial Protection Bureau、略称:CFPB)がどういう訴えをより迅速に解決するかを分析するためのデータです。結果変数は迅速な解決につながったかどうかを示すフラグです。
ここでは、時間の節約のため、ランダムに10000けんのデータを選んで分析を実施します:
set.seed(12345)
text_df <- readr::read_csv("CFPBDt.csv") |>
dplyr::slice_sample(n = 10000) |>
dplyr::mutate(
text_id = dplyr::row_number()
)
vocab_master <- tibble::tibble(
vocab = colnames(text_df)[-1]
) |>
dplyr::mutate(
vocab_id = dplyr::row_number()
)
long_text_df <- text_df |>
dplyr::select(!resp) |>
tidyr::pivot_longer(!text_id, names_to = "vocab", values_to = "count") |>
split(~ text_id) |>
purrr::map(
\(this_df){
# ダウンサンプリングを実施
# 出現回数がゼロではない単語と同じ数の出現回数ゼロの単語をサンプリング
nonzero_df <- this_df |>
dplyr::filter(count > 0)
downsampled_zero_df <- this_df |>
dplyr::filter(count == 0) |>
dplyr::slice_sample(n = nrow(nonzero_df))
nonzero_df |>
dplyr::bind_rows(
downsampled_zero_df
)
},
.progress = TRUE
) |>
dplyr::bind_rows() |>
dplyr::left_join(
vocab_master, by = "vocab"
)
text_begin_end_response_df <- long_text_df |>
dplyr::mutate(
location_id = dplyr::row_number()
) |>
dplyr::summarise(
begin = min(location_id),
end = max(location_id),
.by = text_id
) |>
dplyr::left_join(
text_df |>
dplyr::select(text_id, resp) |>
dplyr::mutate(
resp_flag = dplyr::case_when(
resp == "Yes" ~ 1,
TRUE ~ 0
)
),
by = "text_id"
)
latent_flag_combination <- tibble::tibble(
one = c(0, 1),
two = c(0, 1),
three = c(0, 1),
four = c(0, 1),
five = c(0, 1)
) |>
tidyr::complete(one, two, three, four, five)
data_list <- list(
vocab_type = nrow(vocab_master),
latent_flag_type = 5,
latent_flag_combination_type = nrow(latent_flag_combination),
latent_flag_combination_int = as.matrix(latent_flag_combination),
latent_flag_combination_vector = as.matrix(latent_flag_combination),
N = nrow(long_text_df),
word = long_text_df$vocab_id,
frequency = long_text_df$count,
text_type = nrow(text_begin_end_response_df),
text_begin = text_begin_end_response_df$begin,
text_end = text_begin_end_response_df$end,
response = text_begin_end_response_df$resp_flag
)
では、ここでモデルをコンパイルし:
m_sibp_init <- cmdstanr::cmdstan_model("supervised_indian_buffet_process.stan",
cpp_options = list(
stan_threads = TRUE
)
)
推定を実施します:
> m_sibp_estimate <- m_sibp_init$variational(
seed = 12345,
threads = 30,
iter = 50000,
data = data_list
)
------------------------------------------------------------
EXPERIMENTAL ALGORITHM:
This procedure has not been thoroughly tested and may be unstable
or buggy. The interface is subject to change.
------------------------------------------------------------
Gradient evaluation took 9.22142 seconds
1000 transitions using 10 leapfrog steps per transition would take 92214.2 seconds.
Adjust your expectations accordingly!
Begin eta adaptation.
Iteration: 1 / 250 [ 0%] (Adaptation)
Iteration: 50 / 250 [ 20%] (Adaptation)
Iteration: 100 / 250 [ 40%] (Adaptation)
Iteration: 150 / 250 [ 60%] (Adaptation)
Iteration: 200 / 250 [ 80%] (Adaptation)
Iteration: 250 / 250 [100%] (Adaptation)
Success! Found best value [eta = 0.1].
Begin stochastic gradient ascent.
iter ELBO delta_ELBO_mean delta_ELBO_med notes
100 -885196.154 1.000 1.000
200 -787746.339 0.562 1.000
300 -754997.244 0.389 0.124
400 -736538.208 0.298 0.124
500 -725127.098 0.242 0.043
600 -717668.924 0.203 0.043
700 -712313.856 0.175 0.025
800 -708294.099 0.154 0.025
900 -705297.218 0.137 0.016
1000 -702973.216 0.124 0.016
1100 -701217.800 0.113 0.010
1200 -700041.435 0.104 0.010
1300 -699061.090 0.096 0.008 MEDIAN ELBO CONVERGED
Drawing a sample of size 1000 from the approximate posterior...
COMPLETED.
Finished in 14001.5 seconds.
かなり時間がかかりました。インド料理過程モデルは観測値ごとのトピックのような中間変数を保持しないので、推定が終わって事後分布を書き出す際に大変なことになることはないです。
では、結果を保存します:
m_sibp_summary <- m_sibp_estimate$summary()
推定結果
まずは、フラグの出現確率を確認します;
> m_sibp_summary |>
+ dplyr::filter(stringr::str_detect(variable, "^latent_flag_probability\\["))
# A tibble: 5 × 7
variable mean median sd mad q5 q95
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 latent_flag_probability[1] 0.412 0.412 0.0115 0.0117 0.394 0.431
2 latent_flag_probability[2] 0.250 0.250 0.00919 0.00922 0.236 0.266
3 latent_flag_probability[3] 0.160 0.160 0.00834 0.00845 0.147 0.174
4 latent_flag_probability[4] 0.159 0.159 0.00831 0.00836 0.146 0.173
5 latent_flag_probability[5] 0.154 0.154 0.00827 0.00835 0.141 0.168
モデルに最大限考慮してほしいフラグの数がすべて出現しているように見えるため、実際にはもう少し考慮すべきフラグを追加した方が良いかもしれません。
次に、各フラグが結果変数に与える効果を示す係数の値を確認します:
> m_sibp_summary |>
dplyr::filter(stringr::str_detect(variable, "^beta_flag\\["))
# A tibble: 5 × 7
variable mean median sd mad q5 q95
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 beta_flag[1] -0.0329 -0.0274 0.178 0.171 -0.332 0.253
2 beta_flag[2] 1.21 1.20 0.242 0.246 0.806 1.63
3 beta_flag[3] -0.0662 -0.0651 0.264 0.258 -0.520 0.357
4 beta_flag[4] 1.31 1.30 0.270 0.279 0.880 1.74
5 beta_flag[5] 1.35 1.36 0.293 0.290 0.862 1.83
フラグ(トピック)2、4、5が含まれている場合、消費者金融保護局の対応速度が速くなるという結果が得られています。一方、フラグ1と3は統計的に有意ではないものの、若干のマイナス効果が見受けられます。
さて、ここで各フラグの代表的な単語を出します:
> m_sibp_summary |>
dplyr::filter(stringr::str_detect(variable, "^word_flag_intensity\\[")) |>
dplyr::mutate(
id = variable |>
purrr::map(
\(x){
as.integer(stringr::str_split(x, "\\[|\\]|,")[[1]][2:3])
}
)
) |>
tidyr::unnest_wider(id, names_sep = "_") |>
dplyr::left_join(
vocab_master, by = c("id_2" = "vocab_id")
) |>
split(~ id_1) |>
purrr::map(
\(df){
df |>
dplyr::arrange(dplyr::desc(mean)) |>
dplyr::select(vocab)
}
) |>
dplyr::bind_cols() |>
print(n = 20)
New names:
• `vocab` -> `vocab...1`
• `vocab` -> `vocab...2`
• `vocab` -> `vocab...3`
• `vocab` -> `vocab...4`
• `vocab` -> `vocab...5`
# A tibble: 749 × 5
vocab...1 vocab...2 vocab...3 vocab...4 vocab...5
<chr> <chr> <chr> <chr> <chr>
1 car delete firm chase overdraft
2 xxxx_xxxx experian xxxx_xxxx ocwen deposited
3 xxxx_payment inquiries xxxx_xxxxxxxxxxxx trust points
4 late_fee identity_theft trust escrow branch
5 loans deleted debt_collector modification deposit
6 late_fees xxxx_xxxxxxxxxxxx ocwen wells_fargo cards
7 school item fdcpa principal chase
8 xxxx_payments fair_credit alleged xxxx_xxxx transactions
9 harassing credit_file xxxxxxxxxxxx property purchases
10 xxxx_paid fcra equifax payoff pending
11 late_payment xxxx_accounts address_xxxx taxes union
12 vehicle inaccurate debt loan_modification banking
13 mother reporting_agencies vehicle dated citibank
14 phone_xxxx inquiry court servicer business_days
15 medical belong validation xxxxxxxxxxxx american
16 xxxx_phone validate phone_xxxx mortgage_xxxx requirements
17 paid_xxxx creditors evidence closing card_xxxx
18 billing equifax public agreement check
19 payments_xxxx chapter_xxxx llc xxxx_mortgage checking_account
20 month_xxxx validation bureau program checks
# ℹ 729 more rows
# ℹ Use `print(n = ...)` to see more rows
結果をまとめます:
フラグ2、4、5(係数値:1.2–1.35)→ 迅速な対応との強い正の関連
-
フラグ2(クレジットレポートファイルとアイデンティティの盗難、係数値:1.21): クレジットレポートファイル(
credit_file
)やアイデンティティの盗難(identity_theft
)に関する訴えは、迅速な対応を受ける可能性が高いという結果になっています。他にも信用情報機関(experian
、equifax
)の名前も特徴的なワードとして現れています -
フラグ4(モーゲージとローンサービシング、係数値:1.31): モーゲージサービス(
mortgage_xxxx
、xxxx_mortgage
)、エスクロー口座(escrow
)、変更(modification
)に関連する問題は、金融的利害関係や消費者保護法により、迅速な対応を促すようです -
フラグ5(銀行取引とオーバードラフト、係数値:1.35): 預金(
deposited
、deposit
)、オーバードラフト(overdraft
)に関連する問題も高い反応を示し、消費者金融保護局がこれらの争議を迅速に処理するための構造化されたプロセスを持っているためかもしれません
フラグ1と3(係数値:-0.03から-0.07)→ 迅速な対応に対する有意な影響なしまたはわずかに負の効果
-
フラグ1(自動車ローンと遅延料金、係数値:-0.03): 自動車ローン(
car
、loans
)や遅延料金(late_fees
)に関連する問題は、迅速な対応を受けることがないようで、これは規制が少ないか優先度が低いためかもしれません -
フラグ3(債務回収と法的問題、係数値:-0.07): 債務回収(
debt_collector
)に関する訴えは、迅速な反応が最も低く、これは第三者の回収業者の関与や法的な複雑さ、強力な規制の執行が不足しているためと考えられます
結論
いかがでしたか?
本記事で紹介したモデルを活用することで、テキストデータなどの非構造化データから、結果変数に影響を与える要因を発見し、その効果を推定することができます。もちろん、アカデミアでもビジネスでも、非構造化データから注目すべき要素を発見できた場合、次のステップとして、その要因を推定結果を参考にして定義し、コストをかけてアノテーションを行うことをお勧めします。そして、アノテーションされたデータを用いて、より深い分析を実施することが可能になります。
コストとリードタイムの観点から、LLMを利用したアノテーションを検討している方には、こちらの論文をおすすめします:
最後に、私たちと一緒に、データサイエンスの力で社会を改善したい方はこちらをご確認ください:
参考文献
Egami, Naoki, Christian J. Fong, Justin Grimmer, Margaret E. Roberts, and Brandon M. Stewart. "How to make causal inferences using texts." Science Advances 8.42 (2022): eabg2652.
Fong, Christian, and Justin Grimmer. "Discovery of treatments from text corpora." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2016.