5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Oracleから公開されたTribuoを気を取り直してやってみたら中の人は熱い感じの人だった

Last updated at Posted at 2020-09-29

##前置き
リアルタイムでいろいろ更新されているため、情報整理のために、再度チュートリアルの内容をまとめました。本記事はこちらの内容を元に書いたものです。

##資料について
元ネタはこちらを参照してください。
日本語が必要ならこちらを参照してください。
アヤメデータのダウンロードはこちらを参照してください。
ソースコードはこちらを参照してください。
Java Docはこちらを参照してください。

##変更点
CSVデータロードの不具合は修正されています。v4.0.1 ※

CSVReaderで、末尾に余計な改行があるファイルを読み込めない問題を修正しました。 IDX(すなわちMNIST)フォーマットのデータセットを読めるようにIDXDataSourceを追加しました。 libsvmファイルではなくIDXファイルからMNISTを読み込むように設定チュートリアルを更新しました。

中の人の話では、ドキュメントにv4.0.0て書いてるの見落としてたわー、直しておいたわーって言ってました。

一応気が付いているので書いておくと、Introductionに出てくる下のコードの3行目は、第一引数にmodelが必要です。

1. var trainSet = new MutableDataset<>(new LibSVMDataSource(Paths.get("train-data"),new LabelFactory()));
2. var model    = new LogisticRegressionTrainer().train(trainSet);
3. var eval     = new LabelEvaluator().evaluate(new LibSVMDataSource(Paths.get("test-data"),trainSet.getOutputFactory()));
var eval = new LabelEvaluator()
                .evaluate(model, new LibSVMDataSource(Paths.get("test-data"), trainSet.getOutputFactory()));

追記 2020/09/30
直してコミットしておいたって返事来てました。

##その他のネタ
あれとかこれとかgitのIssuesにあげてみたのですが、まじめに回答されてしまいました。ネタ記事のつもりだったのですが、真剣に付き合っていただきありがとうございました。

中の人の話では、Tribuoの公式ページは、もう直したわーとのことです。nullチェックとjavadocは今週直しておくわーとのことです。対応が早い。

この中の人Tensorflow javaのコミッターとのこと。本人が言ってました。

FP16でGPUを使った計算については、JVMはすでにaarch64をサポートしてるから、Tensorflow JavaとONNX RuntimeのJava APIが、すぐ対応すはずだよ、Tribuoもこいつら追いかけるぜーって中の人。が言ってました。現状でも、Tensorflow JavaかONNX Runtime使えばできるぜーって言ってました。

熱量すごいなこの人たち。私は日本でこんなことしたい。
OpenAIが全然オープンじゃないし。

##セットアップ(Getting Started)
前回からバージョン変更しています。

mavenでは下記のように設定します。

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.0.1</version>
    <type>pom</type>
</dependency>

gradleではドキュメント記載の方法ではうまくjarがダウンロード出来ません。pomファイルだけダウンロードされます。
下記のように設定する必要があります。

    api ('org.tribuo:tribuo-all:4.0.1@pom') {
        transitive = true
    }

kotlinだとまた違う動作するらしいです。ドキュメントにタブ追加して、書いておくけど、gradleのために、わざわざタブ作ってGroovyとkotlinで2重に記載するのは嫌だとお怒りです。

アヤメのデータを取得します。

アヤメデータのダウンロード先
https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

通常のmainメソッドを持つクラスを作成して、Getting Startedの通りに実装を行います。

SampleTribuo
package org.project.eden.adam;

import java.io.IOException;
import java.nio.file.Paths;

import org.tribuo.DataSource;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.dtree.CARTClassificationTrainer;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.evaluation.TrainTestSplitter;

/**
 * @author jashika
 *
 */
public class SampleTribuo {

