1. Qiita
  2. 投稿
  3. Scala

ScalaからTensorFlowのJava APIを呼びだすぞい

  • 23
    いいね
  • 0
    コメント

Googleが誇る機械学習ライブラリTensorFlow 1.0がリリースされましたね。
何が更新されたか覗いてみま…… Java API !?

Experimental APIs for Java and Go

Announcing TensorFlow 1.0

Java APIが使えるということは、JVM言語であるScalaからも呼びだせるはずなので、早速やってみました :raised_hands:

準備

tensorflow/tensorflow/java/README.md
に従って、JARとネイティブライブラリをダウンロードして解凍して、配置するだけです。
Mac OS Xだとこんな感じ。

$ tree
.
├── build.sbt
├── jni
│   └── libtensorflow_jni.dylib
├── lib
│   └── libtensorflow-1.0.0-PREVIEW1.jar
└── src
    └── main
        └── scala
            └── Main.scala

コード

ベクトルA (1, 2, 3)とベクトルB (4, 5 ,6)の要素積を計算するだけのテストコードです。

src/main/scala/Main.scala
import org.tensorflow._

object Main extends App {
  val graph = new Graph()
  val a = graph.opBuilder("Const", "a").
    setAttr("dtype", DataType.INT32).
    setAttr("value", Tensor.create(Array(1, 2, 3))).
    build().
    output(0)

  val b = graph.opBuilder("Const", "b").
    setAttr("dtype", DataType.INT32).
    setAttr("value", Tensor.create(Array(4, 5, 6))).
    build().
    output(0)

  val c = graph.opBuilder("Mul", "c").
    addInput(a).
    addInput(b).
    build().
    output(0)

  val session = new Session(graph)
  val out = new Array[Int](3)
  session.runner().fetch("c").run().get(0).copyTo(out)

  println(out.mkString(", "))
}

なお、GraphSessionTensorは明示的にclose()を呼ばないとリソース解放されないようなので、きちんとしたコードを書くときは注意してください。

実行

$ sbt run -Djava.library.path=./jni
...
4, 10, 18

はい、ベクトルC (4, 10, 18)が計算できました。

感想

機械学習、とくにディープラーニングといえばPythonの文化が強く、静的型付け言語が好きな自分としてはちょっと歯がゆい思いをしていたのですが、こうやって手に馴染んだ言語を使えるとテンション上がります :heart_eyes:
もちろん、Deeplearning4jなど良いJavaライブラリもあるのですが、最新の学習モデルはTensorFlowで実装されることが多いので……。

機械学習の面倒さは、まずは前処理のデータ整形に依るところが大きくて、そこも合わせて一つのコードで書けると敷居がグッと下がる気がします。
Java APIはまだ最低限のものしか用意されていないみたいですが、これからどんどんリッチになってくれると嬉しいですね :joy: :joy: :joy:

環境

build.sbt
name := "tensorflow-scala"
scalaVersion := "2.12.1"
Comments Loading...