JavaのライブラリであるDeeplearning4jで、GloVeを使ってみようと思います。
前提
- Deeplearning4jやGloVeが何かについては特に説明しません。下記のサイト等が参考になるかもです。
- 今回はJava言語・Eclipseで実装しました。EclipseでDeeplearning4Jを使うときの準備は下記のサイトが参考になるかもです。
- Java DeepLearning4j+Eclipse 環境構築 - (注意)2016年12月現在、手順が若干変わっているようです。当該記事のコメント欄をご覧ください。
コーパス
あらかじめ学習させたいコーパスは用意しておいてください。日本語コーパスの場合は分かち書きを行っておきます。分かち書きの際には、動詞などを基本形(原形)に直しておいたほうがいいかも。
学習
コーパスのテキストファイルはinput.txtとしておきます。できたモデルはmodel.txtで保存します。
ModelBuild.java
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import java.io.*;
public class ModelBuild {
public static void main( String[] args ) throws Exception{
//コーパスファイルの読み込み
System.out.println("データを読み込んでいます...");
File inputFile = new File("input.txt");
//文章データクラスとして読み込む
SentenceIterator iter = new BasicLineIterator(inputFile);
//トークナイザー(単語分割)クラスを作成する
System.out.println("トークナイザーを作成します...");
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
//モデルの作成
System.out.println("モデルを作成しています...");
Glove glove = new Glove.Builder()
.iterate(iter) //文章データクラス
.tokenizerFactory(t) //単語分解クラス
.alpha(0.75) //重み付け関数の指数におけるパラメータ
.learningRate(0.1) //初期学習率
.epochs(25) //トレーニング中の訓練コーパス上の反復回数
.layerSize(300) //ベクトルの次元数
.maxMemory(2) //最大メモリ使用量
.xMax(100) //重み関数のカットオフ
.batchSize(1000) //1回のミニバッチで学習する単語数
.windowSize(10) //ウィンドウサイズ
.shuffle(true)
.symmetric(true)
.build();
//学習
System.out.println("学習しています...");
glove.fit();
//モデルの保存
System.out.println("モデルを保存しています...");
WordVectorSerializer.writeWordVectors(glove, "model.txt");
System.out.println("プログラム終了です");
}
}
評価
Evaluation.java
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.Collection;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
public class Evaluation {
public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
//モデルファイルの読み込み
System.out.println("モデルファイルを読み込んでいます...");
File inputFile = new File(args[0]);
WordVectors vec = WordVectorSerializer.loadTxtVectors(inputFile);
//当該単語の類似単語上位10件を表示(例として「天気」)
System.out.println("類似している単語上位10件...");
String word = "天気";
int ranking = 10;
Collection<String> similarTop10 = vec.wordsNearest( word , ranking );
System.out.println( String.format( "Similar word to 「%s」 is %s" , word , similarTop10 ) );
// コサイン類似度を表示(例として「晴れ」と「雨」)
System.out.println( "コサイン類似度を表示..." );
String word1 = "晴れ";
String word2 = "雨";
double similarity = vec.similarity( word1 , word2 );
System.out.println( String.format( "The similarity between 「%s」 and 「%s」 is %f" , word1 , word2 , similarity ) );
}
}