LoginSignup
3
2

word2vecは円周率の法則を発見できるか?

Posted at

はじめに

こんにちは、事業会社で働いているデータサイエンティストです。

この前ウイスキーを飲んでいた時、そういえば、円周率の数字はランダムであるとよく言われますが、基本的にエビデンスとして使われるのは、0から9までの数字の出現頻度がほぼ一緒という集計結果なんですが、全体の出現頻度が一緒というのは、別にパターンがないことを意味しないのでは?と思いました。Tverで秘密のケンミンショー極を見ながら。

数学の分野で円周率の数字が完全にランダムであると証明されているかは知らないですが、どんなに厳密な数学を使った理論(政治学の分野ではこのようなデータを使わない数理モデルのことを形式理論、formal theoryと呼びます)でも、結局実証分析の検証を受けないといけません。これは政治学、経済学、法学、社会学、物理学、化学、医学など全ての学問の形式理論に共通します。もちろん数学も例外ではありません。

そこで、単純集計ではなく、

3.14159265358979323846264338327950288419716939937510582097494459230781640

を文字として扱ったら、word2vecなどの自然言語処理の技術が、円周率の数字の出現パターンを突き止めてくれるのではないかと思って、早速実験に移りました。

結論

結論から先にいいます。

完全にダメでした、、、

ただ、ダメでも結果を残したいです。

モデルがうまくいかない理由は様々です。今までの経験からすると、適切なモデルを選定できなかったのはもちろんありますが、同じくらいの頻度で発生する問題として、コードのミスの他に、学習データと検証データを誤って混ぜたり、wifi接続に問題があったり、OSが古かったりなど、色々あります。

正直、小さなミスで素敵な分析のアイデアを葬り去るのは、ビジネス的にもアカデミック的にも勿体無いです。失敗を共有すれば、誰かが改善策を提案し、そこからよりいいものが生まれてきます。

私は今上場企業のデータサイエンスチームのリーダーを務めています。失敗してもそのままコードを全部削除するのではなく、きちんとその経験と学びを安心して残せる、それが私が創りたい組織の文化です。なので、失敗した結果も、しっかり記事にします。

モデル

はじめにのところで説明したように、今回の記事では、円周率の数字を単語と見なし、ベイズword2vecで分析を行います。

ざっくりいうと、円周率の中のとある区間の

59X30

のXは0から9の数字のどれなのかを判断させます。

もしうまく判断できれば、word2vecが円周率の法則を学習できたといえます。

定式化

前の記事

とほぼ同様ですが、文脈の位置によって異なるウエイトを設定しています。これは自然言語処理の文脈で説明すると、予測対象の単語に近い単語はより多くの情報を提供し、遠い単語はあまり予測に役立たないという仮説に基づいています。

$$ 重要度_{標準化前}\sim Normal(0,1) $$

$$ 重要度 = softmax(重要度_{標準化前}) $$

  • 全ての単語wについて

$$ sigma_{w} \sim Gamma(0.1, 0.1) $$

$$ word\space embedding_{w} \sim Normal(0,sigma_{w}) $$

$$ word\space context_{w} \sim Normal(0,1) $$

  • 全ての本物フラグ判定iについて

$$ \eta = word\space embedding_{word_{i}} '* \Sigma_{c \in 文脈_{i}} 重要度_{目標単語との相対位置_{c}} * word\space context_{c} $$

$$ 本物フラグ \sim Bernoulli(logit(\eta)) $$

Stanでの実装

Stanでの実装はこちらです:

