LoginSignup
3
1

More than 3 years have passed since last update.

Sparkで機械学習入門 「価格推定」 #1 Apache Spark(Java)でデータセットの読み込み

Last updated at Posted at 2019-08-11

概要

  • JavaApache Sparkを使った機械学習入門編です
  • 「価格推定」を題材にApache Sparkの使い方から実際の機械学習(学習、回帰)までステップアップ方式でハンズオンします
  • 機械学習アルゴリズムには教師あり学習勾配ブースティングツリー1を使います
  • ステップアップ方式に複数回に分けて投稿します
  • 全ソースコードはこちら https://github.com/riversun/spark-gradient-boosting-tree-regression-example

環境

  • Apache Spark 2.4.3
  • Java 8~

対象

  • 難しいことを考えずに機械学習をまずはやってみようと思っている人

やりたいこと

以下のようなアクセサリーの価格リストがあります。

「一番下の行にあるシルバーのブレスレットの値段はいくらでしょうか?」

この値をJavaApache Sparkを使った機械学習によって予測したいとおもいます。

いわゆる価格推定です。

id 材質 形状 重さ(g) ブランド 販売店 価格(円)
0 シルバー ブレスレット 56 海外有名ブランド 百貨店 40864
1 ゴールド 指輪 48 国内有名ブランド 直営店 63055
2 ダイア イアリング 37 国内有名ブランド 直営店 112159
3 ダイア ネックレス 20 海外有名ブランド 直営店 216053
4 ダイア ネックレス 33 海外有名ブランド 百貨店 219666
5 シルバー ブローチ 55 国内有名ブランド 百貨店 16482
6 プラチナ ブローチ 58 海外超有名ブランド 直営店 377919
7 ゴールド イアリング 49 国内有名ブランド 直営店 60484
8 シルバー ネックレス 59 ノーブランド 激安ショップ 6256
9 ゴールド 指輪 13 国内有名ブランド 百貨店 37514



x シルバー ブレスレット 56 海外超有名ブランド 直営店 お値段いくら?

Apache SparkをローカルPCで使う

動作環境

Apache Sparkはスタンドアローン動作で

Apache SparkHadoopの仲間でビッグデータを簡単に扱うことのできる分散処理フレームワークで、複数台のマシンをクラスター構成で使う分散環境で強みを発揮しますが、1台のローカルPCでのスタンドアローン動作も簡単です。

今回は、Apache Sparkをスタンドアローン動作させspark.mlという機械学習機能を使います。

OS

OSはJavaランタイムさえ入っていればWindows2でもLinuxでもMacでもOKです

開発言語はJava

Apache Spark自体はScalaで書かれており、Scalaをはじめ、Java、Python、Rから利用できるAPIが整備されています。

本稿では、Javaを使います。

ライブラリ依存を設定する

JavaのアプリにApache Sparkを組み込んでつかうために依存ライブラリを設定します。

POM.xml(またはbuild.gradle)を準備して、以下のようにApache Spark関連ライブラリをdependenciesに加えるだけです。JacksonApache Spark内で使われるため追加しておきます。

POM.xml(抜粋)
<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-core_2.12</artifactId>
    <version>2.4.3</version>
</dependency>
<dependency>
    <groupId>org.apache.spark</groupId>
    <artifactId>spark-mllib_2.12</artifactId>
    <version>2.4.3</version>
</dependency>
<dependency>
    <groupId>com.fasterxml.jackson.module</groupId>
    <artifactId>jackson-module-scala_2.12</artifactId>
    <version>2.9.9</version>
</dependency>

<dependency>
    <groupId>com.fasterxml.jackson.core</groupId>
    <artifactId>jackson-databind</artifactId>
    <version>2.9.9</version>
</dependency>

データセットを準備する

次に、学習に使うデータセットを準備しましょう。

ファイルはこちらにあります。

今回は、ダイアモンドやプラチナなどの原材料でつくったネックレス・指輪などのアクセサリーの価格を機械学習で予測してみたいとおもいます。

機械学習に使うのは以下のようなCSV形式のデータです。

