Java
Scala
機械学習
MachineLearning
matplotlib

Matplotlib4jでJavaやScalaからMatplotlibを使う

JavaやScalaで機械学習をやろうとしたときに、イケてるグラフツールがない、PythonだったらMatplotlibがあるのに・・・という経験をした方も多いのではないかと思います。

そこで、MatplotlibをJavaから呼べるようにするライブラリ、Matplotlib4jを作ったので、紹介したいと思います。

使い方

ライブラリの追加

ここではJavaの例を紹介します。もちろんScalaやKotlinなどの他のJVM言語からも使えます。例については後述。

まずはMatplotlibを利用したいJavaプロジェクトにMatplotlib4jを追加します。

Mavenの場合、以下のdependencyを追加します。

Maven
<dependency>
    <groupId>com.github.sh0nk</groupId>
    <artifactId>matplotlib4j</artifactId>
    <version>0.4.0</version>
</dependency>

同様に、Gradleの場合は以下のようになります。

Gradle
compile 'com.github.sh0nk:matplotlib4j:0.4.0'

グラフの描画

使い方はMatplotlibのAPIに近いため直感的に書くことができます。はじめにPlotオブジェクトを生成し、それに対してpyployのメソッドを呼ぶことで任意のグラフを追加し、最後にshow()メソッドを呼びます。Builderパターンになっているので、IDEの補完を使って後ろにオプションを追加していきます。

散布図

手始めに、散布図を描いてみます。

ScatterPlot
List<Double> x = NumpyUtils.linspace(-3, 3, 100);
List<Double> y = x.stream().map(xi -> Math.sin(xi) + Math.random()).collect(Collectors.toList());

Plot plt = Plot.create();
plt.plot().add(x, y, "o").label("sin");
plt.legend().loc("upper right");
plt.title("scatter");
plt.show();

linspacemeshgridなど、グラフ描画の助けになるように一部のNumpyメソッドがNumpyUtilsクラスとして準備されています。はじめのブロックでプロットするためのxとyのデータを生成しています。ここではsinカーブにランダムな値を与えています。その後、Plotオブジェクトを生成、plot()メソッドに対して生成したxとyを追加し、最後にshow()を呼ぶとグラフが描画されます。

これは以下のPythonの実装とほぼ等価になります(ほぼ、というのはnumpyのデータ生成の部分が厳密には異なるため)。メソッドの呼び方が似ていて、Pythonistaの方にも使いやすくなっています。

PythonScatterPlot
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-3, 3, 100)
y = np.sin(x) + np.random.rand(100)

plt.plot(x, y, "o", label="sin")
plt.legend(loc="upper right")
plt.title("scatter")
plt.show()

上述のJavaコードによって、以下のようなグラフが描けます。

scatter.png

コンター図

次に、コンター図(等高線)を描いてみます。

ContourPlot
List<Double> x = NumpyUtils.linspace(-1, 1, 100);
List<Double> y = NumpyUtils.linspace(-1, 1, 100);
NumpyUtils.Grid<Double> grid = NumpyUtils.meshgrid(x, y);

List<List<Double>> zCalced = grid.calcZ((xi, yj) -> Math.sqrt(xi * xi + yj * yj));

Plot plt = Plot.create();
ContourBuilder contour = plt.contour().add(x, y, zCalced);
plt.clabel(contour)
    .inline(true)
    .fontsize(10);
plt.title("contour");
plt.show();

contour.png

ヒストグラム

ヒストグラムも同様に描けます。

HistogramPlot
Random rand = new Random();
List<Double> x1 = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
        .collect(Collectors.toList());
List<Double> x2 = IntStream.range(0, 1000).mapToObj(i -> 4.0 + rand.nextGaussian())
        .collect(Collectors.toList());

Plot plt = Plot.create();
plt.hist()
    .add(x1).add(x2)
    .bins(20)
    .stacked(true)
    .color("#66DD66", "#6688FF");
plt.xlim(-6, 10);
plt.title("histogram");
plt.show();

histogram.png

画像をファイルに保存

Matplotlib4jはファイルへの保存もサポートしています。サーバ上での機械学習の定期処理など、GUIを持たないユースケースでは画像のファイル保存が便利です。