bayesian_word2vec.stan
functions {
  real partial_sum_lpmf(
    array[] int flag,
    int start, int end,
    
    array[] int word,
    array[] int word_lead_1,
    array[] int word_lag_1,
    array[] int word_lead_2,
    array[] int word_lag_2,
    array[] int word_lead_3,
    array[] int word_lag_3,
    array[] int word_lead_4,
    array[] int word_lag_4,
    array[] int word_lead_5,
    array[] int word_lag_5,
    array[] int word_lead_6,
    array[] int word_lag_6,
    array[] int word_lead_7,
    array[] int word_lag_7,
    array[] int word_lead_8,
    array[] int word_lag_8,
    array[] int word_lead_9,
    array[] int word_lag_9,
    array[] int word_lead_10,
    array[] int word_lag_10,
    
    vector importance,
    array[] vector word_embedding,
    array[] vector word_context
  ){
    vector[end - start + 1] lambda;
    int count = 1;
    for (i in start:end){
      lambda[count] = word_embedding[word[i]] '* (word_context[word_lead_1[i]] * importance[1] + word_context[word_lag_1[i]] * importance[2] + 
                                                  word_context[word_lead_2[i]] * importance[3] + word_context[word_lag_2[i]] * importance[4] + 
                                                  word_context[word_lead_3[i]] * importance[5] + word_context[word_lag_3[i]] * importance[6] +
                                                  word_context[word_lead_4[i]] * importance[7] + word_context[word_lag_4[i]] * importance[8] + 
                                                  word_context[word_lead_5[i]] * importance[9] + word_context[word_lag_5[i]] * importance[10] + 
                                                  word_context[word_lead_6[i]] * importance[11] + word_context[word_lag_6[i]] * importance[12] +
                                                  word_context[word_lead_7[i]] * importance[13] + word_context[word_lag_7[i]] * importance[14] + 
                                                  word_context[word_lead_8[i]] * importance[15] + word_context[word_lag_8[i]] * importance[16] + 
                                                  word_context[word_lead_9[i]] * importance[17] + word_context[word_lag_9[i]] * importance[18] +
                                                  word_context[word_lead_10[i]] * importance[19] + word_context[word_lag_10[i]] * importance[20]
      );
      count += 1;
    }
    return bernoulli_logit_lupmf(flag | lambda);
  }
}
data {
  int<lower=1> N;
  int<lower=1> M;
  int<lower=1> word_type;
  
  array[N] int<lower=1,upper=word_type> word;
  array[N] int<lower=1,upper=word_type> word_lead_1;
  array[N] int<lower=1,upper=word_type> word_lag_1;
  array[N] int<lower=1,upper=word_type> word_lead_2;
  array[N] int<lower=1,upper=word_type> word_lag_2;
  array[N] int<lower=1,upper=word_type> word_lead_3;
  array[N] int<lower=1,upper=word_type> word_lag_3;
  array[N] int<lower=1,upper=word_type> word_lead_4;
  array[N] int<lower=1,upper=word_type> word_lag_4;
  array[N] int<lower=1,upper=word_type> word_lead_5;
  array[N] int<lower=1,upper=word_type> word_lag_5;
  array[N] int<lower=1,upper=word_type> word_lead_6;
  array[N] int<lower=1,upper=word_type> word_lag_6;
  array[N] int<lower=1,upper=word_type> word_lead_7;
  array[N] int<lower=1,upper=word_type> word_lag_7;
  array[N] int<lower=1,upper=word_type> word_lead_8;
  array[N] int<lower=1,upper=word_type> word_lag_8;
  array[N] int<lower=1,upper=word_type> word_lead_9;
  array[N] int<lower=1,upper=word_type> word_lag_9;
  array[N] int<lower=1,upper=word_type> word_lead_10;
  array[N] int<lower=1,upper=word_type> word_lag_10;
  array[N] int<lower=0,upper=1> flag;
  
  int<lower=0> val_N;
  array[val_N] int<lower=1,upper=word_type> val_word;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_1;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_1;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_2;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_2;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_3;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_3;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_4;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_4;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_5;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_5;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_6;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_6;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_7;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_7;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_8;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_8;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_9;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_9;
  array[val_N] int<lower=1,upper=word_type> val_word_lead_10;
  array[val_N] int<lower=1,upper=word_type> val_word_lag_10;
  array[val_N] int<lower=0,upper=1> val_flag;
}
parameters{
  vector[20] importance_unnormalized;
  vector<lower=0>[word_type] word_sigma;
  array[word_type] vector[M] word_embedding;
  array[word_type] vector[M] word_context;
}
transformed parameters {
  vector[20] importance;
  importance = softmax(importance_unnormalized);
}
model{
  importance_unnormalized ~ normal(0, 1);
  
  word_sigma ~ gamma(0.1, 0.1);
  
  for (w in 1:word_type){
    word_embedding[w] ~ normal(0, word_sigma[w]);
    word_context[w] ~ normal(0, 1);
  }
  
  int grainsize = 1;
  
  target += reduce_sum(
    partial_sum_lupmf, flag, grainsize,
    
    word,
    word_lead_1,
    word_lag_1,
    word_lead_2,
    word_lag_2,
    word_lead_3,
    word_lag_3,
    word_lead_4,
    word_lag_4,
    word_lead_5,
    word_lag_5,
    word_lead_6,
    word_lag_6,
    word_lead_7,
    word_lag_7,
    word_lead_8,
    word_lag_8,
    word_lead_9,
    word_lag_9,
    word_lead_10,
    word_lag_10,
    
    importance,
    word_embedding,
    word_context
  );
}
generated quantities {
  real metric_accuracy;
  real metric_recall;
  real metric_precision;
  real metric_f1_score;
  {
    array[val_N] int predicted;
    int TP = 0;
    int TN = 0;
    int FP = 0;
    int FN = 0;
    
    for (i in 1:val_N){
      predicted[i] = bernoulli_logit_rng(word_embedding[val_word[i]] '* (word_context[val_word_lead_1[i]] * importance[1] + word_context[val_word_lag_1[i]] * importance[2] + 
                                                                         word_context[val_word_lead_2[i]] * importance[3] + word_context[val_word_lag_2[i]] * importance[4] + 
                                                                         word_context[val_word_lead_3[i]] * importance[5] + word_context[val_word_lag_3[i]] * importance[6] +
                                                                         word_context[val_word_lead_4[i]] * importance[7] + word_context[val_word_lag_4[i]] * importance[8] + 
                                                                         word_context[val_word_lead_5[i]] * importance[9] + word_context[val_word_lag_5[i]] * importance[10] + 
                                                                         word_context[val_word_lead_6[i]] * importance[11] + word_context[val_word_lag_6[i]] * importance[12] +
                                                                         word_context[val_word_lead_7[i]] * importance[13] + word_context[val_word_lag_7[i]] * importance[14] + 
                                                                         word_context[val_word_lead_8[i]] * importance[15] + word_context[val_word_lag_8[i]] * importance[16] + 
                                                                         word_context[val_word_lead_9[i]] * importance[17] + word_context[val_word_lag_9[i]] * importance[18] +
                                                                         word_context[val_word_lead_10[i]] * importance[19] + word_context[val_word_lag_10[i]] * importance[20]
                                                                         )
                                        );
      if (val_flag[i] == 1 && predicted[i] == 1){
        TP += 1;
      }
      else if (val_flag[i] == 0 && predicted[i] == 0){
        TN += 1;
      }
      else if (val_flag[i] == 0 && predicted[i] == 1){
        FP += 1;
      }
      else if (val_flag[i] == 1 && predicted[i] == 0){
        FN += 1;
      }
    }
    metric_accuracy = ((TP + TN) * 1.0)/(TP + TN + FP + FN);
    metric_recall = (TP * 1.0)/(TP + FN);
    metric_precision = (TP * 1.0)/(TP + FP);
    metric_f1_score = (2.0 * metric_recall * metric_precision)/(metric_recall + metric_precision);
  }
}