    /**
     * @param args mainメソッドの引数。
     * @throws IOException ファイルの読み込みエラー時にスローされる。
     */
    public static void main(String[] args) throws IOException {

        // ラベル付きアヤメデータを読み込む
        var irisHeaders = new String[] { "sepalLength", "sepalWidth", "petalLength", "petalWidth", "species" };
        DataSource<Label> irisData = new CSVLoader<>(new LabelFactory()).loadDataSource(Paths.get("bezdekIris.data"),
                irisHeaders[4], irisHeaders);
        // ※読み込むデータのパスが埋め込まれているため、各自の環境に合わせてください。

        // 読み込んだアヤメデータをトレーニング用に70%、テスト用に30%に分割
        var splitIrisData = new TrainTestSplitter<>(irisData, 0.7, 1L);
        var trainData = new MutableDataset<>(splitIrisData.getTrain());
        var testData = new MutableDataset<>(splitIrisData.getTest());

        // 決定木学習を使用することができる 
        var cartTrainer = new CARTClassificationTrainer();
        Model<Label> tree = cartTrainer.train(trainData);

        // ロジスティック回帰を使用することもできる
        var linearTrainer = new LogisticRegressionTrainer();
        Model<Label> linear = linearTrainer.train(trainData);

        // 最終的には、未知のデータから予測を行う
        // 予測は、出力名(すなわちラベル)と、スコア/確率となる。
        Prediction<Label> prediction = linear.predict(testData.getExample(0));

        // 完全なテストデータセットを評価して、精度、F1などを計算してもよい。
        LabelEvaluation evaluation = new LabelEvaluator().evaluate(linear, testData);

        // 手動で評価を検査することもできる
        double acc = evaluation.accuracy();

        // 0.978を返す
        // フォーマットされた評価文字列を表示する。
        System.out.println(evaluation.toString());
    }
}

SampleTribuoという名前で、1クラスを作ったのみです。
そのまま実行します。

実行結果
9月 29, 2020 5:12:51 午後 org.tribuo.data.csv.CSVIterator getRow
警告: Ignoring extra newline at line 151
9月 29, 2020 5:12:51 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Training SGD classifier with 105 examples
9月 29, 2020 5:12:51 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Labels - (0,Iris-versicolor,34), (1,Iris-virginica,35), (2,Iris-setosa,36)
Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

#各項目の説明

作成したSampleTribuoのコードの中で、下記の部分が、アヤメデータのヘッダー部分になります。

// ラベル付きアヤメデータを読み込む
var irisHeaders = new String[] { "sepalLength", "sepalWidth", "petalLength", "petalWidth", "species" };

ヘッダーの名称があらわしているのは、下記のようになります。

・sepalLength = がく(へた)の長さ
・sepalWidth = がく(へた)の幅
・petalLength = 花びらの長さ
・petalWidth = 花びらの幅
・species = 種類

またアヤメデータの種類に着目すると、下記の3種類のデータが含まれているのがわかります。下記はダウンロードしたアヤメデータから抜粋したものです。

・5.8,4.0,1.2,0.2,Iris-setosa
・6.7,2.5,5.8,1.8,Iris-virginica
・5.5,2.6,4.4,1.2,Iris-versicolor

種類については下記のようになります。アヤメの種類らしいです。

・Iris setosa = ヒオウギアヤメ
・Iris versicolor = バージカラー
・Iris virginica = バージニカ(コバノズイナ?)

Tribuoのサンプルコードで、一番最後にsystem.out.printlnで出力しているのは、LabelEvaluationクラスです。そして、LabelEvaluationクラスは、linearTrainerの学習済みModelと、テストデータを受け取っています。下記が該当箇所となります。前項で既出の出力結果はデータセットの評価結果となります。

SampleTribuo

                       中略

// ロジスティック回帰を使用することもできる
var linearTrainer = new LogisticRegressionTrainer();
Model<Label> linear = linearTrainer.train(trainData);

