0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Oracleから公開されたTribuoをさわってみた。ドキュメント Tribuo - A Java prediction library (v4.0)

Last updated at Posted at 2020-09-26

追記 2020/09/28
この記事は古いです。こちらにまとめなおしました。
https://qiita.com/jashika/items/d7c86dd8053379fd909f

※ 原文は Tribuo - A Java prediction library (v4.0) を参照してください。

#Introduction
Tribuoは、機械学習モデルを構築・展開するためのJavaライブラリです。中心となる開発チームはOracle Labsの機械学習研究グループであり、このライブラリはApache 2.0ライセンスのもと、Github上で公開されています。

・API は強く型付けされており、モデル、予測値、データセット、例題のためのクラスがパラメータ化されています。

・APIは高レベルで、モデルは例題を消費し、予測値を生成しますが、float配列ではありません。

・APIは統一されており、すべての予測タイプは同じ(よく型付けされた)APIを持ち、Tribuoのクラスは予測タイプによってパラメータ化されています(例:分類はLabelを使用し、回帰はRegressorを使用します)。

・APIは再利用可能で、モジュール化されており、必要なものだけを小分けにしてパッケージ化されているので、必要なものだけを導入することができます。

Tribuoは、同じAPIで幅広いMLアルゴリズムと特徴量を提供しています。

・分類:線形モデル、SVM、ツリー、アンサンブル、ディープラーニング

・回帰:線形モデル、罰則付き線形回帰、SVM、ツリー、アンサンブル、深層学習

・クラスタリング:K-Means

・異常検出:SVM

私たちは、時間の経過とともに利用可能なアルゴリズムを増やしていく予定です。

Tribuoは、データセットをロードし、モデルを訓練し、テストデータ上でモデルを評価することを簡単にします。例えば、このコードはロジスティック回帰モデルを学習し、評価します。

var trainSet = new MutableDataset<>(new LibSVMDataSource("train-data",new LabelFactory()));
var model    = new LogisticRegressionTrainer().train(trainSet);
var eval     = new LabelEvaluator().evaluate(new LibSVMDataSource("test-data",trainSet.getOutputFactory()));

#Getting Started
Tribuoをプロジェクトでりようするために、Mavenでは下記のように設定します。

<dependency>
    <groupId>org.tribuo</groupId>
    <artifactId>tribuo-all</artifactId>
    <version>4.0.0</version>
    <type>pom</type>
</dependency>

tribuo-allモジュールは、Tribuoのすべてを取り込みます。特定のユースケースのサブセットを後で選択することができます。
ここでは、分類システムを構築して評価する方法を示す簡単な例を示します。これには4つのステップがあります。

1.アヤメの種を分類するためのデータセットをCSVから読み込む。
2.そのデータセットを学習用データセットとテスト用データセットに分割する。
3.異なるトレーナーを用いて2種類のモデルを学習する。
4.モデルを使ってテストセットの予測を行い、テストセット全体の性能を評価する。

// ラベル付きアヤメ(アイリス)データを読み込む
var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
DataSource<Label> irisData =
        new CSVLoader<>(new LabelFactory()).loadDataSource(Paths.get("bezdekIris.data"),
                                     /* Output column   */ irisHeaders[4],
                                     /* Column headers  */ irisHeaders);

// アヤメ(アイリス)データをトレーニングセット(70%)とテストセット(30%)に分割
var splitIrisData = new TrainTestSplitter<>(irisesSource,
                       /* Train fraction */ 0.7,
                             /* RNG seed */ 1L);
var trainData = new MutableDataset<>(splitIrisData.getTrain());
var testData = new MutableDataset<>(splitIrisData.getTest());

// 決定木を学習する
var cartTrainer = new CARTClassificationTrainer();
Model<Label> tree = cartTrainer.train(trainData);

// ロジスティック回帰
var linearTrainer = new LogisticRegressionTrainer();
Model<Label> linear = linearTrainer.train(trainData);

