LoginSignup
29
31

More than 5 years have passed since last update.

Deep learning on Spark with DL4J

Last updated at Posted at 2015-12-09

この記事はSpark Advent Calendar 9日目の記事として書きました。

Spark上でDeep Learningのアルゴリズムを走らせるにはいくつか方法があります。

今回は3つ目のdeeplearning4jをSparkから利用する方法を紹介したいと思います。

deeplearning4jとは

Skymindが中心となって開発をしているJVM上で動くDeep Learningのフレームワークです。Deep Learningのアルゴリズムを実装したものはCaffeTorch, Chainerなどがありますが、これらはメインではLuaやC/C++で実装されています。deeplearning4jはJava, Scalaで実装されており元々HadoopやSparkといったすでにあるBig Dataのためのフレームワークとの親和性に重きをおいて作られています。

image
Architecture from deeplearning4j.org

本体であるdeeplearning4jの他にも

  • ベクトル化のためのツールであるCanova
  • JVM上でndarrayライクなベクトルを扱えるnd4j
  • Scalaラッパであるdeeplearning4s, nd4s

などもともに開発されています。

Sparkで利用するには

Sparkから利用するためにはこのdeeplearning4jをSpark pacakge化したdl4j-spark-mlというパッケージを使います。

すでにSparkが利用できる環境が整って入れば以下のコマンドを叩くだけです。

$ $SPARK_HOME/bin/spark-shell \
    --packages deeplearning4j:dl4j-spark-ml:0.4-rc3.4

これでclasspathにdeeplearning4jのクラスたちが通って、Spark上で利用できます。とはいってもどうアルゴリズムを書いていいかわからないのでdl4j-spark-ml-examplesという名前でサンプルコードも作られています。

ここに含まれているIrisデータセットを利用したサンプルを見てみましょう。

ml.JavaIrisClassification

このコードではSparkのDaraFrameを使ってIrisデータセットを読み込んでいます。deeplearning4jの方でそれ用のコネクタ(Relation)が容易されています。

        String path = args.length == 1 ? args[0]
                : "file://" + System.getProperty("user.dir") + "/data/svmLight/iris_svmLight_0.txt";
        DataFrame data = jsql.read()
                .format(DefaultSource.class.getName())
                .load(path);

このDataFrameに対してML Pipelineを構築していきます。Spark ML/MLlibの資産を利用できるのは大変便利です。

        // 訓練用のデータを抜き出す
        DataFrame trainingData = data.sample(false, 0.6, 11L);
        // テスト用のデータを抜き出す
        DataFrame testData = data.except(trainingData);

        // 特徴量をscalingさせる
        StandardScaler scaler = new StandardScaler()
                .setWithMean(true).setWithStd(true)
                .setInputCol("features")
                .setOutputCol("scaledFeatures");

        // deeplearning4jが用意したEstimatorを作成する
        NeuralNetworkClassification classification 
          = new NeuralNetworkClassification()
                .setFeaturesCol("scaledFeatures")
                .setConf(getConfiguration());

        // Pipelineを作る
        Pipeline pipeline = new Pipeline()
          .setStages(new PipelineStage[] { scaler, classification });

NeuralNetworkClassificationというEstimatorを作るための設定は以下のようになります。


new NeuralNetConfiguration.Builder()
  .seed(11L) // 初期重み行列を生成するためのseed
  .iterations(100) // backpropagtionを行うiterationの回数
  .learningRate(1e-3f) // 重み更新を行うときの学習率
  .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) // backpropagationの際の勾配の計算方法
  .momentum(0.9) // 勾配更新の際に利用される値
  .constrainGradientToUnitNorm(true)
  .useDropConnect(true) // 過学習などを避ける学習速度を早めるためにDrop outするかどうか
  .list(2) // 層の数。入力層(データそのもの)は含まなない
  // 第一層 制約付きボルツマンマシン 隠れ層と可視層のユニットのタイプを設定
  .layer(0, new RBM.Builder(RBM.HiddenUnit.RECTIFIED, RBM.VisibleUnit.GAUSSIAN)
              .nIn(4) // 入力の大きさ Irisのfeatureの数
              .nOut(3) // 出力層。今回はそのままカテゴリ数になっている
              .weightInit(WeightInit.XAVIER)
              .activation("relu") // 活性化関数のタイプ Rectified Linear Unit
              .lossFunction(LossFunctions.LossFunction.RMSE_XENT)
              .updater(Updater.ADAGRAD) // RBMの勾配更新はAdaGradを使う
              .k(1) // # Contrastive Divergenceのiterationの数。RBM特有パラメタ
              .dropOut(0.5) // Drop Outをこの層で行う割合
              .build()
  ) 
  // 第二層 出力層
  .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
              .nIn(3) // 入力の数
              .nOut(3) // 出力の数
              .activation("softmax") // 活性化関数。Classificationなのでsoftmax
              .weightInit(WeightInit.XAVIER)
              .updater(Updater.ADAGRAD) // Full connection layerの更新の方法
              .dropOut(0.5) // Drop Outをこの層で行う割合
              .build()
  ) 
  .build();

Deep Learningは設定すべきHyper Parameterが多くtuningやモデルの設定が難しくなってしまうことが多いと言われています。まさにそれを体現したようなConfigurationの生成です。
何かよいモデルのプリセットなどがいくつかあれば十分というケースもあると思うのでそういった場合には使いづらいかもしれませんが、逆に考えるとこれだけのパラメタの自由度があたえられていますのでモデルを選択できる幅は広がりそうです。

このIris Datasetに対するexampleを動かしたければscriptが用意されていますのでそれをたたけば大丈夫です。

$ ./bin/run-example org.deeplearning4j.ml.JavaIrisClassification 

ただデータによっては(Mnistとか)必要な容量が環境に対して大きい場合があるのでrun-exampleに記載されている--driver-memory--executor-memoryを増やす必要があるかもしれません。

29
31
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
29
31