LoginSignup
0
1

More than 1 year has passed since last update.

【Java初心者】Wekaによる機械学習 01-数値データの推測

Last updated at Posted at 2022-01-17

Javaで機械学習したくない?

機械学習といえばPythonって風潮ですが、Javaが得意な人は、今更Python学ぶのかなという気持ちになるでしょう。できれば、得意なJavaで機械学習できたら嬉しいですね。自分が開発したソフトに組み込めるし。

そんなことで、機械学習エンジンのWekaのAPIで色々と試してみようということで書き始めたいと思います。(間違いがあったらご指摘お願いします)

まずは、数値データの推測からはじめます。

こちらの書籍(機械学習コレクション Weka入門) https://www.amazon.co.jp/%E6%A9%9F%E6%A2%B0%E5%AD%A6%E7%BF%92%E3%82%B3%E3%83%AC%E3%82%AF%E3%82%B7%E3%83%A7%E3%83%B3-Weka%E5%85%A5%E9%96%80-I%E3%83%BB-BOOKS-%E5%92%8C%E7%94%B0/dp/4777520889 に出ているサンプルをAPIを使ってやってみることにしましょう。

CSVデータは以下の通り。末尾から2行目の売上が「?」となっています。これを推測してみます。test.csvとでもしましょうか。

弁当,レジ前,スイーツ類,コーヒー類,その他,売上
8,8,9,8,8,55
8,7,8,9,8,57
7,6,7,6,6,43
8,7,8,7,6,47
4,6,5,6,4,36
4,5,6,5,5,37
6,6,7,6,4,39
5,6,7,6,5,40
5,6,7,6,5,40
5,6,7,6,5,40
6,6,7,7,5,42
6,7,8,8,7,46
8,7,7,8,8,48
9,7,8,7,8,52
3,4,4,4,5,30
3,6,4,4,4,33
3,5,6,4,3,34
3,4,6,4,4,?
10,9,10,9,7,60

mini-weka

Wekaをそのまま利用してもいいのですが、GUIを使わないならば、mini-wekaでもOKだと思います。( https://github.com/fracpete/mini-weka )

APIを利用してみる

package example20220110;

import java.io.File;
import java.io.IOException;
import weka.classifiers.Classifier;
import weka.classifiers.evaluation.output.prediction.AbstractOutput;
import weka.classifiers.evaluation.output.prediction.CSV;
import weka.classifiers.evaluation.output.prediction.HTML;
import weka.classifiers.evaluation.output.prediction.PlainText;
import weka.classifiers.functions.MultilayerPerceptron;
import weka.core.Instances;
import weka.core.converters.CSVLoader;

public class Main {

    public static void main(String[] argv) throws Exception {
        Main myWeka = new Main();
        myWeka.run();
    }

    void run() throws IOException, Exception {
        CSVLoader loader = new CSVLoader();
        loader.setSource(new File("test.csv"));
        Instances tr = loader.getDataSet();
        // 多層パーセプトロン(ニューラルネットワークを利用)
        MultilayerPerceptron mp = new MultilayerPerceptron();

        // ランダムフォレストにしたければ以下のようにする
        // RandomForest mp = new RandomForest();

        // デフォルトに近いオプション
        mp.setOptions(new String[]{"-L", "0.3", "-M", "0.2", "-N", "500", "-V", "0", "-S", "0", "-E", "20", "-H", "a"});
        PrintPredict pp = new PrintPredict();
        System.out.println("---PLAIN---");
        System.out.println(pp.getPlainText(mp, tr));
        System.out.println("---CSV---");
        System.out.println(pp.getCSV(mp, tr));
    }

    class PrintPredict {

        private String Output(Classifier cls, Instances instances, AbstractOutput output) {
            try {
                instances.setClassIndex(instances.numAttributes() - 1);
                cls.buildClassifier(instances);
                StringBuffer writer = new StringBuffer();
                output.setBuffer(writer);
                output.setHeader(instances);
                output.printClassifications(cls, instances);
                return writer.toString();
            } catch (Exception e) {
                return e.getMessage();
            }
        }

        String getPlainText(Classifier cls, Instances instances) {
            PlainText output = new PlainText();
            return Output(cls, instances, output);
        }

        String getCSV(Classifier cls, Instances instances) {
            CSV output = new CSV();
            return Output(cls, instances, output);
        }

        String getHTML(Classifier cls, Instances instances) {
            HTML output = new HTML();
            return Output(cls, instances, output);
        }

    }
}

出力

---PLAIN---
        1     55         53.92      -1.08  
        2     57         55.332     -1.668 
        3     43         41.299     -1.701 
        4     47         45.9       -1.1   
        5     36         34.783     -1.217 
        6     37         34.776     -2.224 
        7     39         39.444      0.444 
        8     40         38.792     -1.208 
        9     40         38.792     -1.208 
       10     40         38.792     -1.208 
       11     42         40.724     -1.276 
       12     46         46.396      0.396 
       13     48         47.158     -0.842 
       14     52         51.856     -0.144 
       15     30         29.172     -0.828 
       16     33         31.837     -1.163 
       17     34         34.147      0.147 
       18          ?     32.867          ? 
       19     60         58.724     -1.276 

---CSV---
1,55,53.92,-1.08
2,57,55.332,-1.668
3,43,41.299,-1.701
4,47,45.9,-1.1
5,36,34.783,-1.217
6,37,34.776,-2.224
7,39,39.444,0.444
8,40,38.792,-1.208
9,40,38.792,-1.208
10,40,38.792,-1.208
11,42,40.724,-1.276
12,46,46.396,0.396
13,48,47.158,-0.842
14,52,51.856,-0.144
15,30,29.172,-0.828
16,33,31.837,-1.163
17,34,34.147,0.147
18,?,32.867,?
19,60,58.724,-1.276

32.867と推測されたようです。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1