LoginSignup
0
1

More than 3 years have passed since last update.

Oracleから公開されたTribuoをやってみた。Tribuo - A Java prediction library (v4.0)

Last updated at Posted at 2020-09-26

実際にやってみた。
結論から言うととっても不安になった。あと、フレームワークを名乗るなら、英語でいいからjavadocはちゃんと書け。
それから、マニュアル直すのではなく、実装コードをきちんと直したほうがいいぞ。

セットアップ

依存性の解決にはmavenを使用しました。プロジェクト作成後、下記のようにtribuoをpomファイルに設定します。

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.0.0</version>
    <type>pom</type>
</dependency>

ロジスティック回帰モデルを学習し、評価するには下記のようにすればよいと書かれているため、下記のコードをjavaクラスに記述します。

var trainSet = new MutableDataset<>(new LibSVMDataSource("train-data",new LabelFactory()));
var model    = new LogisticRegressionTrainer().train(trainSet);
var eval     = new LabelEvaluator().evaluate(new LibSVMDataSource("test-data",trainSet.getOutputFactory()));

クラス名はTribuoSampleとしました。
は?ナニコレ?

コンストラクタがないエラー01.png

LibSVMDataSourceのコンストラクタ部分です。Stringを引数にとるコンストラクタはないです。

LibSVMDataSourceコンストラクタ.png

しかもこのjava.nio.file.Pathインターフェースって?
java.nio.file.Pathsのほうで実装ではないということは・・・。
もしかして、Path.ofとスタティックにアクセスしているからモダンなのか?

        // こう書けって言ってるの?
        var trainSet = new MutableDataset<>(new LibSVMDataSource<>(Path.of(new URL("file:train-data").getPath()),new LabelFactory()));

        // それともこう?
        var trainSet = new MutableDataset<>(new LibSVMDataSource<>(Paths.get("train-data"),new LabelFactory()));

        // だったらURLでいいんでない?冗長だし。
        var trainSet = new MutableDataset<>(new LibSVMDataSource<>(new URL("file:train-data"),new LabelFactory()));

マニュアル直すか、このクラス本体でStringを受け付けるかになるが、MutableDataset.classを直すべきでしょうね。こんな感じでとりあえずできるでしょうし。

MutableDataset

    public LibSVMDataSource(String url, OutputFactory<T> outputFactory) throws IOException {
        this(null,new URL(url),outputFactory,false,false,0);
    }

    public LibSVMDataSource(String url, OutputFactory<T> outputFactory, boolean zeroIndexed, int maxFeatureID) throws IOException {
        this(null,new URL(url),outputFactory,true,zeroIndexed,maxFeatureID);
    }

ただ、コンストラクタの処理で、urlも、pathもチェックしてないんですよね。
下記がMutableDataset.classの該当コード。

LibSVMDataSource
    private LibSVMDataSource(Path path, URL url, OutputFactory<T> outputFactory, boolean rangeSet, boolean zeroIndexed, int maxFeatureID) throws IOException {
        this.outputFactory = outputFactory;
        this.path = path;
        this.url = url;
        this.rangeSet = rangeSet;
        if (rangeSet) {
            this.zeroIndexed = zeroIndexed;
            this.minFeatureID = zeroIndexed ? 0 : 1;
            if (maxFeatureID < minFeatureID + 1) {
                throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID);
            }
            this.maxFeatureID = maxFeatureID;
        }
        read();
    }

どこまで行っちゃうのかっていうと、
コンストラクタから呼ばれるLibSVMDataSource#readの中。
下記が該当のコード。