本家のMatplotlib同様、.show()の代わりに.savefig()メソッドを用いることで、プロットウィンドウがポップアップする代わりに画像がファイルに保存されます。唯一の違いとして、.savefig()のあとにplt.executeSilently()をコールします。(これはsavefigコマンドもメソッドチェーンで繋げられるので、終端処理として必要になります。)

Random rand = new Random();
List<Double> x = IntStream.range(0, 1000).mapToObj(i -> rand.nextGaussian())
        .collect(Collectors.toList());

Plot plt = Plot.create();
plt.hist().add(x).orientation(HistBuilder.Orientation.horizontal);
plt.ylim(-5, 5);
plt.title("histogram");
plt.savefig("/tmp/histogram.png").dpi(200);

// ファイルを出力するために必要
plt.executeSilently();

これにより、以下のような画像が出力されます。

histogram.png

pyenv, pyenv-virtualenvによるPythonの切り替え

Matplotlib4jを使用するにはMatplotlibがインストールされているPythonを使用する必要があります。Matplotlib4jではデフォルトではパスが通っているPythonが使われますが、システムデフォルトのPythonにはMatplotlibをインストールしていないというケースも多いかと思います。

そうした場合、pyenvやpyenv-virtualenvを使っていれば、AnacondaのようなMatplotlibがインストールされているPython環境に切り替えることができます。

Pyenvの環境に応じたPythonを利用するには、Plotオブジェクトを生成する際に以下のようにPythonConfigを指定します。

pyenv
Plot plot = Plot.create(PythonConfig.pyenvConfig("任意のpyenv環境名"));

同様に、pyenv-virtualenvの環境名も指定できます。

pyenv-virtualenv
Plot plot = Plot.create(PythonConfig.pyenvVirtualenvConfig("任意のpyenv環境名", "任意のvirtualenv環境名"));

Scala

Scalaから利用する場合、上記の散布図の例は以下のように書くことができます。その際、Box/UnboxとListのクラスの違いに注意します。

ScalaScatter
import scala.collection.JavaConverters._

val x = NumpyUtils.linspace(-3, 3, 100).asScala.toList
val y = x.map(xi => Math.sin(xi) + Math.random()).map(Double.box)

val plt = Plot.create()
plt.plot().add(x.asJava, y.asJava, "o")
plt.title("scatter")
plt.show()

おまけ

きっかけ

最近「ゼロから作るDeep Learning ――Pythonで学ぶディープラーニングの理論と実装」を読み始めたのですが、そのままPythonで写経しても面白くないので、最近よく触っているScalaで実装することにしました。Scalaで関数型っぽく書けて、大満足で進めていたのですが、最急降下法によるバックプロパゲーションに差し掛かったところではじめて、あれ、Lossが全然下がらない、どこかバグってるんじゃない?という状況にぶち当たりました。

もちろんテストを厚くすればわかるだろう、というのがこうした場合の常套手段かと思いますが、まずは手っ取り早く本にあるようにグラフを表示して何が起こっているかを見てみたいものです。でもScalaだとイケてるグラフツールがない・・・。とはいえ一からグラフツールを実装するのはさすがに酷過ぎる・・・ということで、Pythonで馴染みのあるMatplotlibを使えるようにしようと考えたのが、ライブラリ製作のきっかけです。

設計

Matplotlib4jでは、JNIやJythonを使わずにPythonコードを生成する形でMatplotlibを呼んでいます。最初はJythonを使って実装しようと思ったのですが、そもそもPythonのバージョンが2.7までしかサポートしていない、またnumpyが使えないのでそれに依存しているMatplotlibも動かないためこの道は断念することに。

世の中にはCPythonをJavaコードから使えるようにするライブラリもあって、こちらはPython3もnumpyも使えるため候補にあがりました。しかし、JNIを使うために別途環境依存のライブラリをインストールしたり、Python側でのもpipからライブラリのインストールが必要だったりして、グラフを描くだけなのに使うための手間が大きすぎるために、結局こうしたライブラリに依存せずに実装することにしました。

もちろんファイルを介して実行するので変数の渡し方や返り値の利用に工夫必要があったり、さらにはパフォーマンスも大丈夫なの?と気になるところです。幸い、グラフを描くことだけが目的なので、基本的な機能は一方的にファイルに出力することで満たせますし、パフォーマンスについても多少の待ち時間は許容範囲に収まっているのではないかと思います。