LoginSignup
0
2

More than 5 years have passed since last update.

MXNet で扱う主なオブジェクトと使い方

Posted at

MXNet で扱う主なオブジェクトの種類

MXNet で扱うオブジェクトとして、主に以下の4種類があります。

オブジェクト 概要 Deep Learning における役割
NDArray 多次元配列 データセットの保持
Symbol 多次元配列に対する操作 ネットワークの定義
Executor Symbol に NDArray を結び付けたもの forward/backward の実行
Model 学習済みモデル 予測

上記のオブジェクトを利用して、以下の流れでモデルを作るのが(いろいろいじる場合の)基本となります。

  1. NDArray にデータセットを格納し、
  2. ネットワーク構造を Symbol で定義し、
  3. データセットとネットワーク構造を結び付け Executor を作成し、
  4. forward/backward を繰り返し学習を行い、
  5. 得られた学習済みモデル Model を利用して予測を行う。

各オブジェクトを使ってみる

データセットを作成する(NDArray)

ソースコード
a.data <- mx.rnorm(shape = c(2, 3), mean = 0, sd = 1, ctx = mx.cpu())
b.data <- mx.rnorm(shape = c(2, 3), mean = 0, sd = 1, ctx = mx.cpu())
data <- list(a = a.data, b = b.data)
結果の確認
> print(data)
$a
         [,1]      [,2]      [,3]
[1,] 2.212206 0.7740038 1.0434403
[2,] 1.163079 0.4838046 0.2995635

$b
          [,1]      [,2]      [,3]
[1,] 1.1839255  1.891711 -1.234741
[2,] 0.1530255 -1.168815  1.558071

データセットは NDArray の名前付きリストで表します。リストの各要素の名前は、この後に mx.symbol.Variable で作成する Symbol の名前と一致させます。

データセットに対する操作を定義する(Symbol)

ソースコード
a.sym <- mx.symbol.Variable("a")
b.sym <- mx.symbol.Variable("b")
c.sym <- a.sym * b.sym
d.sym <- a.sym + c.sym
結果の確認
> graph.viz(d.sym$as.json(), graph.title = NULL, graph.width.px = 300, graph.height.px = 300)

network.png

mx.symbol.Variable で多次元配列のプレースホルダを作成し、それに対する操作を定義していきます。この段階では入力となる多次元配列の形状を指定する必要はありません。Symbol は多次元配列のように扱うことができるので、Symbol を仮の多次元配列とみなして +* で式を書けば操作を定義することができます。

二項演算子の他に、mx.symbol.FullyConnected のようなネットワーク構造を定義する関数や、mx.symbol.Activation のような活性化関数なども用意されており、通常はそれらを組み合わせてネットワークを構築していきます。

Symbol は as.json で json 形式に変換できます。また、それを graph.viz に与えることで、計算グラフを可視化できます。上記の図で a だけ表示されて b が表示されないのは謎です……

図中の a _Mul _plus0 は各操作に対する名前です。名前を明示的につけなかった場合には、重複しないよう名前が自動的に割り振られます。上記の図では、_Mul_plus0 が自動的に割り振られた名前です。

データセットと操作を結び付ける(Executor)

ソースコード
exec <- mx.simple.bind(symbol = d.sym, ctx = mx.cpu(), grad.req = "write", a = dim(a.data), b = dim(b.data))
mx.exec.update.arg.arrays(exec = exec, arg.arrays = data, match.name = T)
結果の確認
> print(exec$arg.arrays)
$a
         [,1]      [,2]      [,3]
[1,] 2.212206 0.7740038 1.0434403
[2,] 1.163079 0.4838046 0.2995635

$b
          [,1]      [,2]      [,3]
[1,] 1.1839255  1.891711 -1.234741
[2,] 0.1530255 -1.168815  1.558071

mx.simple.bind で入力データの形状を指定し、さらに mx.exec.update.arg.arrays で入力データを結び付けます。mx.simple.bindgrad.req には { "write", "add", "null" } の3種類を指定できるようですが、どう違うのかわかりませんでした…… null を指定すると微分できなくなります。普通は "write" を指定するっぽい。grad.req 以降には、mx.symbol.Variable で作成した Symbol の名前とその Symbol に結び付けるデータセットの形状をペアにして指定します。

計算する(forward)

ソースコード
mx.exec.forward(exec = exec, is.train = T)
結果の確認
> print(exec$outputs)
$`_plus0_output`
         [,1]        [,2]       [,3]
[1,] 4.831294  2.23819566 -0.2449387
[2,] 1.341059 -0.08167335  0.7663047

mx.exec.forward で実際に計算を行います。mx.exec.forwardis.train は、勾配を求めるため計算過程を保持する必要がある場合に true にするものだと理解しました(要確認)。

微分する(backward)

ソースコード
mx.exec.backward(exec = exec, out_grad = mx.nd.ones(shape = c(2, 3), ctx = mx.cpu()))
結果の確認
> print(exec$grad.arrays)
$a
         [,1]       [,2]       [,3]
[1,] 2.183926  2.8917112 -0.2347414
[2,] 1.153026 -0.1688148  2.5580711

$b
         [,1]      [,2]      [,3]
[1,] 2.212206 0.7740038 1.0434403
[2,] 1.163079 0.4838046 0.2995635

mx.exec.backward で微分します。mx.exec.backwardout_grad には dd/dd = 1 を指定します。ちなみに、引数の out_grad という名前は僕が適当につけました。exec 以外なら、なんでもいいっぽいです(要確認)。結果を見ると、dd/da = b + 1、dd/db = a となっており、正しく微分できていることが確認できます。

Model の作り方、使い方

ToDo: 後で調べる。

参考文献

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2