17
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

MXNetでXORを学習する

Posted at

はじめに

DeepLearningのフレームワークであるMXNetが高速だという話を聞いたので勉強のためにXORの学習を行いました。

インストール

MXNetページの"Build and Installation"を参照してください。
私の環境はWindows10なのですが、以下のようにインストールしました。

  • "Installing pre-built packages on Windows"に従ってpre-built package をインストール
  • Python packageをインストール

ソースコード

モジュールのインポート

mxnetをインポートします。
また学習データとしてNumpy配列を使うのでnumpyもインポートします。
loggingはログ出力に使うモジュールです。

import numpy as np
import mxnet as mx
import logging

学習データ

学習データを定義します。
今回はXORの学習なので入力は2値でそれぞれ0または1、全部で4個の入力値となります。
出力は入力をXORした値で0または1となり、クラス分類として学習するので整数として定義します。
ついでにミニバッチサイズも定義します。

x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
y = np.asarray([0, 1, 1, 0], dtype=np.int32)
batch_size = 4

ログ出力

ログ出力を定義します。
この定義がないと学習中になにも出力されず学習状態がわかりません。
ほぼMNIST学習のサンプルコードのままで、出力内容を変更したい場合にどうしたらよいかわかっていません。

logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')

ネットワーク定義

ニューラルネットワークの定義を行います。
今回学習するネットワーク構成は以下の通りです。

  • 入力: 2node
  • 隠れ層: 1層2node
  • 出力: 2node

MXNetではニューラルネットワークの各層をSymbolというクラスによって表現しており、Symbolを接続することでネットワークを構築するようです。
今回は層が直列に並んでいるだけなので対応するSymbolを順に生成していきます。

  • 入力
    symbol.Variable を使います。
    入力値の個数は指定しないで良いようです。
  • 隠れ層
    入力と隠れ層は全結合なので symbol.FullyConnected を使い、その後 symbol.Activation を使ってSigmoid関数を適用します。
  • 出力
    隠れ層と出力も全結合なので symbol.FullyConnected を使い、その後 symbol.SoftmaxOutput を使ってSoftmaxの出力を生成します。
    ネットワークの出力Symbolには symbol.XXXOutputというクラスを使うようで、ほかにも symbol.LinearRegressionOutput や symbol.LogisticRegressionOutput などがあります。
    注意点として、SoftmaxOutput の name を 'softmax' にする必要がありました。
    他の値を指定すると後で説明するeval_metricの計算時に、metric計算の対象がわからないというエラーが発生していました。
    metricの設定方法があるのかもしれませんがわかっていません。
net = mx.symbol.Variable('data')
net = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=2)
net = mx.symbol.Activation(data=net, name='sigmoid1', act_type='sigmoid')
net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=2)
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')

モデルの生成

Symbol から Model を生成します。
学習率などもここで設定します。
デフォルトでは更新アルゴリズムとしてSGDを使います。
ニューラルネットワークの初期値も一様乱数、正規乱数などいくつか種類があるのですが、 init.Xavier を使ったときに収束が早かったのでそれを使っています。
Xavierはデータから初期値のスケールを決定するアルゴリズムのようですが、詳しいことは理解できていません。

model = mx.model.FeedForward(
    ctx              = mx.cpu(),
    symbol           = net,
    numpy_batch_size = batch_size,
    num_epoch        = 200,
    learning_rate    = 1,
    momentum         = 0.9,
    initializer      = mx.init.Xavier(factor_type='in')
)

モデルの学習

model.fit を呼び出してモデルの学習を行います。
eval_metric を指定することで学習中に評価データを使った metrics を出力することができます。
通常は学習データと評価データは異なるのですが、今回は学習データを使って予測精度を出力します。
epoch_end_callback にはepoch終了時のコールバックを設定することができ、例えばモデルの保存を行うことができます。
mx.callback.do_checkpoint('xor') とすると'xor'をprefixとしたファイル名でモデルを保存するコールバックを生成します。(今回は保存しません。epoch_end_callbackの使用例だけ載せます)

model.fit(
    X                  = x,
    y                  = y,
    eval_data          = (x, y),
    eval_metric        = ['accuracy']
#    epoch_end_callback = mx.callback.do_checkpoint('xor')
)

モデルを使った予測

model.predict で予測を行うことができます。

print model.predict(x)

ソースコード全体

xor.py
import numpy as np
import mxnet as mx
import logging

x = np.asarray([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.float32)
y = np.asarray([0, 1, 1, 0], dtype=np.int32)
batch_size = 4

logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')

net = mx.symbol.Variable('data')
net = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=2)
net = mx.symbol.Activation(data=net, name='sigmoid1', act_type='sigmoid')
net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=2)
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')

model = mx.model.FeedForward(
    ctx              = mx.cpu(),
    symbol           = net,
    numpy_batch_size = batch_size,
    num_epoch        = 200,
    learning_rate    = 1,
    momentum         = 0.9,
    initializer      = mx.init.Xavier(factor_type='in')
)

model.fit(
    X                  = x,
    y                  = y,
    eval_data          = (x, y),
    eval_metric        = ['accuracy']
#    epoch_end_callback = mx.callback.do_checkpoint('xor')
)

print model.predict(x)

実行結果

ログのValidation-accuracyと、最後の予測結果からうまく学習できていることがわかります。

> python xor.py
2016-04-16 23:32:10,831 Start training with [cpu(0)]
2016-04-16 23:32:10,834 Epoch[0] Resetting Data Iterator
2016-04-16 23:32:10,834 Epoch[0] Train-accuracy=0.500000
2016-04-16 23:32:10,835 Epoch[0] Time cost=0.003
2016-04-16 23:32:10,835 Epoch[0] Validation-accuracy=0.500000
2016-04-16 23:32:10,838 Epoch[1] Resetting Data Iterator
2016-04-16 23:32:10,838 Epoch[1] Train-accuracy=0.500000
2016-04-16 23:32:10,838 Epoch[1] Time cost=0.002
2016-04-16 23:32:10,839 Epoch[1] Validation-accuracy=0.500000
2016-04-16 23:32:10,841 Epoch[2] Resetting Data Iterator
2016-04-16 23:32:10,842 Epoch[2] Train-accuracy=0.500000
2016-04-16 23:32:10,842 Epoch[2] Time cost=0.002
2016-04-16 23:32:10,842 Epoch[2] Validation-accuracy=0.500000

(中略)

2016-04-16 23:32:11,877 Epoch[198] Resetting Data Iterator
2016-04-16 23:32:11,877 Epoch[198] Train-accuracy=1.000000
2016-04-16 23:32:11,878 Epoch[198] Time cost=0.002
2016-04-16 23:32:11,878 Epoch[198] Validation-accuracy=1.000000
2016-04-16 23:32:11,880 Epoch[199] Resetting Data Iterator
2016-04-16 23:32:11,881 Epoch[199] Train-accuracy=1.000000
2016-04-16 23:32:11,881 Epoch[199] Time cost=0.002
2016-04-16 23:32:11,882 Epoch[199] Validation-accuracy=1.000000
[[ 0.99476057  0.00523944]
 [ 0.00427657  0.99572337]
 [ 0.00426387  0.99573612]
 [ 0.99578202  0.00421802]]

感想

fit メソッドで簡単に学習が行えるので、単純なクラス分類やregressionであれば使いやすそうに思いました。
ただ損失関数が複雑なケースでは損失関数の微分計算を実装する必要がありそうで、難しくなりそうな印象を受けました。

参考

17
19
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
19

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?