// 完全なテストデータセットを評価して、精度、F1などを計算してもよい。
LabelEvaluation evaluation = new LabelEvaluator().evaluate(linear, testData);

                       中略

// フォーマットされた評価文字列を表示する。
System.out.println(evaluation.toString());

LabelEvaluationクラスのJava Docをのぞいてみると、下記のような内容が書かれていました。個別に出力させたり、HTMLで出力させたるする機能が実装されています。

JavaDocのLabelEvaluation項目から抜粋
double	accuracy()     評価の総合的な精度。

double	accuracy(Label label)  評価のラベルあたりの精度。

double	AUCROC(Label label) ROC曲線の下の面積。

double	averageAUCROC(boolean weighted) ラベル間で平均化されたROC曲線の下の面積。

double	averagedPrecision(Label label) 所定の閾値での精度の加重平均を取ることで、精度-リコール曲線を要約し、重みはその閾値で達成されたリコールを表します。

LabelEvaluationUtil.PRCurve	precisionRecallCurve(Label label) 単一ラベルの精度リコール曲線を計算します。

static String	toFormattedString(LabelEvaluation evaluation) このメソッドは、ターミナルでの表示に適した、適切なタブと改行付きのきれいにフォーマットされた文字列出力を生成します。

default String	toHTML() この評価を表すHTML形式の文字列を返します。

static String	toHTML(LabelEvaluation evaluation) このメソッドは、Webページへの組み込みに適した、適切なタブと改行付きのHTMLフォーマットの文字列出力を生成します。


static String	toFormattedString(LabelEvaluation evaluation) このメソッドは、ターミナルでの表示に適した、適切なタブと改行付きのきれいにフォーマットされた文字列出力を生成します。

static String	toHTML(LabelEvaluation evaluation) このメソッドは、Webページへの組み込みに適した、適切なタブと改行付きのHTMLフォーマットの文字列出力を生成します。

下記の部分が未知のデータの予測を行っている部分になります。とはいっても、未知のデータはもっていないので、分割された30%分のデータですが。


                       中略

// 最終的には、未知のデータから予測を行う。
// 予測は、出力名(すなわちラベル)と、スコア/確率となる。
Prediction<Label> prediction = linear.predict(testData.getExample(0));

// 追加しました。
// 未知のデータ(テストデータ)の内容を出力する。
System.out.println("入力するテストデータ: " + testData.getExample(0));

// Modelから予測データを出力する。
System.out.println("予測結果: " + prediction.getOutput());
                       中略

元ネタをそのまま実装していますが、テストデータの1件目を予測して、そのままなにもしていませんので、こちらを出力してみます。

9月 29, 2020 9:34:34 午後 org.tribuo.data.csv.CSVIterator getRow
警告: Ignoring extra newline at line 151
9月 29, 2020 9:34:34 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Training SGD classifier with 105 examples
9月 29, 2020 9:34:34 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Labels - (0,Iris-versicolor,34), (1,Iris-virginica,35), (2,Iris-setosa,36)
入力するテストデータ: ArrayExample(numFeatures=4,output=Iris-setosa,weight=1.0,features=[(petalLength, 1.3)(petalWidth, 0.3), (sepalLength, 4.5), (sepalWidth, 2.3), ])
予測結果: (Iris-setosa,0.8752021370239595)

入力されたデータは、ヒオウギアヤメ(Iris-setosa)のデータでそれぞれ、がくの長さ(sepalLength)は4.5、がくの幅は(sepalWidth)は2.3、花びらの長さ(petalLength)は1.3、花びらの幅(petalWidth)は0.3となります。

それに対して予測結果は87%の確立で、ヒオウギアヤメ(Iris-setosa)と予測されました。当たってますね。

これだけでは、あまり面白くないので、30%のデータを予測してみます。下記のようにコードを変更しました。

SampleTribuo
                       中略