LibSVMDataSource#read

    private void read() throws IOException {
        int pos = 0;
        ArrayList<HashMap<Integer,Double>> processedData = new ArrayList<>();
        ArrayList<String> labels = new ArrayList<>();

        // Idiom copied from Files.readAllLines,
        // but this doesn't require keeping the whole file in RAM.
        String line;
        // Parse the libsvm file, ignoring malformed lines.
        try (BufferedReader r = new BufferedReader(new InputStreamReader(url.openStream(),StandardCharsets.UTF_8))) {
            for (;;) {
                line = r.readLine();
                if (line == null) {
                    break;
                }
                pos++;
                String[] fields = splitPattern.split(line);
                try {
                    boolean valid = true;
                    HashMap<Integer, Double> features = new HashMap<>();
                    for (int i = 1; i < fields.length && valid; i++) {
                        int ind = fields[i].indexOf(':');
                        if (ind < 0) {
                            logger.warning(String.format("Weird line at %d", pos));
                            valid = false;
                        }
                        String ids = fields[i].substring(0, ind);
                        int id = Integer.parseInt(ids);
                        if ((!rangeSet) && (maxFeatureID < id)) {
                            maxFeatureID = id;
                        }
                        if ((!rangeSet) && (minFeatureID > id)) {
                            minFeatureID = id;
                        }
                        double val = Double.parseDouble(fields[i].substring(ind + 1));
                        Double value = features.put(id, val);
                        if (value != null) {
                            logger.warning(String.format("Repeated features at line %d", pos));
                            valid = false;
                        }
                    }
                    if (valid) {
                        // Store the label
                        labels.add(fields[0]);
                        // Store the features
                        processedData.add(features);
                    } else {
                        throw new IOException("Invalid LibSVM format file");
                    }
                } catch (NumberFormatException ex) {
                    logger.warning(String.format("Weird line at %d", pos));
                    throw new IOException("Invalid LibSVM format file", ex);
                }
            }
        }

tryの中でurl.openStream()ってやっちゃってるし。キャッチするのNumberFormatExceptionだけだしなぁ。
メンバ変数のdescriptionを見ると、urlかpathのどちらかが必須になっているけど。

LibSVMDataSource
    // url is the store of record.
    @Config(description="URL to load the data from. Either this or path must be set.")
    private URL url;

    @Config(description="Path to load the data from. Either this or url must be set.")
    private Path path;

LibSVMDataSource#postConfigで両方nullだったらのチェックはしているけど、これじゃダメじゃん。

LibSVMDataSource#postConfig
    @Override
    public void postConfig() throws IOException {
        if (maxFeatureID != Integer.MIN_VALUE) {
            rangeSet = true;
            minFeatureID = zeroIndexed ? 0 : 1;
            if (maxFeatureID < minFeatureID + 1) {
                throw new IllegalArgumentException("maxFeatureID must be positive, found " + maxFeatureID);
            }
        }
        if ((url == null) && (path == null)) {
            throw new PropertyException("","path","At most one of url and path must be set.");
        } else if ((url != null) && (path != null) && !path.toUri().toURL().equals(url)) {
            throw new PropertyException("","path","At most one of url and path must be set");
        } else if (path != null) {
            // url is the store of record.
            try {
                url = path.toUri().toURL();
            } catch (MalformedURLException e) {
                throw new PropertyException(e,"","path","Path was not a valid URL");
            }
        }
        read();
    }

こんなコード書いたら誰も処理しないよね。

TribuoSample
public class TribuoSample {

    /**
     * @param args mainメソッドの引数。
     */
    public static void main(String[] args) {

        URL url = null;

        try {
            var trainSet = new MutableDataset<>(
                    new LibSVMDataSource<>(url, new LabelFactory()));
        } catch (IOException e) {
            // TODO 自動生成された catch ブロック
            e.printStackTrace();
        }
    }
}

これを実行すると・・・。

StackTrace
Exception in thread "main" java.lang.NullPointerException
    at org.tribuo.datasource.LibSVMDataSource.read(LibSVMDataSource.java:204)
    at org.tribuo.datasource.LibSVMDataSource.<init>(LibSVMDataSource.java:125)
    at org.tribuo.datasource.LibSVMDataSource.<init>(LibSVMDataSource.java:105)
    at org.project.eden.adam.TribuoSample.main(TribuoSample.java:28)

この場合、利用者に伝えたいメッセージは、「urlかpathのいずれかが設定されていることは必須の項目なんだけど、
あなたがurlに設定した値にはnullが設定されていたよ。」ってことを伝えなければならないはずなんだけど、想定していないエラーだかスタックトレース吐き出して処理が止まってしまうわけだ。業務アプリならまだしも、oracleの名前でだしたフレームワークなのだから、このような落ち方はいかがなものかと思う。pathの場合には、コンストラクタでpathオブジェクトにアクセスしちゃうから、その場でヌルポだしね。

これがその実装

LibSVMDataSource
    public LibSVMDataSource(Path path, OutputFactory<T> outputFactory) throws IOException {
        this(path,path.toUri().toURL(),outputFactory,false,false,0);
    }

サンプルを下記のようにして実行。

TribuoSample
public class TribuoSample {

    /**
     * @param args mainメソッドの引数。
     */
    public static void main(String[] args) {

        Path path = null;

        try {
            var trainSet = new MutableDataset<>(
                    new LibSVMDataSource<>(path, new LabelFactory()));

        } catch (IOException e) {
            // TODO 自動生成された catch ブロック
            //e.printStackTrace();
        }
    }
}

結果は見るまでもなく、NullPointerException。

Exception in thread "main" java.lang.NullPointerException
    at org.tribuo.datasource.LibSVMDataSource.<init>(LibSVMDataSource.java:97)
    at org.project.eden.adam.TribuoSample.main(TribuoSample.java:28)

ドキュメントのトップページに出てくる、1行目からこんなになってるとは思わなかった。
サンプル実行は別のページで。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1