StarSpaceとは?
Facebook Researchが出しているOSSの自然言語処理ツールです。様々なタスクに対して用いることのできる分散表現を効率よく学習できるツールです。以下が、公式のgithubにStarSpaceで扱えるタスクとして挙げられていたものの例になります。
- 単語や文、ドキュメントレベルの分散表現の学習
- 情報抽出:エンティティ、ドキュメント、オブジェクトのランキング
- テキスト分類、そのほかのラベリングタスク
- メトリック、類似性の学習:文やドキュメントの類似性の学習
- コンテンツベース、協調フィルタリングベースのレコメンデーション
などなど。。。
それ以外にも様々なタスクを行えるツールとなっています。また、共通的なベクトル埋め込み空間で異なるタイプのオブジェクトを表現できるようしているのが特徴であるようです。名前の由来としては、Starが*(wildcard)、Spaceが共通の空間を表していて、互いの分散表現を比較できることからきているとのことです。
Facebookが他に出している自然言語処理ツールのfastTextに続いて公開したツールでもあり、気になっていたので試してみました。
インストール
自分の動作環境
- macOS High Sierra 10.13.4
インストール要件
- (gcc-4.6.3 or newer) or (clang-3.3 or newer)
インストール手順
- ソースコードをcloneしてくる
$ git clone https://github.com/facebookresearch/Starspace.git
- gcc5を以下のサイトを参考にインストール(今回はなんとなくバージョン5で)
- boostをインストール
$ brew install boost
- makeコマンドを実行したところ以下のエラーに遭遇した
$ make
g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/utils/normalize.cpp
In file included from /opt/local/include/gcc5/c++/bits/postypes.h:40:0,
from /opt/local/include/gcc5/c++/bits/char_traits.h:40,
from /opt/local/include/gcc5/c++/string:40,
from src/utils/normalize.h:12,
from src/utils/normalize.cpp:10:
/opt/local/include/gcc5/c++/cwchar:44:19: fatal error: wchar.h: No such file or directory
compilation terminated.
make: *** [normalize.o] Error 1
- 該当する対処方法があったので、以下のサイトを参考にXcodeのツールをインストールするコマンドを実行
$ xcode-select --install
- 再度makeを実行したところ別のエラーに遭遇
g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/utils/normalize.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/dict.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -g -c src/utils/args.cpp
g++ -pthread -std=gnu++11 -O3 -funroll-loops -I/usr/local/bin/boost_1_63_0/ -g -c src/proj.cpp
In file included from src/proj.h:14:0,
from src/proj.cpp:11:
src/matrix.h:26:42: fatal error: boost/numeric/ublas/matrix.hpp: No such file or directory
compilation terminated.
make: *** [proj.o] Error 1
- makefileのboostのパスを設定していなかったことによるエラーでした。自分の場合、
$ brew install boost
で以下のディレクトリにboostがインストールされていました。- BOOST_DIR = /usr/local/Cellar/boost/1.67.0_1/include/
- makefileのBOOST_DIRを変更
BOOST_DIR = /usr/local/bin/boost_1_63_0/
↓
BOOST_DIR = /usr/local/Cellar/boost/1.67.0_1/include/
- 再度、makeを実行。無事コンパイル完了。
- helpが表示されるか確認
$ ./starspace --help
TagSpace word / tag embeddings(マルチクラスのテキスト分類)
TagSpace word / tag embeddings(論文)という例をまず動かしてみます。TagSpace word / tag embeddingsは、あるテキストに該当する複数のタグ(クラス)を予測するための分散表現を学習することができます。テキストとタグの単語の両方を使って分散表現を学習し、テキストからタグを予測するモデルを構築できます。
論文より、提案された手法は、55億の単語を学習して、100,000のハッシュタグを予測可能ということです。すごい。。。
The proposed approach is trained on up to 5.5 billion words predicting 100,000 possible hashtags.
公式のgithubのinput file format例は以下になります。
restaurant has great food #yum #restaurant
イメージとしては、ツイッターのテキストからテキストが属する複数のハッシュタグを予測するタスクでしょうか。マルチクラス分類に適用できるので、活用範囲は広そうです。
- git cloneしてきたディレクトリにexampleのscriptがあったので、以下のコマンドで実行してみます。
$ bash examples/classification_ag_news.sh
example scriptを実行すると、AG's News Topic Classification Datasetというデータセットをダウンロードしてきて学習を行ってくれます。
AG's News Topic Classification Datasetは、4つのトピックをクラスとしてもつデータセットで1つのトピックあたりtraining用に30000、test用に1900のサンプルがあります。4つのクラスを合わせると、training用サンプルが120,000、test用サンプルが7,600です。
4つのクラス情報は、ダウンロードしてきたclasses.txtに入っていて、以下になります。
World
Sports
Business
Sci/Tech
また、生のデータであるtrain.csvの1つのサンプルデータは以下になっており、""で囲まれた1列目がクラスのラベル情報、2列目がtitle、3列目がdescriptionになります。
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
example scriptで前処理されたデータのサンプルを1つ見てみます。ラベルは__label__x
の形式で読み込んでStarSpaceで計算を行っています。
__label__1 , indian violence kills 6 , injures 51 , six persons , including two border security force ( bsf ) men , were killed and 51 others injured in three blasts triggered by united liberation front of asom ( ulfa ) militants
example scriptを実行した後の表示される結果は以下になります。学習する分散表現の次元数が10で小さいので、1分もかからずに終わりました。
$ bash examples/classification_ag_news.sh
Downloading dataset ag_news
Compiling StarSpace
make: Nothing to be done for `opt'.
Start to train on ag_news data:
Arguments:
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 3
negSearchLimit: 5
thread: 20
minCount: 1
minCountLabel: 1
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 0
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropoutRHS: 0
Start to initialize starspace model.
Build dict from input file : /tmp/starspace/data/ag_news.train
Read 5M words
Number of words in dictionary: 95811
Number of labels in dictionary: 4
Loading data from file : /tmp/starspace/data/ag_news.train
Total number of examples loaded : 120000
Initialized model weights. Model size :
matrix : 95815 10
Training epoch 0: 0.01 0.002
Epoch: 100.0% lr: 0.008100 loss: 0.008115 eta: <1min tot: 0h0m1s (20.0%)
---+++ Epoch 0 Train error : 0.00644227 +++--- ☃
Training epoch 1: 0.008 0.002
Epoch: 100.0% lr: 0.006017 loss: 0.003892 eta: <1min tot: 0h0m3s (40.0%)
---+++ Epoch 1 Train error : 0.00397125 +++--- ☃
Training epoch 2: 0.006 0.002
Epoch: 100.0% lr: 0.004017 loss: 0.003494 eta: <1min tot: 0h0m4s (60.0%)
---+++ Epoch 2 Train error : 0.00338104 +++--- ☃
Training epoch 3: 0.004 0.002
Epoch: 100.0% lr: 0.002000 loss: 0.003006 eta: <1min tot: 0h0m6s (80.0%)
---+++ Epoch 3 Train error : 0.00295286 +++--- ☃
Training epoch 4: 0.002 0.002
Epoch: 100.0% lr: -0.000000 loss: 0.002587 eta: <1min tot: 0h0m7s (100.0%)
---+++ Epoch 4 Train error : 0.00262243 +++--- ☃
Saving model to file : /tmp/starspace/models/ag_news
Saving model in tsv format : /tmp/starspace/models/ag_news.tsv
Start to evaluate trained model:
Arguments:
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 10
negSearchLimit: 50
thread: 10
minCount: 1
minCountLabel: 1
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 1
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropoutRHS: 0
Start to load a trained starspace model.
STARSPACE-2017-2
Initialized model weights. Model size :
matrix : 95815 10
Model loaded.
Loading data from file : /tmp/starspace/data/ag_news.test
Total number of examples loaded : 7600
------Loaded model args:
Arguments:
lr: 0.01
dim: 10
epoch: 5
maxTrainTime: 8640000
saveEveryEpoch: 0
loss: hinge
margin: 0.05
similarity: dot
maxNegSamples: 3
negSearchLimit: 5
thread: 10
minCount: 1
minCountLabel: 1
label: __label__
ngrams: 1
bucket: 2000000
adagrad: 1
trainMode: 0
fileFormat: fastText
normalizeText: 0
dropoutLHS: 0
dropoutRHS: 0
Predictions use 4 known labels.
Evaluation Metrics :
hit@1: 0.9175 hit@10: 1 hit@20: 1 hit@50: 1 mean ranks : 1.1025 Total examples : 7600
テスト用のサンプルに対し、hit@1で0.9175なので良さげな精度で分類できていますね。
- 以下のコマンドを実行して、学習で作成されたモデルを再度読み込んでprediction結果を出力させてみます。
./starspace test -testFile /tmp/starspace/data/ag_news.test -model /tmp/starspace/models/ag_news -predictionFile predictions.txt
以下がprediction結果の1つの例になります。この例はちゃんとSportsのクラスとして分類されています。
Example 0:
LHS:
, four players shoot opening-round 66 , , ala . -- grace park , looking to clinch second place in the player of the year race , birdied the final hole thursday to gain a share of the lead after the first round of the tournament of champions .
RHS:
__label__2
Predictions:
(++) [0.209871] __label__2
(--) [-0.0409123] __label__3
(--) [-0.0630488] __label__4
(--) [-0.0941172] __label__1
:
:
ちなみに分散表現は、/tmp/starspace/models/ag_news.tsv
を確認すると以下のような値になっているのがわかります。
you 0.0517028 0.0632872 -0.0886393 -0.0326751 -0.0343589 0.062882 0.0229753 -0.00139947 0.0396142 0.035067
set 0.0062284 0.034237 0.0105888 -0.0300584 0.00787413 -0.00723825 0.0320939 0.0129017 0.0224701 0.00541018
european -0.0211528 -0.0308988 0.0846774 0.0104599 0.0217948 -0.062741 -0.0414285 0.02002 -0.0412714 -0.00655955
before 0.0662441 0.0324028 -0.0278512 0.00877903 -0.0279142 0.0364988 -0.0500219 0.011719 0.0185126 0.0202915
chief -0.10016 -0.0920889 -0.00988563 0.0314743 0.0207762 -0.00694755 0.0799257 -0.059415 -0.0262858 -0.0397386
lead 0.0166031 0.0406471 0.0176889 -0.0079458 0.00753064 -0.0399314 0.0280651 0.00753941 0.0194812 0.0288038
technology -0.111373 -0.20321 -0.230317 0.14528 -0.0319123 0.264915 -0.0247894 -0.189683 -0.056927 -0.0913001
com 0.1435 0.0565506 -0.169092 0.0104515 -0.0409452 0.182425 -0.108167 0.00969619 0.0484618 0.0146816
talks -0.132501 -0.0956215 0.0586796 0.0332124 0.0425119 -0.0445733 0.0536436 -0.0447683 -0.0432121 -0.0541774
cup 0.214038 0.21899 0.22578 -0.146971 -0.00575729 -0.211818 -0.0364427 0.229259 0.0749109 0.114307
league 0.105266 0.11454 0.176639 -0.0696982 0.0259426 -0.166549 -0.0516321 0.143722 0.0402037 0.0724694
american -0.0468925 0.00557206 0.0219382 -0.0345875 0.0411395 -0.0749748 0.105936 -0.00246951 0.0176306 -0.0282046
just 0.0318377 0.0455117 -0.0381685 -0.0189038 -0.0041167 0.0333977 0.0183844 -0.00186247 0.0222026 0.00371659
search -0.0629342 -0.0745887 -0.0725334 0.0367046 -0.014612 0.0953445 -0.00631663 -0.0807699 -0.0200482 -0.0429865
computer -0.0913863 -0.126859 -0.287999 0.0950781 -0.0119644 0.282886 -0.0011539 -0.195667 -0.0161312 -0.0648938
space 0.0118128 -0.221955 -0.160012 0.181186 -0.0430855 0.272821 -0.245684 -0.147063 -0.10963 -0.031327
online 0.0278405 -0.0733343 -0.224693 0.081582 -0.0471127 0.261379 -0.0733302 -0.130254 0.0252417 -0.011354
what 0.0422958 0.0202962 -0.0794863 -0.0105392 -0.0304567 0.105753 -0.00637522 0.00359807 0.0309032 -0.00036712
country -0.0991791 -0.0863653 0.0596176 0.0548506 0.0256906 -0.0383166 -0.0200991 -0.0274331 -0.0642891 -0.0304374
:
:
まとめ
今回は、StarSpaceでテキスト分類だけですが試してみました。以下のようなことを日本語データでやってみたいなぁと思ったり。。。
- ツイートを自動分類(画像認識、自然言語処理、音声認識とかみたいなカテゴリで)
- 商品名、商品説明から商品のカテゴリ分類
また、論文を読んでアルゴリズムを理解するのと、StarSpaceで実行できる他のタスクもやりたいなぁ(願望)