4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Ubuntu で XGBoost4J のビルド & 実行

Last updated at Posted at 2016-04-13

事前準備

  • 以下をインストール
  • Java 1.7 or later
  • Git
  • g++
  • Maven
  • 環境変数 JAVA_HOME を設定

XGBoost4Jのビルド

$ git clone --recursive https://github.com/dmlc/xgboost
$ cd xgboost/jvm-packages
$ mvn package install

データ形式

XGBoost4J では LIBSVM形式のファイル、CSR/CSC形式の疎行列、密行列を取り扱える。LIBSVM形式のデータセットは以下から取得できる。

分類問題において、ラベルは 0~クラス数-1 (2値分類では0,1、3値分類では0,1,2)でなければならないので上記のデータセットを使用する場合は修正が必要。

実行例その1 - 分類 & cross validation

ここから iris データをダウンロード。
ラベルが 1, 2, 3 となっているので 0, 1, 2 に修正する。

build.gradle
repositories {
    mavenLocal()
}

dependencies {
    compile 'ml.dmlc:xgboost4j:0.5'
}
XGBoost4JClassificationTest.java
import java.util.HashMap;
import java.util.Map;

import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

public class XGBoost4JClassificationTest {

	private static final String trainMatPath = "data/iris.scale";

	public static void main(String[] args) throws XGBoostError {
		DMatrix trainMat = new DMatrix(trainMatPath);

		Map<String, Object> params = new HashMap<>();
		params.put("objective", "multi:softmax");
		params.put("num_class", 3);     // 多値分類の場合は "num_class" でクラス数を指定

		int round = 10;
		int nfold = 5;	// 5-fold
        // cross validation
		String[] evalHists = XGBoost.crossValidation(trainMat, params, round, nfold, null, null, null);  
		for (String evalHist : evalHists) {
			System.out.println(evalHist);
		}
	}
}
結果
[0]	cv-test-merror:0.066667	cv-train-merror:0.023333
[1]	cv-test-merror:0.066667	cv-train-merror:0.023333
[2]	cv-test-merror:0.053333	cv-train-merror:0.020000
[3]	cv-test-merror:0.060000	cv-train-merror:0.018333
[4]	cv-test-merror:0.053333	cv-train-merror:0.016667
[5]	cv-test-merror:0.053333	cv-train-merror:0.015000
[6]	cv-test-merror:0.053333	cv-train-merror:0.015000
[7]	cv-test-merror:0.046667	cv-train-merror:0.011667
[8]	cv-test-merror:0.046667	cv-train-merror:0.010000
[9]	cv-test-merror:0.046667	cv-train-merror:0.010000
  • XGBoost.crossValidation() で cross validation を行う
  • 分類では error rate、回帰では RMSE が出力される

実行例その2 - 回帰 & 学習、予測データを指定

ここから housing データをダウンロード。
学習用と予測用に分割する。

XGBoost4JRegressionTest.java
import java.util.HashMap;
import java.util.Map;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

public class XGBoost4JRegressionTest {

	private static final String trainMatPath = "data/housing.scale.train";
	private static final String testMatPath = "data/housing.scale.test";

	public static void main(String[] args) throws XGBoostError {
		DMatrix trainMat = new DMatrix(trainMatPath);
		DMatrix testMat = new DMatrix(testMatPath);
        
		Map<String, Object> params = new HashMap<>();
		params.put("objective", "reg:linear");

		Map<String, DMatrix> watches = new HashMap<>();
		watches.put("train", trainMat);
		watches.put("test", testMat);

		int round = 10;
		// 学習
		Booster booster = XGBoost.train(trainMat, params, round, watches, null, null);
		// 予測
		float[][] predicts = booster.predict(testMat);

        // 評価用データの正解値
		float[] answers = testMat.getLabel();

		System.out.println("RMSE of predicts: " + calcRMSE(answers, predicts));
	}

	// 正解値と予測値のRMSEを計算
	private static double calcRMSE(float[] answers, float[][] predicts) {
		double sse = 0.0;
		int num = answers.length;

		for (int i = 0; i < num; i++) {
			sse += Math.pow(answers[i] - predicts[i][0], 2);
		}

		return Math.sqrt(sse / num);
	}
}
  • XGBoost.train() で学習、Booster.predict() で予測を行う

参考

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?