for (var example : testData) {
    // 正解
    var correctAnswer = example.getOutput().getLabel();
    // 予測結果
    var predictResult = linear.predict(example);

    if (correctAnswer.equals(predictResult.getOutput().getLabel())) {
        System.out.println("予測結果->正解: " + predictResult.getOutputScores());
    } else {
        System.out.println("予測結果->不正解: " + predictResult.getOutputScores() + "正解->" + correctAnswer);
    }
}
                       中略
実行結果
9月 29, 2020 10:35:18 午後 org.tribuo.data.csv.CSVIterator getRow
警告: Ignoring extra newline at line 151
9月 29, 2020 10:35:18 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Training SGD classifier with 105 examples
9月 29, 2020 10:35:18 午後 org.tribuo.classification.sgd.linear.LinearSGDTrainer train
情報: Labels - (0,Iris-versicolor,34), (1,Iris-virginica,35), (2,Iris-setosa,36)
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.12465430241942166), Iris-virginica=(Iris-virginica,1.4356055661867684E-4), Iris-setosa=(Iris-setosa,0.8752021370239595)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.022985338660181973), Iris-virginica=(Iris-virginica,4.886599414569194E-6), Iris-setosa=(Iris-setosa,0.9770097747404035)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8743791900480306), Iris-virginica=(Iris-virginica,0.11972594367759447), Iris-setosa=(Iris-setosa,0.005894866274374883)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.13752628567026762), Iris-virginica=(Iris-virginica,0.8624557718013333), Iris-setosa=(Iris-setosa,1.794252839916067E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.754662316591029), Iris-virginica=(Iris-virginica,0.24288192724655125), Iris-setosa=(Iris-setosa,0.00245575616241976)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.1114591917841135), Iris-virginica=(Iris-virginica,0.8885339140853985), Iris-setosa=(Iris-setosa,6.89413048798891E-6)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.018368556099769975), Iris-virginica=(Iris-virginica,2.8979851163953406E-6), Iris-setosa=(Iris-setosa,0.9816285459151136)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8336662376434194), Iris-virginica=(Iris-virginica,0.16052879318949909), Iris-setosa=(Iris-setosa,0.005804969167081375)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.7401640849157867), Iris-virginica=(Iris-virginica,0.25814459323929695), Iris-setosa=(Iris-setosa,0.0016913218449164506)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.37734213887670864), Iris-virginica=(Iris-virginica,0.6225430896188512), Iris-setosa=(Iris-setosa,1.1477150444008675E-4)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.08165259920209698), Iris-virginica=(Iris-virginica,0.9183413054167227), Iris-setosa=(Iris-setosa,6.095381180367614E-6)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.9166451164592634), Iris-virginica=(Iris-virginica,0.0710098849793183), Iris-setosa=(Iris-setosa,0.012344998561418293)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.12496197859394245), Iris-virginica=(Iris-virginica,0.8750207513327715), Iris-setosa=(Iris-setosa,1.7270073286161786E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.08832616235771579), Iris-virginica=(Iris-virginica,0.9116697571921494), Iris-setosa=(Iris-setosa,4.080450134879567E-6)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.7606313121553816), Iris-virginica=(Iris-virginica,0.23595375784889286), Iris-setosa=(Iris-setosa,0.0034149299957255148)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.2838330052617929), Iris-virginica=(Iris-virginica,0.7160953733169281), Iris-setosa=(Iris-setosa,7.162142127894781E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8962905028787402), Iris-virginica=(Iris-virginica,0.076593124958417), Iris-setosa=(Iris-setosa,0.027116372162842753)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8781294741330369), Iris-virginica=(Iris-virginica,0.11263604829308826), Iris-setosa=(Iris-setosa,0.009234477573874762)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.02598246093746204), Iris-virginica=(Iris-virginica,9.258026355622071E-6), Iris-setosa=(Iris-setosa,0.9740082810361823)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.4125332441403861), Iris-virginica=(Iris-virginica,0.5874113240653356), Iris-setosa=(Iris-setosa,5.54317942782377E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.005480793710204649), Iris-virginica=(Iris-virginica,3.36057635603426E-7), Iris-setosa=(Iris-setosa,0.9945188702321598)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.7157983219006878), Iris-virginica=(Iris-virginica,0.28236588916463473), Iris-setosa=(Iris-setosa,0.0018357889346774156)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.041785356794614356), Iris-virginica=(Iris-virginica,1.63804767321482E-5), Iris-setosa=(Iris-setosa,0.9581982627286535)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.03557299656267042), Iris-virginica=(Iris-virginica,1.572396823310113E-5), Iris-setosa=(Iris-setosa,0.9644112794690964)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.7445227755500248), Iris-virginica=(Iris-virginica,0.25336879249494487), Iris-setosa=(Iris-setosa,0.0021084319550302984)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.23201927045744059), Iris-virginica=(Iris-virginica,0.7679353747675968), Iris-setosa=(Iris-setosa,4.535477496262948E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.6136153688335659), Iris-virginica=(Iris-virginica,0.38487059274540525), Iris-setosa=(Iris-setosa,0.0015140384210286877)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.0037281020760574794), Iris-virginica=(Iris-virginica,1.291282392314054E-7), Iris-setosa=(Iris-setosa,0.9962717687957032)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.04755707288123218), Iris-virginica=(Iris-virginica,1.934947705107781E-5), Iris-setosa=(Iris-setosa,0.9524235776417168)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.19993977432885074), Iris-virginica=(Iris-virginica,0.8000350473542086), Iris-setosa=(Iris-setosa,2.5178316940654208E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.04523347766974972), Iris-virginica=(Iris-virginica,2.000119721397452E-5), Iris-setosa=(Iris-setosa,0.9547465211330363)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.029794358762199272), Iris-virginica=(Iris-virginica,8.7737002425928E-6), Iris-setosa=(Iris-setosa,0.9701968675375582)}
予測結果->不正解: {Iris-versicolor=(Iris-versicolor,0.5733177742115076), Iris-virginica=(Iris-virginica,0.42626475284848847), Iris-setosa=(Iris-setosa,4.1747294000396004E-4)}正解 -> Iris-virginica
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.029877036683050945), Iris-virginica=(Iris-virginica,7.161847786201107E-6), Iris-setosa=(Iris-setosa,0.970115801469163)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.7821104478321824), Iris-virginica=(Iris-virginica,0.21391875776045305), Iris-setosa=(Iris-setosa,0.003970794407364584)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.019534516466634153), Iris-virginica=(Iris-virginica,5.375894384912967E-6), Iris-setosa=(Iris-setosa,0.9804601076389811)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.05383775370019657), Iris-virginica=(Iris-virginica,0.9461587580877066), Iris-setosa=(Iris-setosa,3.4882120966764877E-6)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.12393648653270413), Iris-virginica=(Iris-virginica,0.8760476481453248), Iris-setosa=(Iris-setosa,1.58653219710635E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8991747020299696), Iris-virginica=(Iris-virginica,0.09631094366524456), Iris-setosa=(Iris-setosa,0.004514354304785789)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.8925507000636339), Iris-virginica=(Iris-virginica,0.10101559809853458), Iris-setosa=(Iris-setosa,0.00643370183783148)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.759603537166626), Iris-virginica=(Iris-virginica,0.23777484597390816), Iris-setosa=(Iris-setosa,0.00262161685946576)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.01560114791426806), Iris-virginica=(Iris-virginica,2.3647440157063515E-6), Iris-setosa=(Iris-setosa,0.9843964873417164)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.23490596676084197), Iris-virginica=(Iris-virginica,0.7650569443472615), Iris-setosa=(Iris-setosa,3.708889189648049E-5)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.5604116022032899), Iris-virginica=(Iris-virginica,0.4392143166585561), Iris-setosa=(Iris-setosa,3.740811381540314E-4)}
予測結果->正解: {Iris-versicolor=(Iris-versicolor,0.13060096881100045), Iris-virginica=(Iris-virginica,0.8693951745606155), Iris-setosa=(Iris-setosa,3.856628384041485E-6)}

