最近、PyBrainというPython用の機械学習ライブラリを使用して、ニューラルネットを構築して学習させることがあった。
PyBrainを使えば、結構簡単にニューラルネットを構築できる上に、訓練したモデルをxmlファイルとして書き出しておくことができる。もちろん、これを読みこめばモデルの再利用ができる。
ので、今回はインストール方法とモデルの書き出し・読み込みをまとめておく。
1. PyBrainのインストール
インストールにはpipを使用する。
pip install pybrain
2. モデルの書き出し
from pybrain.tools.xml import NetworkWriter
を使用する。
from pybrain.tools.shortcuts import buildNetwork
# ~ 中略(訓練データの読み込みなど) ~
# ニューラルネットワークモデルの定義
network = buildNetwork(64, 19, 2)
# ~ 中略(データを使用して訓練を実行)~
# モデルを書き出す
NetworkWriter.writeToFile(network, 'model.xml')
これで、訓練済みのモデルの設定がmodel.xmlというファイルが出力される。
3. モデルの読み込み
from pybrain.tools.xml import NetworkReader
を使用する。
from pybrain.tools.xml import NetworkReader
network = NetworkReader.readFrom('model.xml')
先ほど出力したmodel.xmlを読み込むことで、学習済みモデルを再現することができる。
あとは、認識するなり追加学習するなり、好きにしてあげてください。