Javaで機械学習したくない?
機械学習といえばPythonって風潮ですが、Javaが得意な人は、今更Python学ぶのかなという気持ちになるでしょう。できれば、得意なJavaで機械学習できたら嬉しいですね。自分が開発したソフトに組み込めるし。
そんなことで、機械学習エンジンのWekaのAPIで色々と試してみようということで書き始めたいと思います。(間違いがあったらご指摘お願いします)
まずは、数値データの推測からはじめます。
こちらの書籍(機械学習コレクション Weka入門) https://www.amazon.co.jp/%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%82%B3%E3%83%AC%E3%82%AF%E3%82%B7%E3%83%A7%E3%83%B3-Weka%E5%85%A5%E9%96%80-I%E3%83%BB-BOOKS-%E5%92%8C%E7%94%B0/dp/4777520889 に出ているサンプルをAPIを使ってやってみることにしましょう。
CSVデータは以下の通り。末尾から2行目の売上が「?」となっています。これを推測してみます。test.csvとでもしましょうか。
弁当,レジ前,スイーツ類,コーヒー類,その他,売上
8,8,9,8,8,55
8,7,8,9,8,57
7,6,7,6,6,43
8,7,8,7,6,47
4,6,5,6,4,36
4,5,6,5,5,37
6,6,7,6,4,39
5,6,7,6,5,40
5,6,7,6,5,40
5,6,7,6,5,40
6,6,7,7,5,42
6,7,8,8,7,46
8,7,7,8,8,48
9,7,8,7,8,52
3,4,4,4,5,30
3,6,4,4,4,33
3,5,6,4,3,34
3,4,6,4,4,?
10,9,10,9,7,60
mini-weka
Wekaをそのまま利用してもいいのですが、GUIを使わないならば、mini-wekaでもOKだと思います。( https://github.com/fracpete/mini-weka )
APIを利用してみる
package example20220110;
import java.io.File;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.CSV;
import weka.classifiers.evaluation.output.prediction.HTML;
import weka.classifiers.evaluation.output.prediction.PlainText;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
public class Main {
public static void main(String[] argv) throws Exception {
Main myWeka = new Main();
myWeka.run();
}
void run() throws IOException, Exception {
CSVLoader loader = new CSVLoader();
loader.setSource(new File("test.csv"));
Instances tr = loader.getDataSet();
// 多層パーセプトロン(ニューラルネットワークを利用)
MultilayerPerceptron mp = new MultilayerPerceptron();
// ランダムフォレストにしたければ以下のようにする
// RandomForest mp = new RandomForest();
// デフォルトに近いオプション
mp.setOptions(new String[]{"-L", "0.3", "-M", "0.2", "-N", "500", "-V", "0", "-S", "0", "-E", "20", "-H", "a"});
PrintPredict pp = new PrintPredict();
System.out.println("---PLAIN---");
System.out.println(pp.getPlainText(mp, tr));
System.out.println("---CSV---");
System.out.println(pp.getCSV(mp, tr));
}
class PrintPredict {
private String Output(Classifier cls, Instances instances, AbstractOutput output) {
try {
instances.setClassIndex(instances.numAttributes() - 1);
cls.buildClassifier(instances);
StringBuffer writer = new StringBuffer();
output.setBuffer(writer);
output.setHeader(instances);
output.printClassifications(cls, instances);
return writer.toString();
} catch (Exception e) {
return e.getMessage();
}
}
String getPlainText(Classifier cls, Instances instances) {
PlainText output = new PlainText();
return Output(cls, instances, output);
}
String getCSV(Classifier cls, Instances instances) {
CSV output = new CSV();
return Output(cls, instances, output);
}
String getHTML(Classifier cls, Instances instances) {
HTML output = new HTML();
return Output(cls, instances, output);
}
}
}
出力
---PLAIN---
1 55 53.92 -1.08
2 57 55.332 -1.668
3 43 41.299 -1.701
4 47 45.9 -1.1
5 36 34.783 -1.217
6 37 34.776 -2.224
7 39 39.444 0.444
8 40 38.792 -1.208
9 40 38.792 -1.208
10 40 38.792 -1.208
11 42 40.724 -1.276
12 46 46.396 0.396
13 48 47.158 -0.842
14 52 51.856 -0.144
15 30 29.172 -0.828
16 33 31.837 -1.163
17 34 34.147 0.147
18 ? 32.867 ?
19 60 58.724 -1.276
---CSV---
1,55,53.92,-1.08
2,57,55.332,-1.668
3,43,41.299,-1.701
4,47,45.9,-1.1
5,36,34.783,-1.217
6,37,34.776,-2.224
7,39,39.444,0.444
8,40,38.792,-1.208
9,40,38.792,-1.208
10,40,38.792,-1.208
11,42,40.724,-1.276
12,46,46.396,0.396
13,48,47.158,-0.842
14,52,51.856,-0.144
15,30,29.172,-0.828
16,33,31.837,-1.163
17,34,34.147,0.147
18,?,32.867,?
19,60,58.724,-1.276
32.867と推測されたようです。