1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

MXNet Tutorialを試す(3): Module - ニューラルネットの学習と推論

Posted at

MXNet Tutorialを順番にやっていくメモ(頑張って最後まで進む……)

Module

ニューラルネットの学習にはたくさんのステップが存在

  • 入力データの調整
  • モデルのパラメータの初期化
  • 順方向、逆方向への伝播
  • 勾配に応じた重みのアップデート
  • チェックポイントの作成

などなど. これらは初心者にも経験者にも面倒.

MXNet では学習と推論でよく使うものを module (省略してmod)パッケージにまとめてある.

Moduleでは高レベルと中レベルのどちらに対しても、定義済のネットワークの利用のためのインターフェースが用意されている.
インターフェースは交換可能.

はじめに

このチュートリアルではUCI文字認識のデータセットに対しての多層パーセプトロンの学習を行う.
Training: Test = 80:20.
イテレータは毎回バッチあたり32のデータを返す.

import logging
logging.getLogger().setLevel(logging.INFO)
import mxnet as mx
import numpy as np

# テスト用のモジュール使ってダウンロード、ここでrequestsが必要
# T,2,8,3,5,1,8,13,0,6,6,10,8,0,8,0,8 のようなデータが20000行並んでいる
fname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',')[:,1:] # テキストファイルの内容からnp.ndarrayを生成する
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])  # Unicodeの番号を返すメソッド`ord()`を利用しラベルづけ、`A`からの距離を格納

batch_size = 32
ntrain = int(data.shape[0]*0.8)  # 8割Training、2割Test
train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)  # イテレータ
val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)
OUT
    INFO:root:letter-recognition.data exists, skip to downloada
# データの確認
print(data)
print(data.shape)
print(label)
print(label.shape)
OUT
    [[  2.   8.   3. ...,   8.   0.   8.]
     [  5.  12.   3. ...,   8.   4.  10.]
     [  4.  11.   6. ...,   7.   3.   9.]
     ..., 
     [  6.   9.   6. ...,  12.   2.   4.]
     [  2.   3.   4. ...,   9.   5.   8.]
     [  4.   9.   6. ...,   7.   2.   8.]]
    (20000, 16)
    [19  8  3 ..., 19 18  0]
    (20000,)

ネットワークの定義

net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
net = mx.sym.Activation(net, name='relu1', act_type='relu')
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net)

output_5_0.png

Moduleの作成

Moduleは次のパラメータで定義

  • symbol: ネットワークの定義
  • context: 実行環境のデバイス
  • data_names: 入力するデータの変数名のリスト
  • label_names: 入力するラベルの変数名のリスト

netでは、dataという名前のデータ1つと、sofmax_labelと名前のついたラベル1つ.
(SoftmaxOutputsoftmaxと名前をつけたために自動で入力ラベル名は生成)

mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

中レベルのインターフェース

定義したmodを利用しての、中レベル(Intermediate-level)インターフェースを利用しての学習と推論の実行方法について.
forwardbackwordを走らせることにより柔軟な、ステップバイステップでの計算が可能、デバッグにも有用

学習のためには次のステップを踏む必要

  • bind: メモリを割り当てることで計算の環境の準備
  • init_params: パラメータの割当てと初期化
  • init_optimizer: オプティマイザの初期化、デフォルトはSGD
  • metric.create: 入力のメトリックから評価のメトリックを作成
  • forward: 順伝播計算
  • update_metric: 最後の順方向の計算から、metricの評価と蓄積
  • backward: 逆伝播計算
  • update: 設定されたオプティマイザと、前回の順/逆伝播計算で計算された勾配にもとづきパラメータを更新

下のような感じで利用できる.

# 与えられた入力データトラベルのshapeからメモリの割り当て
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# 一様乱数でパラメータを初期化
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# SGD(学習率0.1)を利用
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# Accuracyをメトリックとして利用
metric = mx.metric.create('acc')
# エポック数5で学習、つまり渡されたイテレータでデータをなめる
for epoch in range(5):
    train_iter.reset()
    metric.reset()
    for batch in train_iter:
        mod.forward(batch, is_train=True)  # 予測の計算
        mod.update_metric(metric, batch.label)  # 予測の蓄積
        mod.backward()  # 勾配の計算
        mod.update()  # パラメータのアップデート
    print('Epoch %d, Training %s' % (epoch, metric.get()))
OUT
    Epoch 0, Training ('accuracy', 0.45587499999999997)
    Epoch 1, Training ('accuracy', 0.65706249999999999)
    Epoch 2, Training ('accuracy', 0.7139375)
    Epoch 3, Training ('accuracy', 0.74568749999999995)
    Epoch 4, Training ('accuracy', 0.77006249999999998)

高レベルインターフェース

学習

学習、予測、評価に高レベルAPIを提供(前節で踏んだStepを省略可能).
fit APIは内部的に前節と同様のことを実行する.

# train_iterをリセット
train_iter.reset()

# modules作成
mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

# moduleをfit
mod.fit(train_iter,
        eval_data=val_iter,
        optimizer='sgd',
        optimizer_params={'learning_rate': 0.1},
        eval_metric='acc',
        num_epoch=8)
