はじめに
もう各所で解説されつくしたword2vecについて、自然言語処理の基本中の基本、ということで改めて学習してみようと思いました。gensimを使った実装例が多いと思いますが、ここではKeras+CNTKで実際に実装してみよう、という試みにしました。
NoSQL DBのMarkLogicの最新バージョン10がCNTKを搭載したたため、DB上のデータを直接使ってNLPを実現してみたい、という思いがあったりします。
本投稿のソースコードは以下に公開しております。
https://github.com/t2hk/word2vec_keras_cntk
環境
環境 | バージョン等 |
---|---|
Ubuntu | 16.04.6 LTS |
Keras | 2.2.4 |
Python | 3.6.8 |
CNTK | 2.7 |
CUDA | 10.1 |
GPU | GTX1070Ti |
参考にしたもの
主に以下の書籍やブログ、ソースコードを参考にさせて頂きました。
-
[ゼロから作るDeep Learning 2 自然言語処理編] (https://www.oreilly.co.jp/books/9784873118369/)
言わずと知れた名著です。機械学習のフレームワークに頼らず自分で実装するスタイルの解説なので深く理解できます。 -
[abaheti95氏 Deep-Learning word2vec] (https://github.com/abaheti95/Deep-Learning)
KerasとCNTKでword2vec(cbow)を実装している方がいらっしゃいました。今回は主にこちらを参考にさせていただき、CNTKバックエンドのKerasで実装してみました。 -
Adventure Machine Learning - A Word2Vec Keras tutorial
こちらもKerasでの実装を例としたチュートリアルです。skip-gramです。このブログは機械学習に関する解説がわかりやすいです。比較的読みやすい英語で書かれており、英語が苦手な方にもオススメです。
実装するword2vec(cbow)の概要
cbow(Continuous Bag of Words)を実装してみます。ベースは上記で紹介した[Deep-Learning word2vec] (https://github.com/abaheti95/Deep-Learning)を元にしています。
cbowの詳細については随所で語られているため、ここでは概要だけ記します。
該当のソースコードはcbow_train_onnx.pyになります。
-
モデルの概要
- 文章内の周辺の単語からその中心の単語を予測する
- 周辺の単語と中心の単語については関連性が高いため、単語ベクトルのWeightsを上げるように学習する
- ネガティブサンプリングした単語は関連性が低いため、単語ベクトルのWeightsを下げるように学習する
-
モデルの入力
- 中心となる単語
- 中心となる単語の前後の単語群(ウィンドウサイズ x 2)
- ネガティブサンプルの単語群(任意のサイズ)
-
モデルの出力
ネガティブサンプリングにより、正解と不正解の2値分類に置き換えています。- 周辺単語と中心単語は正解(=1)
- ネガティブサンプルは不正解(=0)
-
Kerasのモデル
実際に構築したモデルは以下のようになりました。
Layer (type) Output Shape Param # Connected to
==================================================================================================
word_index_input (InputLayer) (None, 1) 0
__________________________________________________________________________________________________
context_input (InputLayer) (None, 10) 0
__________________________________________________________________________________________________
embedding_1 (Embedding) multiple 103758300 word_index_input[0][0]
context_input[0][0]
negative_samples_input[0][0]
__________________________________________________________________________________________________
negative_samples_input (InputLa (None, 5) 0
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 300) 0 embedding_1[1][0]
__________________________________________________________________________________________________
dot_1 (Dot) (None, 1) 0 embedding_1[0][0]
lambda_1[0][0]
__________________________________________________________________________________________________
dot_2 (Dot) (None, 5) 0 embedding_1[2][0]
lambda_1[0][0]
学習データについて
日本語版Wikipediaの全文データの「jawiki-latest-pages-articles.xml.bz2」を使用します。
日本語版Wikipedia全文データをテキスト化する方法は随所で語られているため詳細は省きます。今回はwp2txtでXMLからテキストファイルに変換しました。
学習用入力データの生成
Wikipediaのテキストファイルを元に、モデルの入力データをPythonで生成します。大量にデータが存在するため適度にファイルを分割しておき、マルチプロセスで並列処理しました。
該当のソースコードはgen_prepared_data.pyやmulti_vocab_generator.pyになります。
MeCabで分かち書きし、特定の品詞の単語のみを学習対象としています。
MeCabはPythonから呼び出し、辞書は「mecab-ipadic-neologd」を使用しています。
学習対象の単語は以下としました。下記のように単語数を絞らないと、GPUメモリが不足してしまいました。
- 名詞(数と非自立を除く)
- 動詞(自立のみ、かつ、2文字以上)
- 形容詞(2文字以上)
- Wikipedia全体で60回以上使用されている単語
以下のように学習データを作成しました。
項目 | 値 |
---|---|
ウィンドウサイズ | 5 |
ネガティブサンプル数 | 5 |
単語数 | 345,858 |
学習用入力データ数 | 535,567,601 |
入力する学習データは以下のようなフォーマットにしました。
連番, 単語, ウィンドウサイズ分の前後の単語, ネガティブサンプルの単語
[1],[7],[0 0 0 0 0 103859 4775 14454 7290 3342],[228868 299358 258025 97607 13132]
[2],[103859],[0 0 0 0 7 4775 14454 7290 3342 0],[101266 143424 22165 164608 303077]
[3],[4775],[0 0 0 7 103859 14454 7290 3342 0 0],[172734 135301 56157 35406 27560]
[4],[14454],[0 0 7 103859 4775 7290 3342 0 0 0],[59589 80319 309995 209001 238735]
[5],[7290],[0 7 103859 4775 14454 3342 0 0 0 0],[239954 103781 319647 155381 153859]
学習結果
学習は以下のパラメータで実行しました。
項目 | 値 |
---|---|
embeddedレイヤーの次元数 | 300 |
バッチサイズ | 10,000 |
1epochのステップ数 | 53,557 |
epoch数 | 5 |
実行結果は以下の通りです。
項目 | 値 |
---|---|
1epochの処理時間 | 約2時間40分 |
学習中のGPUメモリ使用量 | 7,649MiB (8,119MiB中) |
Epoch 1/5
53557/53557 - 12744s 238ms/step - loss: 1.3724 - dot_1_loss: 1.1638 - dot_2_loss: 0.1967
Epoch 2/5
53557/53557 - 9368s 175ms/step - loss: 0.9370 - dot_1_loss: 0.7406 - dot_2_loss: 0.1797
Epoch 3/5
53557/53557 - 9375s 175ms/step - loss: 0.9526 - dot_1_loss: 0.7411 - dot_2_loss: 0.1938
Epoch 4/5
53557/53557 - 9397s 175ms/step - loss: 0.9586 - dot_1_loss: 0.7330 - dot_2_loss: 0.2059
Epoch 5/5
53557/53557 - 9379s 175ms/step - loss: 0.9843 - dot_1_loss: 0.7456 - dot_2_loss: 0.2178
5エポック実行してみましたが、全然lossが収束していない・・・。対象のデータが多いからでしょうか。
データを1千万件に絞って、追加で5エポック実行してみました。これは収束傾向が見えました。
Epoch 6/10
1000/1000 - 154s 154ms/step - loss: 0.8937 - dot_1_loss: 0.6649 - dot_2_loss: 0.2119
Epoch 7/10
1000/1000 - 154s 154ms/step - loss: 0.7043 - dot_1_loss: 0.4860 - dot_2_loss: 0.2051
Epoch 8/10
1000/1000 - 154s 154ms/step - loss: 0.6044 - dot_1_loss: 0.3919 - dot_2_loss: 0.2009
Epoch 9/10
1000/1000 - 154s 154ms/step - loss: 0.5423 - dot_1_loss: 0.3355 - dot_2_loss: 0.1970
Epoch 10/10
1000/1000 - 153s 153ms/step - loss: 0.4986 - dot_1_loss: 0.2986 - dot_2_loss: 0.1922
遊んでみる
学習した単語ベクトルを利用してみます。該当のソースコードはcbow_eval.pyです。
コサイン類似度、類似単語トップ5、類推を実行してみます。
===== cosine similarity =====
猫 : ライオン = 0.369179904460907
猫 : 犬 = 0.6505978107452393
ライオン : 犬 = 0.45467498898506165
動物の種別の近さよりも、ペットとしての意味の方が近いと判定されている模様です。
そこで、類似単語のトップ5を見てみます。
===== most similar =====
[query] 猫
ネコ: 0.7435717582702637
犬: 0.6505978107452393
ウサギ: 0.6034150719642639
飼っ: 0.5997058749198914
飼い主: 0.5615102052688599
ペット関連の単語ばかりですね。
そして最後に「King - Man + Women = Queen」でおなじみの類推です。
良い感じで類推できています、と思いきや日本と東京の関係はアメリカならワシントンD.C.が正解ですね。関係性の高い単語にもあがってきていません。学習が足りないかなぁ。
===== analogy =====
[analogy] 日本:東京 = アメリカ:?
ニューヨーク: 6.172818183898926
米国: 5.168918609619141
カリフォルニア州: 5.112800598144531
ニューヨーク州: 4.972364902496338
新宿: 4.692124843597412
[analogy] 王:男 = 女王:?
女: 5.910060882568359
美女: 5.535874843597412
恋: 4.4083428382873535
花嫁: 4.368666648864746
恋人: 4.366179466247559
まとめ
word2vecのcbowをKerasで実装してみました。negative samplingを導入し2値分類問題としています。
GTX1070Tiの8GBのGPUメモリでは、300次元で35万ほどの単語を学習するのがギリギリでした。
データ量が多いためなのかlossの収束が見られませんでしたが、類推などは良い結果が出ています。
品詞によって学習する単語を意味がありそうなものに絞りましたが、テストデータも絞る工夫が必要そうです。
このモデルはONNX形式で出力しているので、次回はMarkLogic10のCNTKで使ってみようと思います。
DB内のコンテンツを直接、学習や推論に利用できれば面白いことが出来そうです。