主旨
1.文章を投入すると、その文章が何を説明しているのか判断
2.あらかじめ学習データを投入する教師あり学習で、単純ベイズ分類器を利用
3.学習させている分類の割合と、文中の単語の頻出度合いを元に分類を決定している
(但し、学習データの事前確率は各分類で等価、つまり各分類につき1個の説明文書を投入)
4.サンプルソースなので、学習データの永続化などは行っていない
5.数学的に厳密に正しいという自信はない・・・です
つまり、、、
以下のソースで、分類を正しく返してくれるサンプルモジュール
main.java
public static void main(String argv[]) {
NaiveBayesClassifier classifier = new NaiveBayesClassifier();
/**
* 分類と説明をセットにして学習させる
*/
classifier.learn("青龍",
"青竜(せいりゅう、せいりょう、拼音: qīnglóng チンロン)は、" +
"中国の伝説上の神獣、四神(四象)の1つ。東方青竜。蒼竜(そうりゅう)ともいう。");
classifier.learn("朱雀",
"朱雀は南方を守護する神獣とされる。翼を広げた鳳凰様の鳥形で表される。" +
"朱は赤であり、五行説では南方の色とされる。" +
"鳳凰、不死鳥、フェニックス、インド神話に登場するガルーダ等と同一起源とする説や同一視される。");
classifier.learn("白虎",
"白虎(びゃっこ、拼音: báihŭ パイフー)は、中国の伝説上の神獣である四神の1つで、" +
"西方を守護する。白は、五行説では西方の色とされる。");
classifier.learn("玄武",
"玄武は、北方を守護する、水神。" +
"「玄」は「黒」を意味し、黒は五行説では「北方」の色とされ、「水」を表す。");
classifier.learn("黄龍",
"黄竜、中国の伝承五行思想に現れる黄色の竜。黄金に輝く竜であると言う異説もある。" +
"四神の中心的存在、または、四神の長とも呼ばれている。" +
"四神が東西南北の守護獣なのに対し、中央を守るとされる。");
/**
* 適当な説明文で判定させてみる
*/
Map<String, Double> scoreMap
= classifier.judgeCategory("中国の五行思想において四神(四聖獣)の中央に位置するといわれている黄金の龍。");
scoreMap.entrySet().stream()
.sorted( (e1, e2) -> {
return e2.getValue().compareTo(e1.getValue());
}).findFirst().ifPresent(e -> System.out.println(e.getKey()));
scoreMap
= classifier.judgeCategory("陰陽五行説で玄・黒の色とされる北方を守護し、亀の姿であると言われる。");
scoreMap.entrySet().stream()
.sorted( (e1, e2) -> {
return e2.getValue().compareTo(e1.getValue());
}).findFirst().ifPresent(e -> System.out.println(e.getKey()));
scoreMap
= classifier.judgeCategory("鳳凰と同一視されることも多く、その場合様々な鳥獣の混ざり合った、美しい翼を持つ鳥とされる。");
scoreMap.entrySet().stream()
.sorted( (e1, e2) -> {
return e2.getValue().compareTo(e1.getValue());
}).findFirst().ifPresent(e -> System.out.println(e.getKey()));
}
実行結果
黄龍
玄武
朱雀
事前準備
簡単なのでPOMで。
文章から単語を抽出するのにKuromojiを使ってます。
pom.xml(依存性部分のみ)
<dependencies>
<dependency>
<artifactId>lucene-core</artifactId>
<groupId>org.apache.lucene</groupId>
<version>5.1.0</version>
</dependency>
<dependency>
<artifactId>lucene-analyzers-kuromoji</artifactId>
<groupId>org.apache.lucene</groupId>
<version>5.1.0</version>
</dependency>
</dependencies>
実装サンプル
NaiveBayesClassifier.java
/**
* 単純ベイズ分類器を利用
* ① <学習> 事前に投入した分類の割合 と 分類ごとの説明文の単語の頻出度合いを保持
* ② <分類> 上記のデータを元に分類ごとのスコアを算出
* @author ryutaro_hakozaki
*/
public class NaiveBayesClassifier {
/**
* 利用(蓄積)するデータは以下の3つのみ
* ① 説明文に登場した全ての単語の種類数
* ② 投入事前データの分類カウント(何の分類を何回学ばせたか)
* ③ 分類ごとの、説明文に登場した単語の回数カウントマップ
*/
// ①
private final Set<String> wordSet = new HashSet<>();
// ②
private final Map<String, Integer> categoryCountMap = new HashMap<>();
// ③
private final Map<String, Map<String, Integer>> categoryWordCountMap = new HashMap<>();
/**
* 学習はこれだけ。
* ① 登場する単語の種類を追加
* ② 指定された分類のカウントアップ(事前確率)
* ③ 分類別の単語の出現数をカウントアップ
*/
public void learn(String category, String description){
try {
// 説明文から単語を抽出
toWords(description)
.stream().sequential().forEach(w -> {
// ③ 設定済の分類の単語と出現回数マップを取得
Map<String, Integer> wordCountMap =
null == categoryWordCountMap.get(category)
? new HashMap<>() : categoryWordCountMap.get(category);
// ③ 単語マップを更新
if(wordCountMap.containsKey(w)) {
wordCountMap.put(w, wordCountMap.get(w) + 1);
}else{
wordCountMap.put(w, 1);
}
categoryWordCountMap.put(category, wordCountMap);
// ① 登場した単語で新しいものがあれば追加
wordSet.add(w);
});
// ② 分類の学習回数を単純にカウントアップ
if(categoryCountMap.containsKey(category)){
categoryCountMap.put(category, categoryCountMap.get(category) + 1);
} else {
categoryCountMap.put(category, 1);
}
} catch (IOException ex) {
Logger.getLogger(MachineLearningFirstStep.class.getName()).log(Level.SEVERE, null, ex);
}
}
/**
* 説明文から分類を判定
* ※ 数学的に厳密にあっている自信はありません
* ① 事前確率 = 学習データの分類の割合とする
* ② 単語の条件付確率 = 学習データの分類に出てきた各単語の割合の総和
*
* ① * ② がスコアとなる
* ⇒ アンダーフローを防ぐため、対数を取って乗算を和算にしている
* ⇒ ゼロ頻度問題を防ぐためにラプラススムージング(単語の出現回数に+1)を行う
*/
public Map<String, Double> judgeCategory(String description){
Map<String, Double> categoryScoreMap = new HashMap<>();
try {
List<String> words = toWords(description);
categoryCountMap.keySet()
.stream().forEach((category) -> {
// ①を算出
double score
= Math.log((double)categoryCountMap.get(category) / categoryCountMap.values().stream().mapToDouble(d -> d).sum());
// ②を算出
score = words
.stream()
.map((w) -> (categoryWordCountMap.get(category).containsKey(w)) ? categoryWordCountMap.get(category).get(w) : 0)
.map((wordCount) -> {
wordCount++;
return wordCount;
}).map((wordCount) ->
Math.log(
wordCount / (categoryWordCountMap.get(category).values().stream().mapToDouble(d -> d).sum() + wordSet.size())
)
).reduce(score, (accumulator, _item) -> accumulator + _item);
// 格納
categoryScoreMap.put(category, score);
});
} catch (IOException ex) {
Logger.getLogger(MachineLearningFirstStep.class.getName()).log(Level.SEVERE, null, ex);
}
return categoryScoreMap;
}
/**
* ドキュメントから単語を抽出
* 本筋ではないので細かい説明は省略
*/
private List<String> toWords(String doc) throws IOException {
List<String> ret = new ArrayList<>();
AttributeFactory factory = AttributeFactory.DEFAULT_ATTRIBUTE_FACTORY;
try(JapaneseTokenizer ts = new JapaneseTokenizer(factory, null, false, JapaneseTokenizer.DEFAULT_MODE)){
ts.setReader(new StringReader(doc.toLowerCase()));
PartOfSpeechAttribute ps = ts.addAttribute(PartOfSpeechAttribute.class);
CharTermAttribute ct = ts.addAttribute(CharTermAttribute.class);
ts.reset();
while(ts.incrementToken()){
if(ps.getPartOfSpeech().startsWith("名詞")) ret.add(ct.toString());
}
} catch (IOException e) {
}
return ret;
}
}