※原文は Intro classification with Irises を参照してください。
分類チュートリアル
このチュートリアルでは、Fisherの有名なアヤメ(アイリス)データセットを使って、Tribuoの分類モデルを使ってアヤメ(アイリス)種を予測する方法を紹介します(今は2020年ですが、デモではまだ1936年のデータセットを使っています。次回は90年代のMNISTを使いますのでご安心ください)。ここでは、単純なロジスティック回帰に焦点を当て、Tribuoが各モデルの内部に保存しているデータの出所とメタデータを調査します。
セットアップ
アヤメ(アイリス)のデータセットのコピーを取得する必要があります。
wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data
まず必要なTribuoのjarライブラリをロードします。ここでは、分類実験ジャーとjson interop jarライブラリを使って、証明情報を読み書きしています。
jars ./tribuo-classification-experiments-4.0.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.0.0-jar-with-dependencies.jar
import java.nio.file.Paths;
基本のorg.tribuoパッケージからすべてをインポートし、シンプルなCSVローダーと分類パッケージもインポートします。ロジスティック回帰を構築しようとしているので、それも必要になります。
import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
これらのインポートは来歴システムのためのものです。
import com.fasterxml.jackson.databind.*;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.config.json.*;
データの読み込み
Tribuoでは、すべての予測タイプは、入力から適切なOutputサブクラスを作成することができるOutputFactoryの実装に関連付けられています。ここでは、マルチクラス分類を実行しているので、LabelFactoryを使用します。次に、labelFactoryをシンプルなCSVLoaderに渡して、DataSourceにすべての列を読み込みます。
var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);
アヤメ(アイリス)のコピーにはカラムヘッダがないので、ヘッダを作成し、パスとどの変数を出力するか(この場合は "species")とともにロードメソッドに供給します。アヤメ(アイリス)にはあらかじめ定義された訓練/テストの分割がないので、70%のデータを訓練に使用して、分割を作成することにします。
var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
var irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"),"species",irisHeaders);
var irisSplitter = new TrainTestSplitter<>(irisesSource,0.7,1L);
トレーニングデータソースとテストデータソースをそれぞれのデータセットに投入する。これらのデータセットは、特徴領域や出力領域など、必要なメタデータをすべて計算します。学習データセットにはMutableDatasetを使用するのがベストです。これでデータセットが揃ったので、モデルを学習する準備ができました。
var trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
var testingDataset = new MutableDataset<>(irisSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));
Training data size = 105, number of features = 4, number of classes = 3
Testing data size = 45, number of features = 4, number of classes = 3
Training the model
それでは、トレーナーのインスタンスを作成して、デフォルトのハイパーパラメータを見てみましょう。これらのパラメータを完全に制御するために、完全に設定可能なLinearSGDTrainerを直接使用することができます。
Trainer<Label> trainer = new LogisticRegressionTrainer();
System.out.println(trainer.toString());
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)
これは、ロジスティック損失を用いた線形モデルで、AdaGradを用いて5エポックで学習したものです。
それでは、モデルを訓練してみましょう。他のパッケージと同様に、訓練アルゴリズムと訓練データがあれば、訓練は非常に簡単です。
Model<Label> irisModel = trainer.train(trainingDataset);
モデルの評価
モデルを学習したら、それがどれくらい学習できているのかを評価する必要があります。このために、適切な評価器が何であるかをlabelFactoryに尋ね(または直接インスタンス化し)、評価器にモデルとテストデータセットを渡します。また、dataestの代わりにデータソースを渡すこともできます。LabelEvaluator クラスは、一般的な分類メトリックをすべて実装しており、それぞれを個別に検査することができます。LabelEvaluator.toString() は、メトリクスのきれいにフォーマットされた要約を生成します。
var evaluator = new LabelEvaluator();
var evaluation = evaluator.evaluate(irisModel,testingDataset);
System.out.println(evaluation.toString());
Class n tp fn fp recall prec f1
Iris-versicolor 16 16 0 1 1.000 0.941 0.970
Iris-virginica 15 14 1 0 0.933 1.000 0.966
Iris-setosa 14 14 0 0 1.000 1.000 1.000
Total 45 44 1 1
Accuracy 0.978
Micro Average 0.978 0.978 0.978
Macro Average 0.978 0.980 0.978
Balanced Error Rate 0.022
precision(精度)、recall(リコール)、F1は、多クラス分類器を評価する際に使用される標準的な指標です。
また、混同行列を表示することもできます。
System.out.println(evaluation.getConfusionMatrix().toString());
Iris-versicolor Iris-virginica Iris-setosa
Iris-versicolor 16 0 0
Iris-virginica 1 14 0
Iris-setosa
モデルメタデータ
Tribuoは、構築されたすべてのモデルの特徴領域と出力領域を追跡します。これにより、元の学習データにアクセスせずにLIMEのようなテクニックを実行したり、特定の入力が学習モデルの範囲内にあるかどうかのチェックを追加したりすることが可能になります。
Irisesモデルの特徴領域を見てみましょう。
var featureMap = irisModel.getFeatureIDMap();
for (var v : featureMap) {
System.out.println(v.toString());
System.out.println();
}
CategoricalFeature(name=petalLength,id=0,count=105,map={1.2=1, 6.9=1, 3.6=1, 3.0=1, 1.7=4, 4.9=4, 4.4=3, 3.5=2, 5.9=2, 5.4=1, 4.0=4, 1.4=12, 4.5=4, 5.0=2, 5.5=3, 6.7=2, 3.7=1, 1.9=1, 6.0=2, 5.2=1, 5.7=2, 4.2=2, 4.7=2, 4.8=4, 1.6=4, 5.8=2, 3.8=1, 6.3=1, 3.3=1, 1.0=1, 5.6=4, 5.1=5, 4.6=3, 4.1=2, 1.5=9, 1.3=4, 3.9=3, 6.6=1, 6.1=2})
CategoricalFeature(name=petalWidth,id=1,count=105,map={2.0=3, 0.5=1, 1.2=3, 0.3=6, 1.6=2, 0.1=3, 0.4=5, 2.5=3, 2.3=4, 1.7=2, 1.1=3, 2.1=4, 0.6=1, 1.4=6, 1.0=5, 2.4=1, 1.8=12, 0.2=20, 1.9=4, 1.5=7, 1.3=8, 2.2=2})
CategoricalFeature(name=sepalLength,id=2,count=105,map={6.9=3, 6.4=3, 7.4=1, 4.9=4, 4.4=1, 5.9=3, 5.4=5, 7.2=3, 7.7=3, 5.0=8, 6.2=2, 5.5=5, 6.7=7, 6.0=3, 5.2=2, 6.5=3, 5.7=4, 4.7=2, 4.8=3, 5.8=4, 5.3=1, 6.8=3, 6.3=5, 7.3=1, 5.6=6, 5.1=7, 4.6=4, 7.6=1, 7.1=1, 6.6=2, 6.1=5})
CategoricalFeature(name=sepalWidth,id=3,count=105,map={2.0=1, 2.8=10, 3.6=4, 2.3=3, 2.5=5, 3.1=8, 3.8=4, 3.0=19, 2.6=4, 4.4=1, 3.3=4, 3.5=4, 2.4=2, 3.2=10, 2.9=5, 3.7=3, 3.4=6, 2.2=2, 3.9=2, 4.2=1, 2.7=7})
4つの特徴と、それらの値のヒストグラムを見ることができます。この情報は、各特徴からサンプリングしたり、LIMEのような局所的な説明変数の候補例を構築したり、範囲を確認したりするのに利用できます。特徴情報はモデル学習時に凍結されているので、特徴集合が疎な場合(NLP問題ではよくあることですが)には、学習集合中に特徴が何回発生したかを確認するのにも使えます。
モデル証明書
最近のアプリケーションでは,多くの異なる種類のMLモデルが配備されており,アプリケーションの様々な側面を支援しています。しかし、ほとんどのMLパッケージは、モデルの追跡と再構築をサポートしていません。Tribuoでは、各モデルがその実績を追跡します。どのようにして作成されたのか、いつ作成されたのか、どのようなデータが関係しているのかを知ることができます。ここでは、アイリスモデルのデータの実績を見てみましょう。デフォルトでは、Tribuo は、各証明書オブジェクトの toString() メソッドを使用することによって、人間が読みやすい適度な形式で証明書を表示します。すべての情報はプログラムからアクセスできます。
var provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
TrainTestSplitter(
class-name = org.tribuo.evaluation.TrainTestSplitter
source = CSVLoader(
class-name = org.tribuo.data.csv.CSVLoader
outputFactory = LabelFactory(
class-name = org.tribuo.classification.LabelFactory
)
response-name = species
separator = ,
quote = "
path = file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data
file-modified-time = 1999-12-14T15:12:39-05:00
resource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC
)
train-proportion = 0.7
seed = 1
size = 150
is-train = true
)
特定のランダムシードと分割率を使用して、2つに分割されたデータソース上でモデルが学習されていることがわかります。元のデータソースはCSVファイルで、ファイルの修正時刻とSHA-256ハッシュも記録されています。
同様に、訓練者の出所を調べることで、訓練アルゴリズムを知ることができます。
ここでは、予想通り、我々のモデルは勾配降下アルゴリズムとしてAdaGradを使用したLogisticRegressionTrainerを使用して訓練されていることがわかります。
別の記録を残したい場合は、モデルから実績を抽出してjsonファイルとして保存することができます(または、デプロイされたモデルから実績を取り消すこともできます)。
ObjectMapper objMapper = new ObjectMapper();
objMapper.registerModule(new JsonProvenanceModule());
objMapper = objMapper.enable(SerializationFeature.INDENT_OUTPUT);
jsonの実績は冗長ですが、人間が読める別のシリアル化フォーマットを提供しています。
System.out.println(jsonProvenance);
[ {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "linearsgdmodel-0",
"object-class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
"provenance-class" : "org.tribuo.provenance.ModelProvenance",
"map" : {
"instance-values" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance",
"map" : { }
},
"tribuo-version" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "tribuo-version",
"value" : "4.0.1",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"trainer" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "trainer",
"value" : "logisticregressiontrainer-2",
"provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
"additional" : "",
"is-reference" : true
},
"trained-at" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "trained-at",
"value" : "2020-08-31T20:24:37.854775-04:00",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
"additional" : "",
"is-reference" : false
},
"dataset" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "dataset",
"value" : "mutabledataset-1",
"provenance-class" : "org.tribuo.provenance.DatasetProvenance",
"additional" : "",
"is-reference" : true
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "mutabledataset-1",
"object-class-name" : "org.tribuo.MutableDataset",
"provenance-class" : "org.tribuo.provenance.DatasetProvenance",
"map" : {
"num-features" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "num-features",
"value" : "4",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"num-examples" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "num-examples",
"value" : "105",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"num-outputs" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "num-outputs",
"value" : "3",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"tribuo-version" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "tribuo-version",
"value" : "4.0.1",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"datasource" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "datasource",
"value" : "traintestsplitter-3",
"provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
"additional" : "",
"is-reference" : true
},
"transformations" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance",
"list" : [ ]
},
"is-sequence" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "is-sequence",
"value" : "false",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
"additional" : "",
"is-reference" : false
},
"is-dense" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "is-dense",
"value" : "false",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
"additional" : "",
"is-reference" : false
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.MutableDataset",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "logisticregressiontrainer-2",
"object-class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
"provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
"map" : {
"seed" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "seed",
"value" : "12345",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
"additional" : "",
"is-reference" : false
},
"minibatchSize" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "minibatchSize",
"value" : "1",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"train-invocation-count" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "train-invocation-count",
"value" : "0",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"is-sequence" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "is-sequence",
"value" : "false",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
"additional" : "",
"is-reference" : false
},
"shuffle" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "shuffle",
"value" : "true",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
"additional" : "",
"is-reference" : false
},
"epochs" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "epochs",
"value" : "5",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"optimiser" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "optimiser",
"value" : "adagrad-4",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
"additional" : "",
"is-reference" : true
},
"host-short-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "host-short-name",
"value" : "Trainer",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"objective" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "objective",
"value" : "logmulticlass-5",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
"additional" : "",
"is-reference" : true
},
"loggingInterval" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "loggingInterval",
"value" : "1000",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "traintestsplitter-3",
"object-class-name" : "org.tribuo.evaluation.TrainTestSplitter",
"provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
"map" : {
"train-proportion" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "train-proportion",
"value" : "0.7",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
"additional" : "",
"is-reference" : false
},
"seed" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "seed",
"value" : "1",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
"additional" : "",
"is-reference" : false
},
"size" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "size",
"value" : "150",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
"additional" : "",
"is-reference" : false
},
"source" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "source",
"value" : "csvloader-6",
"provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
"additional" : "",
"is-reference" : true
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.evaluation.TrainTestSplitter",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"is-train" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "is-train",
"value" : "true",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "adagrad-4",
"object-class-name" : "org.tribuo.math.optimisers.AdaGrad",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
"map" : {
"epsilon" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "epsilon",
"value" : "0.1",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
"additional" : "",
"is-reference" : false
},
"initialLearningRate" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "initialLearningRate",
"value" : "1.0",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
"additional" : "",
"is-reference" : false
},
"initialValue" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "initialValue",
"value" : "0.0",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
"additional" : "",
"is-reference" : false
},
"host-short-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "host-short-name",
"value" : "StochasticGradientOptimiser",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.math.optimisers.AdaGrad",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "logmulticlass-5",
"object-class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
"map" : {
"host-short-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "host-short-name",
"value" : "LabelObjective",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "csvloader-6",
"object-class-name" : "org.tribuo.data.csv.CSVLoader",
"provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
"map" : {
"resource-hash" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "resource-hash",
"value" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance",
"additional" : "SHA256",
"is-reference" : false
},
"path" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "path",
"value" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance",
"additional" : "",
"is-reference" : false
},
"file-modified-time" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "file-modified-time",
"value" : "1999-12-14T15:12:39-05:00",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
"additional" : "",
"is-reference" : false
},
"quote" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "quote",
"value" : "\"",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
"additional" : "",
"is-reference" : false
},
"response-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "response-name",
"value" : "species",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
},
"outputFactory" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "outputFactory",
"value" : "labelfactory-7",
"provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
"additional" : "",
"is-reference" : true
},
"separator" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "separator",
"value" : ",",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
"additional" : "",
"is-reference" : false
},
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.data.csv.CSVLoader",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
}, {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
"object-name" : "labelfactory-7",
"object-class-name" : "org.tribuo.classification.LabelFactory",
"provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
"map" : {
"class-name" : {
"marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
"key" : "class-name",
"value" : "org.tribuo.classification.LabelFactory",
"provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
"additional" : "",
"is-reference" : false
}
}
} ]
別の方法として、モデルの証明書は Model.toString() の出力にも存在しますが、この形式は機械可読ではありません。
linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=CSV(class-name=org.tribuo.data.csv.CSVLoader,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),response-name=species,separator=,,quote=",path=file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,file-modified-time=1999-12-14T15:12:39-05:00,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC]),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=false,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.0.1),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),loggingInterval=1000,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2020-08-31T20:24:37.854775-04:00,instance-values={},tribuo-version=4.0.1)
評価には、テストデータの実績とともにモデルの実績を記録する実績もあります。JSON 実績の別の形式を使用しています。しかし、これは少し精度が落ます。そのかわり、読みやすくなっています。この形式は参照に適していますが、すべてを文字列に変換しているため、元の実績オブジェクトを再構築するためには使用できません。
String jsonEvaluationProvenance = objMapper.writeValueAsString(ProvenanceUtil.convertToMap(evaluation.getProvenance()));
System.out.println(jsonEvaluationProvenance);
{
"tribuo-version" : "4.0.1",
"dataset-provenance" : {
"num-features" : "4",
"num-examples" : "45",
"num-outputs" : "3",
"tribuo-version" : "4.0.1",
"datasource" : {
"train-proportion" : "0.7",
"seed" : "1",
"size" : "150",
"source" : {
"resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
"path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
"file-modified-time" : "1999-12-14T15:12:39-05:00",
"quote" : "\"",
"response-name" : "species",
"outputFactory" : {
"class-name" : "org.tribuo.classification.LabelFactory"
},
"separator" : ",",
"class-name" : "org.tribuo.data.csv.CSVLoader"
},
"class-name" : "org.tribuo.evaluation.TrainTestSplitter",
"is-train" : "false"
},
"transformations" : [ ],
"is-sequence" : "false",
"is-dense" : "false",
"class-name" : "org.tribuo.MutableDataset"
},
"class-name" : "org.tribuo.provenance.EvaluationProvenance",
"model-provenance" : {
"instance-values" : { },
"tribuo-version" : "4.0.1",
"trainer" : {
"seed" : "12345",
"minibatchSize" : "1",
"train-invocation-count" : "0",
"is-sequence" : "false",
"shuffle" : "true",
"epochs" : "5",
"optimiser" : {
"epsilon" : "0.1",
"initialLearningRate" : "1.0",
"initialValue" : "0.0",
"host-short-name" : "StochasticGradientOptimiser",
"class-name" : "org.tribuo.math.optimisers.AdaGrad"
},
"host-short-name" : "Trainer",
"class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
"objective" : {
"host-short-name" : "LabelObjective",
"class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass"
},
"loggingInterval" : "1000"
},
"trained-at" : "2020-08-31T20:24:37.854775-04:00",
"dataset" : {
"num-features" : "4",
"num-examples" : "105",
"num-outputs" : "3",
"tribuo-version" : "4.0.1",
"datasource" : {
"train-proportion" : "0.7",
"seed" : "1",
"size" : "150",
"source" : {
"resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
"path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
"file-modified-time" : "1999-12-14T15:12:39-05:00",
"quote" : "\"",
"response-name" : "species",
"outputFactory" : {
"class-name" : "org.tribuo.classification.LabelFactory"
},
"separator" : ",",
"class-name" : "org.tribuo.data.csv.CSVLoader"
},
"class-name" : "org.tribuo.evaluation.TrainTestSplitter",
"is-train" : "true"
},
"transformations" : [ ],
"is-sequence" : "false",
"is-dense" : "false",
"class-name" : "org.tribuo.MutableDataset"
},
"class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel"
}
}
この実績情報には、モデルの実績情報に含まれるすべてのフィールドと、テストデータ、分割されたデータ、CSVが含まれていることがわかります。
この実績情報は、それだけでもモデルを追跡するのに便利ですが、設定チュートリアルで説明されている設定システムと組み合わせることで、モデルや実験を再構築するための強力な方法となり、どのようなMLモデルでもほぼ完璧な再現性を実現することができます。
結論
Tribuoのcsvロードのメカニズム、単純な分類器のトレーニング方法、テストデータ上での分類器の評価方法、さらにTribuoのモデルと評価オブジェクト内に保存されているメタデータと実績情報を見てみました。