前処理

まずは、10万桁の円周率をtxtファイルを読み込み、小数点を消して、全ての数字を単語であるかのように分割した後、文脈である前後の10個の数字を抽出してカラムに格納します:

pi_df <- readr::read_lines("pi.txt") |> 
  stringr::str_remove_all("\\.") |> 
  stringr::str_remove_all(" ") |> 
  stringr::str_split("") |> 
  unlist() |> 
  tibble::tibble(digit = _) |>
  dplyr::mutate(
    # Stanは1始まりの言語なので、実際の円周率の0を1にするため、全てに1を足す
    digit = as.integer(digit) + 1,
    digit_lead_1 = dplyr::lead(digit, 1),
    digit_lag_1 = dplyr::lag(digit, 1),
    digit_lead_2 = dplyr::lead(digit, 2),
    digit_lag_2 = dplyr::lag(digit, 2),
    digit_lead_3 = dplyr::lead(digit, 3),
    digit_lag_3 = dplyr::lag(digit, 3),
    digit_lead_4 = dplyr::lead(digit, 4),
    digit_lag_4 = dplyr::lag(digit, 4),
    digit_lead_5 = dplyr::lead(digit, 5),
    digit_lag_5 = dplyr::lag(digit, 5),
    digit_lead_6 = dplyr::lead(digit, 6),
    digit_lag_6 = dplyr::lag(digit, 6),
    digit_lead_7 = dplyr::lead(digit, 7),
    digit_lag_7 = dplyr::lag(digit, 7),
    digit_lead_8 = dplyr::lead(digit, 8),
    digit_lag_8 = dplyr::lag(digit, 8),
    digit_lead_9 = dplyr::lead(digit, 9),
    digit_lag_9 = dplyr::lag(digit, 9),
    digit_lead_10 = dplyr::lead(digit, 10),
    digit_lag_10 = dplyr::lag(digit, 10)
  ) |>
  tidyr::drop_na()

