Help us understand the problem. What is going on with this article?

機械学習(文章から分類を判定)のサンプルソース(JAVA)

More than 3 years have passed since last update.

主旨

1.文章を投入すると、その文章が何を説明しているのか判断
2.あらかじめ学習データを投入する教師あり学習で、単純ベイズ分類器を利用
3.学習させている分類の割合と、文中の単語の頻出度合いを元に分類を決定している
(但し、学習データの事前確率は各分類で等価、つまり各分類につき1個の説明文書を投入)
4.サンプルソースなので、学習データの永続化などは行っていない
5.数学的に厳密に正しいという自信はない・・・です

つまり、、、

以下のソースで、分類を正しく返してくれるサンプルモジュール

main.java
    public static void main(String argv[]) {

        NaiveBayesClassifier classifier = new NaiveBayesClassifier();

        /**
         * 分類と説明をセットにして学習させる
         */
        classifier.learn("青龍",
                "青竜(せいりゅう、せいりょう、拼音: qīnglóng チンロン)は、" + 
                "中国の伝説上の神獣、四神(四象)の1つ。東方青竜。蒼竜(そうりゅう)ともいう。");

        classifier.learn("朱雀",
                "朱雀は南方を守護する神獣とされる。翼を広げた鳳凰様の鳥形で表される。" + 
                "朱は赤であり、五行説では南方の色とされる。" + 
                "鳳凰、不死鳥、フェニックス、インド神話に登場するガルーダ等と同一起源とする説や同一視される。");

        classifier.learn("白虎",
                "白虎(びゃっこ、拼音: báihŭ パイフー)は、中国の伝説上の神獣である四神の1つで、" + 
                "西方を守護する。白は、五行説では西方の色とされる。");

        classifier.learn("玄武",
                "玄武は、北方を守護する、水神。" +
                "「玄」は「黒」を意味し、黒は五行説では「北方」の色とされ、「水」を表す。");

        classifier.learn("黄龍",
                "黄竜、中国の伝承五行思想に現れる黄色の竜。黄金に輝く竜であると言う異説もある。" + 
                "四神の中心的存在、または、四神の長とも呼ばれている。" + 
                "四神が東西南北の守護獣なのに対し、中央を守るとされる。");

        /**
         * 適当な説明文で判定させてみる
         */
        Map<String, Double> scoreMap
            = classifier.judgeCategory("中国の五行思想において四神(四聖獣)の中央に位置するといわれている黄金の龍。");
        scoreMap.entrySet().stream()
            .sorted( (e1, e2) -> { 
                return e2.getValue().compareTo(e1.getValue()); 
            }).findFirst().ifPresent(e -> System.out.println(e.getKey()));

        scoreMap 
            = classifier.judgeCategory("陰陽五行説で玄・黒の色とされる北方を守護し、亀の姿であると言われる。");
        scoreMap.entrySet().stream()
            .sorted( (e1, e2) -> { 
                return e2.getValue().compareTo(e1.getValue()); 
            }).findFirst().ifPresent(e -> System.out.println(e.getKey()));

        scoreMap 
            = classifier.judgeCategory("鳳凰と同一視されることも多く、その場合様々な鳥獣の混ざり合った、美しい翼を持つ鳥とされる。");
        scoreMap.entrySet().stream()
            .sorted( (e1, e2) -> { 
                return e2.getValue().compareTo(e1.getValue()); 
            }).findFirst().ifPresent(e -> System.out.println(e.getKey()));

    }

実行結果
黄龍
玄武
朱雀

事前準備

簡単なのでPOMで。
文章から単語を抽出するのにKuromojiを使ってます。

pom.xml(依存性部分のみ)
    <dependencies>
        <dependency>
            <artifactId>lucene-core</artifactId>
            <groupId>org.apache.lucene</groupId>
            <version>5.1.0</version>
        </dependency>        
        <dependency>
            <artifactId>lucene-analyzers-kuromoji</artifactId>
            <groupId>org.apache.lucene</groupId>
            <version>5.1.0</version>
        </dependency>
    </dependencies>

実装サンプル

NaiveBayesClassifier.java
/**
 * 単純ベイズ分類器を利用
 * ① <学習> 事前に投入した分類の割合 と 分類ごとの説明文の単語の頻出度合いを保持
 * ② <分類> 上記のデータを元に分類ごとのスコアを算出
 * @author ryutaro_hakozaki
 */
public class NaiveBayesClassifier {

    /**
     * 利用(蓄積)するデータは以下の3つのみ
     * ① 説明文に登場した全ての単語の種類数 
     * ② 投入事前データの分類カウント(何の分類を何回学ばせたか)
     * ③ 分類ごとの、説明文に登場した単語の回数カウントマップ
     */

