数日前にpiyo7さんが投稿された ScalaからTensorFlowのJava APIを呼びだすぞい をみて JRuby からだとどういうコードになるのか試してみました。
コード
require 'pp'
require 'java'
require './libtensorflow-1.0.0-PREVIEW1.jar'
module TF
include_package 'org.tensorflow'
end
graph = TF::Graph.new
a = graph.opBuilder("Const", "a").
setAttr("dtype", TF::DataType::INT32).
setAttr("value", TF::Tensor.create([1, 2, 3].to_java(:int))).
build().
output(0)
b = graph.opBuilder("Const", "b").
setAttr("dtype", TF::DataType::INT32).
setAttr("value", TF::Tensor.create([4, 5, 6].to_java(:int))).
build().
output(0)
c = graph.opBuilder("Mul", "c").
addInput(a).
addInput(b).
build().
output(0)
session = TF::Session.new(graph)
out = Array.new(3).to_java(:int)
session.runner().fetch("c").run().get(0).copyTo(out)
pp out #=> int[4, 10, 18]@71623278
$ ruby -J-Djava.library.path=./jni example.rb
で実行すると確かに [4, 10, 18] が求まります。
ちょっとハマったのが org.tensorflow を import する処理
import 'org.tensorflow.*'
とは書けず、名前空間 TF を定義して include_package する方式になりました。
比較
Ruby から TensorFlow を実行するライブラリは、昨年6月に公開された
が定番のようです。
このライブラリ付属のサンプルコードと上記の example.rb を比較すると example.rb の方がかなり冗長です。
もし今回のコードを発展させる意味があるとすると、module TF をよりインテリジェントなラッパーに置き換えることくらいでしょう。
(1) Java 側のシンボル情報を用いて module TF をメタプログラミングする。
(2) Kerasライクに、より抽象化された (Ruby らしい)API を提供する。
といったことが考えられます。
Keras は Keras 2 で TensorFlow に統合されるとのこと(→Spring 2017 roadmap: Keras 2, PR freeze, TF integration )で、Java API の仕様はまだまだ不安定と見受けます。残念ながら(1)(2)いずれも時期尚早ですね。