データフレイムはこのような形になっています:

> pi_df
# A tibble: 99,980 × 21
   digit digit_lead_1 digit_lag_1 digit_lead_2 digit_lag_2 digit_lead_3 digit_lag_3 digit_lead_4 digit_lag_4 digit_lead_5
   <dbl>        <dbl>       <dbl>        <dbl>       <dbl>        <dbl>       <dbl>        <dbl>       <dbl>        <dbl>
 1     6            9           4           10           6            8           7           10           3            4
 2     9           10           6            8           4           10           6            4           7            3
 3    10            8           9           10           6            4           4            3           6            4
 4     8           10          10            4           9            3           6            4           4            9
 5    10            4           8            3          10            4           9            9           6            5
 6     4            3          10            4           8            9          10            5           9            7
 7     3            4           4            9          10            5           8            7          10            3
 8     4            9           3            5           4            7          10            3           8            7
 9     9            5           4            7           3            3           4            7          10            5
10     5            7           9            3           4            7           3            5           4            4
# ℹ 99,970 more rows
# ℹ 11 more variables: digit_lag_5 <dbl>, digit_lead_6 <dbl>, digit_lag_6 <dbl>, digit_lead_7 <dbl>, digit_lag_7 <dbl>,
#   digit_lead_8 <dbl>, digit_lag_8 <dbl>, digit_lead_9 <dbl>, digit_lag_9 <dbl>, digit_lead_10 <dbl>,
#   digit_lag_10 <dbl>
# ℹ Use `print(n = ...)` to see more rows

次に、学習データとテストデータにわけますが、モデルが単に数字を丸暗記するのを防ぐため、最後の10000桁を全てテストデータにします。

test_id <- (nrow(pi_df) - 10000 + 1):nrow(pi_df)

pi_df_train <- pi_df[-test_id,] |>
  dplyr::mutate(
    flag = 1
  ) |>
  dplyr::bind_rows(
    pi_df[-test_id,] |>
      dplyr::mutate(
       # デタラメな数字を入れて判別させる
        digit = sample(pi_df$digit, dplyr::n()),
        flag = 0
      )
  )

pi_df_test <- pi_df[test_id,] |>
  dplyr::mutate(
    flag = 1
  ) |>
  dplyr::bind_rows(
    pi_df[test_id,] |>
      dplyr::mutate(
        digit = sample(pi_df$digit, dplyr::n()),
        flag = 0
      )
  )

data_list <- list(
  N = nrow(pi_df_train),
  M = 10,
  word_type = 10,
  
  word = pi_df_train$digit,
  word_lag_1 = pi_df_train$digit_lag_1,
  word_lead_1 = pi_df_train$digit_lead_1,
  word_lag_2 = pi_df_train$digit_lag_2,
  word_lead_2 = pi_df_train$digit_lead_2,
  word_lag_3 = pi_df_train$digit_lag_3,
  word_lead_3 = pi_df_train$digit_lead_3,
  word_lag_4 = pi_df_train$digit_lag_4,
  word_lead_4 = pi_df_train$digit_lead_4,
  word_lag_5 = pi_df_train$digit_lag_5,
  word_lead_5 = pi_df_train$digit_lead_5,
  word_lag_6 = pi_df_train$digit_lag_6,
  word_lead_6 = pi_df_train$digit_lead_6,
  word_lag_7 = pi_df_train$digit_lag_7,
  word_lead_7 = pi_df_train$digit_lead_7,
  word_lag_8 = pi_df_train$digit_lag_8,
  word_lead_8 = pi_df_train$digit_lead_8,
  word_lag_9 = pi_df_train$digit_lag_9,
  word_lead_9 = pi_df_train$digit_lead_9,
  word_lag_10 = pi_df_train$digit_lag_10,
  word_lead_10 = pi_df_train$digit_lead_10,
  flag = pi_df_train$flag ,
  
  val_N = nrow(pi_df_test),
  val_word = pi_df_test$digit,
  val_word_lag_1 = pi_df_test$digit_lag_1,
  val_word_lead_1 = pi_df_test$digit_lead_1,
  val_word_lag_2 = pi_df_test$digit_lag_3,
  val_word_lead_2 = pi_df_test$digit_lead_3,
  val_word_lag_3 = pi_df_test$digit_lag_3,
  val_word_lead_3 = pi_df_test$digit_lead_3,
  val_word_lag_4 = pi_df_test$digit_lag_4,
  val_word_lead_4 = pi_df_test$digit_lead_4,
  val_word_lag_5 = pi_df_test$digit_lag_5,
  val_word_lead_5 = pi_df_test$digit_lead_5,
  val_word_lag_6 = pi_df_test$digit_lag_6,
  val_word_lead_6 = pi_df_test$digit_lead_6,
  val_word_lag_7 = pi_df_test$digit_lag_7,
  val_word_lead_7 = pi_df_test$digit_lead_7,
  val_word_lag_8 = pi_df_test$digit_lag_8,
  val_word_lead_8 = pi_df_test$digit_lead_8,
  val_word_lag_9 = pi_df_test$digit_lag_9,
  val_word_lead_9 = pi_df_test$digit_lead_9,
  val_word_lag_10 = pi_df_test$digit_lag_10,
  val_word_lead_10 = pi_df_test$digit_lead_10,
  val_flag = pi_df_test$flag
)