OUT
    INFO:root:Epoch[0] Train-accuracy=0.366063
    INFO:root:Epoch[0] Time cost=0.639
    INFO:root:Epoch[0] Validation-accuracy=0.595250
    INFO:root:Epoch[1] Train-accuracy=0.631875
    INFO:root:Epoch[1] Time cost=0.617
    INFO:root:Epoch[1] Validation-accuracy=0.679500
    INFO:root:Epoch[2] Train-accuracy=0.703562
    INFO:root:Epoch[2] Time cost=0.610
    INFO:root:Epoch[2] Validation-accuracy=0.727750
    INFO:root:Epoch[3] Train-accuracy=0.745188
    INFO:root:Epoch[3] Time cost=0.613
    INFO:root:Epoch[3] Validation-accuracy=0.769500
    INFO:root:Epoch[4] Train-accuracy=0.768312
    INFO:root:Epoch[4] Time cost=0.616
    INFO:root:Epoch[4] Validation-accuracy=0.798250
    INFO:root:Epoch[5] Train-accuracy=0.784500
    INFO:root:Epoch[5] Time cost=0.610
    INFO:root:Epoch[5] Validation-accuracy=0.806000
    INFO:root:Epoch[6] Train-accuracy=0.795813
    INFO:root:Epoch[6] Time cost=0.610
    INFO:root:Epoch[6] Validation-accuracy=0.798250
    INFO:root:Epoch[7] Train-accuracy=0.803500
    INFO:root:Epoch[7] Time cost=0.614
    INFO:root:Epoch[7] Validation-accuracy=0.792500

fit

  • eval_metric: accuracy
  • optimizer: sgd
  • optimizer_params: (('learning_rate', 0.01),)

がデフォルトになっている.

optimizer_params この書き方だとtupleでないだろうか

予測と評価

predict()で予測の実行. 全ての予測の結果について返す.

y = mod.predict(val_iter)
assert y.shape == (4000, 26)

予測の出力はいらないが、テストセットの評価が欲しい時はscore関数が利用可能.
入力のvalidationデータセットに対して予測を実行し、与えられたメトリックに応じパフォーマンスを評価する.

score = mod.score(val_iter, ['acc'])
print('Accuracy score is %f' % (score[0][1]))
OUT
    Accuracy score is 0.792500

他のメトリックについても利用可能. top_k_acc(top-k-accuracy)、F1, RMSE, MSE, MAE, ce(Cross entropy)など

SaveとLoad

checkpointのコールバックを利用することで毎回のトレーニングのパラメータを保存することができる

# チェックポイントで保存するためのコールバック関数
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix)

mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)
OUT
    INFO:root:Epoch[0] Train-accuracy=0.089438
    INFO:root:Epoch[0] Time cost=0.614
    INFO:root:Saved checkpoint to "mx_mlp-0001.params"
    INFO:root:Epoch[1] Train-accuracy=0.244375
    INFO:root:Epoch[1] Time cost=0.631
    INFO:root:Saved checkpoint to "mx_mlp-0002.params"
    INFO:root:Epoch[2] Train-accuracy=0.434250
    INFO:root:Epoch[2] Time cost=0.615
    INFO:root:Saved checkpoint to "mx_mlp-0003.params"
    INFO:root:Epoch[3] Train-accuracy=0.535250
    INFO:root:Epoch[3] Time cost=0.615
    INFO:root:Saved checkpoint to "mx_mlp-0004.params"
    INFO:root:Epoch[4] Train-accuracy=0.608187
    INFO:root:Epoch[4] Time cost=0.618
    INFO:root:Saved checkpoint to "mx_mlp-0005.params"

保存されたパラメータのロードにはload_checkpoint関数を利用.
Symbolをロードしパラメータと紐づけ、ロードしたパラメータをmoduleにセットできる.

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
assert sym.tojson() == net.tojson()

# ロードしたパラメータをモジュールへと関連づける
mod.set_params(arg_params, aux_params)

保存した状態から学習を再開したい時は、set_paramsを呼ぶ代わりに直接fitにパラメータを渡すことができる.begin_epochパラメータによりどのポイントから再開すればいいかをfitに伝える.

mod = mx.mod.Module(symbol=sym)
mod.fit(train_iter,
        num_epoch=8,
        arg_params=arg_params,
        aux_params=aux_params,
        begin_epoch=3)
OUT
    INFO:root:Epoch[3] Train-accuracy=0.535250
    INFO:root:Epoch[3] Time cost=0.626
    INFO:root:Epoch[4] Train-accuracy=0.608187
    INFO:root:Epoch[4] Time cost=0.613
    INFO:root:Epoch[5] Train-accuracy=0.652062
    INFO:root:Epoch[5] Time cost=0.612
    INFO:root:Epoch[6] Train-accuracy=0.683438
    INFO:root:Epoch[6] Time cost=0.594
    INFO:root:Epoch[7] Train-accuracy=0.702313
    INFO:root:Epoch[7] Time cost=0.582

  • Python 3 + Ubuntu, GPU環境で学習中
  • 全文和訳しているわけではないです、あくまでメモ程度に
  • Jupyterからの出力に手を入れただけなのでレイアウト崩れるやも…

Keras みたいな感じですね.

次はBasicセクションの最後、Iteratorの予定(もう出て来てる気がしますが)

1
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?