// 最終的には、目に見えないデータから予測を行う
// 各予測は、出力名(ラベル)からスコア/確率へのマップ
Prediction<Label> prediction = linear.predict(testData.get(0));

// 完全なテストデータセットを評価して、精度、F1などを計算してもよい。
Evaluation<Label> evaluation = new LabelEvaluation().evaluate(linear,testData);

// 手動での評価を検査する。
double acc = evaluation.accuracy();

// フォーマットされた評価文字列を表示する。
System.out.println(evaluation.toString());

フォーマットされた評価出力は以下のようになります。

Class                           n          tp          fn          fp      recall        prec          f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

この例の詳細については、同じアヤメ(アイリス)データセットを使用した分類チュートリアルをご覧ください。
後で翻訳する。
翻訳した。

#Documentation Overview
機能一覧では、Tribuoでできることや、ネイティブでもサードパーティ製ライブラリへのインタフェースを介してもサポートしているアルゴリズムの概要を説明しています。Tribuoを理解するための最良の方法は、Tribuoのアーキテクチャ・ドキュメントを読むことです。基本的な定義、データフロー、ライブラリ構造、設定(オプションと実績を含む)、データロード、変換、サンプルの詳細、入力機能を隠すために利用できる難読化機能について説明しています。パッケージ構造の概要では、Tribuoのパッケージが、それぞれがサポートする機械学習タスクを中心にどのように構成されているかを説明しています。これらのパッケージはモジュールにグループ化されているので、Tribuoのユーザは実装に必要な部分だけに依存することができます。Tribuoを使用する上でのセキュリティ上の注意事項や、ユーザが期待することを必ずお読みください。その他の問題や一般的な質問については、FAQを参照してください。すべてのクラスとパッケージの詳細については、TribuoのJavaDocを参照してください。

#Tutorials
分類、クラスタリング、回帰、異常検出、設定システムのチュートリアルノートを用意しています。これらはJava Jupyterノートブックカーネルを使用しており、Java 10+で動作します。varキーワードを適切な型に置き換えることで、チュートリアルのコードをJava 8のコードに戻すのは簡単なはずです。

#Configuration and Provenance
Tribuoのトレーナは、OLCUT設定システムを介して完全に設定することができます。これにより、XML(またはJSONやEDN)ファイルにトレーナーを一度定義しておけば、全く同じパラメータで繰り返しモデルを構築することができます。各パッケージのconfigフォルダには、提供されているトレーナーの設定例があります。モデルは、データセット自体と同様にJavaシリアライズを使用してシリアライズ可能で、使用されたコンフィグレーションはどのモデルにも保存されます。すべてのモデルと評価には、モデルや評価がいつ作成されたか、使用されたデータは何か、データに適用された変換は何か、トレーナーのハイパーパラメータは何か、評価の場合はどのモデルが使用されたかを記録する、シリアライズ可能な証明書オブジェクトが含まれています。この情報はJSONに抽出することもできますし、Javaシリアライズを使って直接シリアライズすることもできます。本番環境では、この実績情報は、外部システムを介してモデルのトラッキングを提供するために、ハッシュに置き換えて編集することができます。設定、オプション、証明書についての詳細はこちらをご覧ください。

#Platform Support & Requirements
TribuoはJava 8+で動作し、JavaのLTSバージョンと最新のリリースでテストを行っています。Tribuo自体はJavaライブラリであり、すべてのJavaプラットフォームでサポートされていますが、一部のインタフェースはネイティブコードを必要とし、ネイティブライブラリがある場所でのみサポートされます。Windows 10、macOS、Linux (RHEL/OL/CentOS 7+)上のx86_64アーキテクチャでテストしています。別のプラットフォームに興味があり、ネイティブライブラリのインターフェース(ONNXランタイム、TensorFlow、XGBoost)のいずれかを使用したい場合は、それらのライブラリの開発者に連絡することをお勧めします。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?