「kerasの学習が再現できない!」という時の解決法を書きます!
実行環境は下記の通りです。
- python3.6.1
- Keras==2.0.9
- tensorflow==1.4.0
導入(飛ばしていいけど読んでほしい!)
簡単なモデルを作る時にkerasを使いがちな自分ですが、今回研究でマルチレイヤーパーセプトロンを組んだ時に「学習の再現ができない!」という状態でしばらく格闘しました。
その時に発見した解決法を伝授いたします。
ちなみにブログ書くのは人生2回目の未熟者なのでお手柔らかに。。。
(1回目はこちらのkaggleの記事です!)
読み飛ばさないでほしいというのは、1回目のブログとかも読んで欲しかったからです。。。笑
ダメだった例
import numpy as np
import tensorflow as tf
np.random.seed(7)
tf.set_random_seed(7)
これでnumpyとtensorflowの乱数シードは固定していたのですが学習過程でずれが出てしまい、再現性が担保できませんでした。
結論(もう!?)
import os
import numpy as np
import random as rn
import tensorflow as tf
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(7)
rn.seed(7)
session_conf = tf.ConfigProto(
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1
)
from keras import backend as K
tf.set_random_seed(7)
sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
K.set_session(sess)
これ書いておけば再現性担保できました!
調べていた中でもnp.random.seedだけの記述も多かったんですが、それだけだと微妙に学習過程でずれが出ちゃいましたね。
PYTHONHASHSEEDでpythonの乱数固定して、random関数のシードも固定して、sessionも定義してようやくありつけました。
いらない部分もあるかもなので、コメントいただけるとありがたいです。
まとめ
今回はすごく短いですが、困っている人の助けになればなと思います。
近々othlo techという名古屋の学生クリエイティブ団体のAdvent Calendarでブログを書くことになると思うので、その時によろしくです!
また見てください〜