gem_price_ja.csv
id,material,shape,weight,brand,shop,price
0,シルバー,ブレスレット,56,海外有名ブランド,百貨店,40864
1,ゴールド,指輪,48,国内有名ブランド,直営店,63055
2,ダイア,イアリング,37,国内有名ブランド,直営店,112159
3,ダイア,ネックレス,20,海外有名ブランド,直営店,216053
4,ダイア,ネックレス,33,海外有名ブランド,百貨店,219666
5,シルバー,ブローチ,55,国内有名ブランド,百貨店,16482
6,プラチナ,ブローチ,58,海外超有名ブランド,直営店,377919
7,ゴールド,イアリング,49,国内有名ブランド,直営店,60484
8,シルバー,ネックレス,59,ノーブランド,激安ショップ,6256
・・・・

このCSV形式データはここにのせましたが、id、material(材質)、shape(形状)、weight(重さ)、brand(ブランド)、shop(販売店)、price(価格)の順番で並んでおり、全部で500件あります。

つまり、変数としては6種類(idを除くmaterial,shape,weight,brand,shop,priceの6種類)あり、それが500レコードということになります。

Sparkでデータを読み込んでみる

ワークディレクトリの直下にdatasetというディレクトリを作り学習用のデータ gem_price_ja.csvをそこに置きます。

このCSVファイルを読み込んでSparkが扱えるようにするコードは以下の通り。


import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class GBTRegressionStep1 {

  public static void main(String[] args) {

    SparkSession spark = SparkSession
        .builder()
        .appName("GradientBoostingTreeGegression")
        .master("local[*]")// (1)
        .getOrCreate();

    spark.sparkContext().setLogLevel("OFF");// (2)

    Dataset<Row> dataset = spark
        .read()
        .format("csv")// (3)
        .option("header", "true")// (4)
        .option("inferSchema", "true")// (5)
        .load("dataset/gem_price_ja.csv");// (6)

    dataset.show();// (7)

    dataset.printSchema();// (8)

  }

}

コード解説

(1) .master("local[*]")・・・Sparkをローカルモードで動作させる。「*」はワーカースレッドをCPUの論理コア数だけ割り当てる。.master("local")としたらワーカースレッドの数は1固定。.master("local[2]")としたらワーカースレッドの数は2となる。

(2)・・・ログをオフにする。

またはlog4jのほうで直接設定するには以下のようにします。

org.apache.log4j.Logger.getLogger("org").setLevel(org.apache.log4j.Level.ERROR);
org.apache.log4j.Logger.getLogger("akka").setLevel(org.apache.log4j.Level.ERROR);

(3) .format("csv") ・・・データファイルをCSV形式として読み込む

(4) .option("header", "true")・・・これをtrueにすると、CSVファイルの先頭の行をカラム名として使う
SparkではRDBMSのように、学習につかうためのデータを表形式のデータとして扱っていきます。学習データをCSV形式で準備する場合は、以下のように先頭の行にある定義をカラム名にマッピングできるので便利です。

CSVファイルの先頭の行
id,material,shape,weight,brand,shop,price

(5) .option("inferSchema", "true")・・・入力データのスキーマ(schema)を推定する
これをtrueにすると、CSVファイルを読み込むときに1つ1つのデータのをSparkが自動的に推定してセットしてくれます。便利ですね。
たとえば、以下の行ならシンプルなので、各データのinteger,string,string,integer,string,string,integerのように推定されます。もし推定が間違えていたら、自分でスキーマを厳密に定義することも可能です。

(6) .load("dataset/gem_price_ja.csv");・・・データファイルを読み込む

このようにシンプルなデータならinferSchemaで何とかなる
0,シルバー,ブレスレット,56,海外有名ブランド,百貨店,40864

(7) dataset.show();・・・読み込んだデータセットを表示する

これを実行すると、以下のようになる。