1件不正解となっていますが、バージカラー(Iris-versicolor)の可能性は57% バージニカ(Iris-virginica)の可能性は42%と予測されています。逆に正解しているデータでも、バージカラー(Iris-versicolor)の可能性56%と、結構低い確率での正解も存在しています。学習させるデータが少ないので偏りが出るのでしょう。ここから、精度を上げていくのが楽しいはずです。

SampleTribuoでは、ロジスティクス回帰を利用して学習していますが、他にも利用できるクラスが準備されています。Trainerインターフェースの実装はたくさんあるので、これらを試していくだけでも結構時間がかかりそうです。やはりCPUでの学習よりも、GPUでの学習を行うようにしていきたいと思いますので、学習モデルそのものよりも先に、ONNXの連携部分か、TensorFlowの連携部分から見ていくのが良いような気がします。
image.png

まぁ機械学習なんて、ここ2、3日しかやったことないですが・・・。

あとがき

作成したSampleTribuoでは、毎回テストデータを読み込んで学習をすることになるので、学習済みデータを保存する方法を探していたのですが、ドキュメントからは見つかりませんでした。ソースコードからは、Tensorfrowを使ってTensorflowModelを保存する方法と、CSVを保存する方法は提供されてるっぽいことが分かるのですが、初めて触るコードなのでいまいち、癖というか思想が分かりません。こんなのが、ここにあるはずっていう感が働かないというか・・・。