これで前処理が終わりました!

詳細はこちらの記事を参照してください:

結果

早速学習を行いましょう:

> m_bw2v_estimate <- m_bw2v_init$variational(
     seed = 123,
     threads = 10,
     iter = 50000,
     data = data_list
 )
------------------------------------------------------------ 
EXPERIMENTAL ALGORITHM: 
  This procedure has not been thoroughly tested and may be unstable 
  or buggy. The interface is subject to change. 
------------------------------------------------------------ 
Gradient evaluation took 0.29859 seconds 
1000 transitions using 10 leapfrog steps per transition would take 2985.9 seconds. 
Adjust your expectations accordingly! 
Begin eta adaptation. 
Iteration:   1 / 250 [  0%]  (Adaptation) 
Iteration:  50 / 250 [ 20%]  (Adaptation) 
Iteration: 100 / 250 [ 40%]  (Adaptation) 
Iteration: 150 / 250 [ 60%]  (Adaptation) 
Iteration: 200 / 250 [ 80%]  (Adaptation) 
Success! Found best value [eta = 1] earlier than expected. 
Begin stochastic gradient ascent. 
  iter             ELBO   delta_ELBO_mean   delta_ELBO_med   notes  
   100      -126595.847             1.000            1.000 
   200      -125031.502             0.506            1.000 
   300      -124954.147             0.338            0.013 
   400      -124896.407             0.253            0.013 
   500      -124896.439             0.203            0.001   MEDIAN ELBO CONVERGED 
Drawing a sample of size 1000 from the approximate posterior...  
COMPLETED. 
Finished in  214.6 seconds.

精度を見てみると、、、

> m_bw2v_summary |>
     dplyr::filter(stringr::str_detect(variable, "metric"))
# A tibble: 4 × 7
  variable          mean median      sd     mad    q5   q95
  <chr>            <dbl>  <dbl>   <dbl>   <dbl> <dbl> <dbl>
1 metric_accuracy  0.500  0.500 0.00346 0.00348 0.494 0.506
2 metric_recall    0.500  0.500 0.00580 0.00563 0.490 0.509
3 metric_precision 0.500  0.500 0.00347 0.00350 0.495 0.506
4 metric_f1_score  0.500  0.500 0.00417 0.00406 0.493 0.507

おーい、ランダムやないかい!コイン投げでの予測と変わんないやん!どれも0.5と統計的に有意な差がないです(信頼区間q5 ~ q95に0.5が入っている)、、、

ちなみに、これはStanでword2vecを実装できないわけではないです。実際もっと複雑な状態空間動的word2vecを実装できました。

なので、実装にミスがなければ、word2vecは円周率のパターンを突き止められないという結論になります。

考察

今回の失敗した分析結果から以下のことがいえると思います。

  • そもそも円周率は本当にランダムで予測できない
  • 筆者のword2vecの実装が良くないだけ
  • word2vecが単純すぎて、実はトランスフォーマなどもっと複雑なモデルを使えばうまくいく

興味ある方はぜひチャレンジしてみてください!

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2