LoginSignup
2
3

More than 5 years have passed since last update.

Deeplearning4jでGloVeを使ってみる

Posted at

JavaのライブラリであるDeeplearning4jで、GloVeを使ってみようと思います。

前提

コーパス

あらかじめ学習させたいコーパスは用意しておいてください。日本語コーパスの場合は分かち書きを行っておきます。分かち書きの際には、動詞などを基本形(原形)に直しておいたほうがいいかも。

学習

コーパスのテキストファイルはinput.txtとしておきます。できたモデルはmodel.txtで保存します。

ModelBuild.java
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;

import java.io.*;

public class ModelBuild {
    public static void main( String[] args ) throws Exception{

        //コーパスファイルの読み込み
        System.out.println("データを読み込んでいます...");
        File inputFile = new File("input.txt");

        //文章データクラスとして読み込む
        SentenceIterator iter = new BasicLineIterator(inputFile);

        //トークナイザー(単語分割)クラスを作成する
        System.out.println("トークナイザーを作成します...");
        TokenizerFactory t = new DefaultTokenizerFactory();
        t.setTokenPreProcessor(new CommonPreprocessor());

        //モデルの作成
        System.out.println("モデルを作成しています...");
        Glove glove = new Glove.Builder()
                .iterate(iter) //文章データクラス
                .tokenizerFactory(t) //単語分解クラス
                .alpha(0.75) //重み付け関数の指数におけるパラメータ
                .learningRate(0.1) //初期学習率
                .epochs(25) //トレーニング中の訓練コーパス上の反復回数
                .layerSize(300) //ベクトルの次元数
                .maxMemory(2) //最大メモリ使用量
                .xMax(100) //重み関数のカットオフ
                .batchSize(1000) //1回のミニバッチで学習する単語数
                .windowSize(10) //ウィンドウサイズ
                .shuffle(true)
                .symmetric(true)
                .build();

        //学習
        System.out.println("学習しています...");
        glove.fit();

        //モデルの保存
        System.out.println("モデルを保存しています...");
        WordVectorSerializer.writeWordVectors(glove, "model.txt");

        System.out.println("プログラム終了です");
    }
}

評価

Evaluation.java
import java.io.File;
import java.io.FileNotFoundException;
import java.io.UnsupportedEncodingException;
import java.util.Collection;

import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;

public class Evaluation {

    public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException {
        //モデルファイルの読み込み
        System.out.println("モデルファイルを読み込んでいます...");
        File inputFile = new File(args[0]);
        WordVectors vec = WordVectorSerializer.loadTxtVectors(inputFile);

        //当該単語の類似単語上位10件を表示(例として「天気」)
        System.out.println("類似している単語上位10件...");
        String  word        = "天気";
        int     ranking     = 10;
        Collection<String>  similarTop10    = vec.wordsNearest( word , ranking );
        System.out.println( String.format( "Similar word to 「%s」 is %s" , word , similarTop10 ) );

        // コサイン類似度を表示(例として「晴れ」と「雨」)
        System.out.println( "コサイン類似度を表示..." );
        String  word1       = "晴れ";
        String  word2       = "雨";
        double  similarity  = vec.similarity( word1 , word2 );
        System.out.println( String.format( "The similarity between 「%s」 and 「%s」 is %f" , word1 , word2 , similarity ) );
    }
}

コードを参考にさせていただいたページ

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3