自前でシリアライズ、デシリアライズするのかと勝手に考えていたら、Issuesに出てました。

質問内容は下記のようなものです。

モデルやデータセットをディスクにシリアライズする例をチュートリアルやドキュメントに含めることはできますか?ONNXの使用方法についての追加のヘルプもあると助かります。

それに対する回答がこちらです。

確かに、ドキュメントにシリアライズの例を追加することができます。現時点では、シリアライズは標準的なjava.io.Serializableメカニズムを使用しています。モデルやデータセットをロードして保存する例は、Tribuoに付属のデモプログラムで見ることができます。デシリアライズ攻撃を防ぐために、シリアライズ許可リストを使用する必要があることに注意してください。

現時点でのONNXローダの最良の例は、ユニットテストにあります。サードパーティ製モデルのロードプロセスを経たチュートリアルを追加する予定です。

つまり、ドキュメントに載っていないですし、java.io.Serializableをつかって自前でどうぞってことみたいです。ONNXローダの使い方は、テストコードにあると・・・。自分でソース検索しろってことですね。
テストコード探してみたけど、ロードはしてました。あれ?セーブは?

OnnxRuntimeの中を見たのですが、ロードのみですね。「JNI バインディング用のスタティックローダー。パブリックAPIはありませんが、このパッケージの様々なクラスから呼び出されます。」って書いてますし。
ai.onnxruntimeのパッケージの中も、org.tribuo.interop.onnxのパッケージの中にもなさそう。そもそも、現時点では、「シリアライズは標準的なjava.io.Serializableメカニズムを使用しています。」と言ってますから、Tensorflowで出来そうなだけで、公式には、今のところは無いんでしょうね。

機能はそこそこ出来上がっているのですが、いかんせん、英語でもドキュメントが追い付いていないです。ところどころjavadoc入っていないとかありますし。聞けば教えてくれそうではありますが、いちいち聞くよりも、試したほうが早いし、でもめんどくさいし見たいな感じです。いまのところは、うーんといった感じがありますが、何となく一気に化けそうな気もしますよ。中の人結構熱い人ですしね。

次はどこに手を付けていくか考えてみます。

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?