    // ①
    private final Set<String> wordSet = new HashSet<>();    
    // ②
    private final Map<String, Integer> categoryCountMap = new HashMap<>();
    // ③
    private final Map<String, Map<String, Integer>> categoryWordCountMap = new HashMap<>();

    /**
     * 学習はこれだけ。
     * ① 登場する単語の種類を追加
     * ② 指定された分類のカウントアップ(事前確率)
     * ③ 分類別の単語の出現数をカウントアップ
     */
    public void learn(String category, String description){

        try {
            // 説明文から単語を抽出
            toWords(description)
                .stream().sequential().forEach(w -> {

                    // ③ 設定済の分類の単語と出現回数マップを取得
                    Map<String, Integer> wordCountMap =  
                        null == categoryWordCountMap.get(category) 
                            ? new HashMap<>() : categoryWordCountMap.get(category);

                    // ③ 単語マップを更新
                    if(wordCountMap.containsKey(w)) {
                        wordCountMap.put(w, wordCountMap.get(w) + 1);
                    }else{
                        wordCountMap.put(w, 1);
                    }
                    categoryWordCountMap.put(category, wordCountMap);

                    // ① 登場した単語で新しいものがあれば追加
                    wordSet.add(w);

                });

            // ② 分類の学習回数を単純にカウントアップ
            if(categoryCountMap.containsKey(category)){
                categoryCountMap.put(category, categoryCountMap.get(category) + 1);
            } else {
                categoryCountMap.put(category, 1);
            }
        } catch (IOException ex) {
            Logger.getLogger(MachineLearningFirstStep.class.getName()).log(Level.SEVERE, null, ex);
        }

    }    

    /**
     * 説明文から分類を判定
     * ※ 数学的に厳密にあっている自信はありません
     * ① 事前確率 = 学習データの分類の割合とする
     * ② 単語の条件付確率 = 学習データの分類に出てきた各単語の割合の総和
     * 
     * ① * ② がスコアとなる
     * ⇒ アンダーフローを防ぐため、対数を取って乗算を和算にしている
     * ⇒ ゼロ頻度問題を防ぐためにラプラススムージング(単語の出現回数に+1)を行う
     */
    public Map<String, Double> judgeCategory(String description){

        Map<String, Double> categoryScoreMap = new HashMap<>();
        try {
            List<String> words = toWords(description);

            categoryCountMap.keySet()
                .stream().forEach((category) -> {

                    // ①を算出
                    double score
                        = Math.log((double)categoryCountMap.get(category) / categoryCountMap.values().stream().mapToDouble(d -> d).sum());                

                    // ②を算出
                    score = words
                        .stream()
                        .map((w) -> (categoryWordCountMap.get(category).containsKey(w)) ? categoryWordCountMap.get(category).get(w) : 0)
                        .map((wordCount) -> { 
                            wordCount++;
                            return wordCount;
                        }).map((wordCount) -> 
                            Math.log(
                                wordCount / (categoryWordCountMap.get(category).values().stream().mapToDouble(d -> d).sum() + wordSet.size())
                            )
                        ).reduce(score, (accumulator, _item) -> accumulator + _item);

                // 格納
                categoryScoreMap.put(category, score);

            });

        } catch (IOException ex) {
            Logger.getLogger(MachineLearningFirstStep.class.getName()).log(Level.SEVERE, null, ex);
        }
        return categoryScoreMap;

    }

    /**
     * ドキュメントから単語を抽出
     * 本筋ではないので細かい説明は省略
     */
    private List<String> toWords(String doc) throws IOException {
        List<String> ret = new ArrayList<>();
        AttributeFactory factory = AttributeFactory.DEFAULT_ATTRIBUTE_FACTORY;
        try(JapaneseTokenizer ts = new JapaneseTokenizer(factory, null, false, JapaneseTokenizer.DEFAULT_MODE)){
            ts.setReader(new StringReader(doc.toLowerCase()));
            PartOfSpeechAttribute ps = ts.addAttribute(PartOfSpeechAttribute.class);
            CharTermAttribute ct = ts.addAttribute(CharTermAttribute.class);
            ts.reset();
            while(ts.incrementToken()){
                if(ps.getPartOfSpeech().startsWith("名詞")) ret.add(ct.toString());
            }
        } catch (IOException e) {
        }
        return ret;
    }    

}
hakozaki
B2Bの商取引を電子化するプラットフォーマー(サービスプロバイダ)の会社で 自社の技術系ラボの責任者をしております。 新規事業やFintechプラットフォームの実現にむけて AI(ML)やブロックチェーンを用いた技術基盤の構築や協業での実証検証を主に行っています。 技術のこと・ビジネスのこと・これからの業界の動向のこと 色々情報交換させて頂ければ幸いです。
infomart
40万社以上の企業に商取引を電子データ化して効率化・利便性をもたらすBtoBプラットフォームを提供しています。
https://www.infomart.co.jp/index.asp
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away