LoginSignup
4
5

More than 5 years have passed since last update.

ROOTで機械学習してみる〜Multi Layer Perceptron(ANN)

Last updated at Posted at 2019-03-02

概要

欧州原子核研究機構CERNが開発しているデータ分析ツールROOTを使って簡単に機械学習をしてみる。

ROOTで用意されているチュートリアルのサンプルコードを参考に、
Multi Layer Perceptron(MLP)で回帰曲線を出すところまでやってみる。

ROOTのインストール方法や基本的な使い方は過去に書いた記事を参考までに。
https://qiita.com/dyamaguc/items/2f723cbc304c4debd82e
https://qiita.com/dyamaguc/items/397121b303e26f8286cf

環境

  • Mac Book Air
  • Mac OS X 10.13.6
  • ROOT 6.14/06

サンプルコードと実行結果

ROOTのサンプルコード(C++)

ROOTが提供しているサンプルコードはこちら
https://root.cern.ch/root/html608/mlpRegression_8C.html

ROOTをすでにインストールしている場合、サンプルコードは
<path to root>/tutorials/mlp/mlpRegression.C
ここにおいてある。
これを動かすには
root <path to root>/tutorials/mlp/mlpRegression.C
とターミナルでコマンドを打てば、いくつかプロットが出てきて、
うまく予測できていることがわかる。
C++のコードだが、コンパイルは不要。

自分のサンプルコード(Python)

ROOTが提供しているサンプルコードだとあまり面白くないので、
自分でサンプルコード少し変えて走らせてみた。
まず目的変数zに対し、説明変数x、yが次のような関係を持っているとする。

def theUnknownFunction(x, y):
    return TMath.Sin( (1.7+x)*(x-0.3) - 2.3*(y+0.7))

(x、y、z)の組を500個学習させて、得た回帰曲線の予想z_predと、正解z_trueを比較する。

from ROOT import TNtuple, TRandom, TMultiLayerPerceptron, TMLPAnalyzer, TMath, TCanvas, TGraph, TGraph2D, TTree, gROOT
from array import array
import numpy as np

def createData():
    N = 1000
    r = TRandom()
    r.SetSeed(0)

    data_tree = TTree("tree", "tree")
    x = np.empty((1), dtype="float32")
    y = np.empty((1), dtype="float32")
    z = np.empty((1), dtype="float32")
    data_tree.Branch("x", x, "x/f")
    data_tree.Branch("y", y, "y/f")
    data_tree.Branch("z", z, "z/f")

    # fill data
    for i in range(0, N):
        x[0] = r.Rndm()
        y[0] = r.Rndm()
        z[0] = theUnknownFunction(x, y)
        data_tree.Fill()

    del x,y,z

    return data_tree

if __name__ == '__main__':

    # fill data
    data_tree = createData()

    # create ANN
    mlp = TMultiLayerPerceptron("x,y:10:8:z", data_tree,  "Entry$%2","(Entry$%2)==0")
    mlp.Train(150, "graph update=10")

    mlpa = TMLPAnalyzer(mlp)
    mlpa.GatherInformations()
    mlpa.CheckNetwork()
    mlpa.DrawDInputs()

TMultiLayerPerceptronの最初の引数"x,y:10:8:z"はニューラルネットワークのレイヤーを記述している。
今回は、xとyが入力の層、その次の10と8はhidden layerのニューロンの数、zが出力。
今回はもとのサンプルと同じ設定にしている。
ROOTでは、Draw()で設定したニューラルネットワークの構造を可視化できる(下図)。
structure.png

第2引数はデータ。ROOTのデータ構造である、TTreeで渡している。
第3引数は学習用データの条件。今回はEntry(データのID)が奇数の場合を指定している。
第4引数はテスト(検証用)データの条件。Entryが偶数の場合を指定している。

その次にTrain(150, "graph update=10")で学習をしている。
第1引数はイテレーションの回数。
第2引数は学習のオプション。10イテレーション毎にグラフで学習曲線を描くように指定している。

TMLPAnalyzerで学習したMLPの結果を簡単に表示できる。
結果を可視化する。

#draw statistics shows the quality of the ANN's approximation
    canvas = TCanvas("TruthDeviation", "TruthDeviation")
    canvas.Divide(2,2)
    canvas.cd(1)
    mlp.Draw()

    canvas.cd(2)
    # draw the difference between the ANN's output for (x,y) and
    # the true value f(x,y), vs. f(x,y), as TProfiles
    mlpa.DrawTruthDeviations()

    canvas.cd(3)
    # draw the difference between the ANN's output for (x,y) and
    # the true value f(x,y), vs. x, and vs. y, as TProfiles
    mlpa.DrawTruthDeviationInsOut()

    canvas.cd(4)
    graph_truth_y05 = TGraph()
    graph_predi_y05 = TGraph()
    graph_truth_y05.SetMarkerStyle(20)
    graph_predi_y05.SetMarkerStyle(21)
    graph_predi_y05.SetMarkerColor(2)
    for ix in range(0, 15 ):
        v = array( 'd', [0,0])
        v[0] = ix / 10.0
        v[1] = 0.5
        graph_truth_y05.SetPoint( graph_truth_y05.GetN(), v[0], theUnknownFunction(v[0], v[1]))
        graph_predi_y05.SetPoint( graph_predi_y05.GetN(), v[0], mlp.Evaluate(0, v))
    graph_truth_y05.Draw("AP1")
    graph_predi_y05.Draw("P1 SAME")
    graph_truth_y05.SetTitle("y = 0.5;x;z")


    canvas.SaveAs("output.png")

    # To avoid error
    del mlp

実行結果は下図の通り。
x,y(右上),z(左下)を横軸にとって、縦軸に正解と予測の差をとると、いい精度で予測できていることがわかる。
右下は、y=0.5のときの正解(黒)と予想(赤)のプロット。横軸がx、縦軸がz。
xが1以上で予想と正解がずれているが、学習データの範囲外なので、まあそんなところかなというところ。
0 <= x <= 1では正解(黒)と予測(赤)がよく一致している。

output.png

まとめ

ROOTのフレームワークを使うことで簡単に、ニューラルネットワークの機械学習をしてみた。
チュートリアルのコードを参考にしたので、簡単に(数行で)回帰曲線を求められた。

参考

TMultiLayerPerceptron
https://root.cern.ch/doc/master/classTMultiLayerPerceptron.html#a9262eee03feb52487900c31f279691c2

TMLPAnalyzer
https://root.cern.ch/doc/master/classTMLPAnalyzer.html

実行後に出力されるエラーを消す方法
https://root-forum.cern.ch/t/output-could-be-produced-but-here-is-a-long-list-of-errors/10620/4

その他
https://www-he.scphys.kyoto-u.ac.jp/member/n.kamo/wiki/doku.php?id=study:software:root:pyroot

4
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
5