+---+--------+------------+------+------------------+------------+------+
| id|material|       shape|weight|             brand|        shop| price|
+---+--------+------------+------+------------------+------------+------+
|  0|シルバー|ブレスレット|    56|  海外有名ブランド|      百貨店| 40864|
|  1|ゴールド|        指輪|    48|  国内有名ブランド|      直営店| 63055|
|  2|  ダイア|  イアリング|    37|  国内有名ブランド|      直営店|112159|
|  3|  ダイア|  ネックレス|    20|  海外有名ブランド|      直営店|216053|
|  4|  ダイア|  ネックレス|    33|  海外有名ブランド|      百貨店|219666|
|  5|シルバー|    ブローチ|    55|  国内有名ブランド|      百貨店| 16482|
|  6|プラチナ|    ブローチ|    58|海外超有名ブランド|      直営店|377919|
|  7|ゴールド|  イアリング|    49|  国内有名ブランド|      直営店| 60484|
|  8|シルバー|  ネックレス|    59|      ノーブランド|激安ショップ|  6256|
|  9|ゴールド|        指輪|    13|  国内有名ブランド|      百貨店| 37514|
| 10|プラチナ|  ネックレス|    23|  国内有名ブランド|激安ショップ| 48454|
| 11|  ダイア|  イアリング|    28|  海外有名ブランド|      直営店|233614|
| 12|シルバー|  ネックレス|    54|  国内有名ブランド|激安ショップ| 12235|
| 13|プラチナ|    ブローチ|    28|      ノーブランド|      百貨店| 34285|
| 14|シルバー|  ネックレス|    49|      ノーブランド|激安ショップ|  5970|
| 15|プラチナ|ブレスレット|    40|  国内有名ブランド|      百貨店| 82960|
| 16|シルバー|  ネックレス|    21|  海外有名ブランド|      百貨店| 28852|
| 17|ゴールド|        指輪|    11|  国内有名ブランド|      百貨店| 34980|
| 18|プラチナ|ブレスレット|    44|海外超有名ブランド|      百貨店|340849|
| 19|シルバー|        指輪|    11|海外超有名ブランド|      直営店| 47053|
+---+--------+------------+------+------------------+------------+------+
only showing top 20 rows

(8) dataset.printSchema();・・・スキーマを表示する

さきほど.option("inferSchema", "true")しましたが、以下のようにSparkはちゃんと型が推定してくれました。

root
 |-- id: integer (nullable = true)
 |-- material: string (nullable = true)
 |-- shape: string (nullable = true)
 |-- weight: integer (nullable = true)
 |-- brand: string (nullable = true)
 |-- shop: string (nullable = true)
 |-- price: integer (nullable = true)

次回「#2 データの前処理(カテゴリ変数の取り扱い)」へと続く

補足:WindowsでApache Sparkを実行する場合

SparkをWindows環境で実行するときに以下のようなエラーメッセージがでる場合はwinutils.exeをダウンロードして適当なディレクトリに配置する。

Could not locate executable null\bin\winutils.exe in the Hadoop binaries.

winutils.exeはUnixコマンドをWindowsでエミュレーションする為のユーティリティでHadoopが使う。

これを http://public-repo-1.hortonworks.com/hdp-win-alpha/winutils.exe からダウンロードしてきて、例えば、c:/Temp/以下にc:/Temp/winutil/bin/winutil.exeとなるようにディレクトリをほって配置する。

そしてコードの先頭で以下のようにセットする。

System.setProperty("hadoop.home.dir", "c:\\Temp\\winutil\\");

  1. Sparkのspark.mlにある勾配ブースティングツリーは、いま流行ってるLightGBMXGBoost(treeモデルを選択した場合)などと勾配ブースティングという基本的な理論は同じですが、設計面(並行性など)や性能面(過学習防止の為のハイパーパラメーターチューニングなど)はだいぶ異なっており同一のものではありません。Spark開発チームでもこれらをウォッチしていたり、XGBoost本家によるxgboost4j-sparkなどの提供もあります。 

  2. SparkをWindows環境で実行するときに以下のようなエラーメッセージがでる場合はwinutils.exeをダウンロードして適当なディレクトリに配置する。文末の補足にて説明 

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1