RNN(LSTM)
RNNは時系列のデータを学習するためのモデルです.
RNNは長期記憶にとても弱い欠点があります.
それを改良したものがLSTMとなっています.
モデルや細かい説明は他のサイトに任せます.
Pythonで実装
TensorFlowはバージョンアップが早く,すぐに過去のプログラムが使えなくなってしまいます.
なのでネットで探してみても自分の環境で動かせるプログラムがなかったので,
自分の環境で動かせるようプログラムを変更していきます.
この記事のプログラムを使って実行できるLSTMを作成していきます.
TensorFlowのRNNを基本的なモデルで試す
詳しい説明などはそちらのページを参考にしてください.
環境
macOS Mojave
Python 3.5.2
TensorFlow 1.11.0
変更箇所
BasicLSTMCell(54行目)
cell = tf.nn.rnn_cell.BasicLSTMCell(num_of_hidden_nodes, forget_bias=forget_bias, state_is_tuple=False)
このままだと
WARNING:tensorflow:From /Users/~~/lstm_q.py:54: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
と警告が出てしまいます.
これはBasicLSTMCellをLSTMCellに書き換えてねという警告なのでそのまま変更するだけです.
cell = tf.nn.rnn_cell.LSTMCell(num_of_hidden_nodes, forget_bias=forget_bias, state_is_tuple=False)
すると同じ行のstate_is_tuple=Falseをstate_is_tuple=Trueに変更してねと警告が出るのですが,
まだ勉強不足でこの警告は取り除けていません.
これは出力をタプル(データ型)で出してくださいという命令なのですがそのまま直してしまうと,
型が違いますとエラーが出てしまいます.
ちなみに,タプルとは更新が不可能なリストのようなものです.
このタプルを使うと計算時間が早くなるので,TensorFlowではタプルを使用することを推奨しているみたいです.
警告の取り除き方がわかる方はぜひコメントなどで教えてもらえればありがたいです.
initialize_all_variables(120行目)
init = tf.initialize_all_variables()
このままだと
WARNING:tensorflow:From /Users/~~/tf_should_use.py:189: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
と警告が出てしまいます.
これも先ほどと同様initialize_all_variablesをglobal_variables_initializerに変更してねという警告なので変更します.
init = tf.global_variables_initializer()
save(sess, "model.ckpt")(145行目)
saver.save(sess, "model.ckpt")
このままだと
ValueError: Parent directory of model.ckpt doesn't exist, can't save.
エラーが出てしまいます.saver.saveの第二引数にはパスを入れなければならないみたいなので,
saver.save(sess, "./model.ckpt")
このように変更しました.
ただしmacOSでは相対パスで対応しているのですが,
Windowsでは絶対パスを与えないといけないみたいなので注意してください.
FileWriter("/tmp/tensorflow_log", graph=sess.graph)(139行目)
上の項目を直すことでエラーと一つ以外の警告は出なくなったのですが,
tensorflow_logが絶対パスの/tmp/tensorflow_logに保存されていたので変更しました.
このままの方が正解だったのかもしれませんがとりあえず作業ディレクトリの中に保存しました.
summary_writer = tf.summary.FileWriter("/tmp/tensorflow_log", graph=sess.graph)
を
summary_writer = tf.summary.FileWriter("./tmp/tensorflow_log", graph=sess.graph)
のように変更しました.
次のステップ
このプログラムの理解度をより上げていき,
前回取得した為替データを使って学習していくことが目的です.