環境:python 3.6, numpy 1.14
DeepLearning関係で、マルチスレッドで学習を回すアプリを書いていたのですが、実験結果の再現性を確保するために各スレッドで乱数シードを固定しようとしました。
Threadのクラスのコンストラクタでnp.random.seed(N)をやってあげたのですが、うまくいきませんでした。
イメージ↓
import threading
import numpy as np
import time
class learningThread(threading.Thread):
def __init__(self, tag):
super(learningThread, self).__init__()
self.tag = tag
np.random.seed(5) # 乱数固定したつもり
def run(self):
for i in range(5):
# ここで学習してるイメージ
r = np.random.randint(100)
print(self.tag, "loop:", i, "rand:", r)
time.sleep(1)
def main():
th1 = learningThread("thread-0")
th2 = learningThread("thread-1")
th1.start()
th2.start()
if __name__ == '__main__':
main()
# 実行結果
# thread-0 loop: 0 rand: 44
# thread-1 loop: 0 rand: 47
# thread-0 loop: 1 rand: 64
# thread-1 loop: 1 rand: 67
# thread-0 loop: 2 rand: 67
# thread-1 loop: 2 rand: 9
# thread-0 loop: 3 rand: 83
# thread-1 loop: 3 rand: 21
# thread-0 loop: 4 rand: 36
# thread-1 loop: 4 rand: 87
推測するに、マルチスレッド上では乱数生成器はグローバルなもの1つしかなく、np.random.seedではそのSeedを初期化するので、結局各スレッドで乱数が揃うことはない、という挙動です。
(マルチプロセスだったらうまくいくんだろうな)
じゃあどうすればいいのかと調べたら、乱数生成器自体を作る方法があるんですね。
numpy.random.RandomState(seed=None)
RandomStateのオブジェクトを作って各スレッドに渡してあげれば、スレッドそれぞれが乱数生成器を持ってくれます。
RandomStateのコンストラクタ引数にseedがあるので、これを指定してあげれば乱数固定できます。
そして各スレッドでは渡されたRandomStateを使って乱数生成すれば、乱数が揃いますね。
→ 2018/6/11 LearningThreadクラスにRandomStateを引数で渡す必要はなかったので修正しました。
import threading
import numpy as np
import time
class learningThread(threading.Thread):
def __init__(self, tag):
super(learningThread, self).__init__()
self.tag = tag
def run(self):
rst = np.random.RandomState(5) # スレッド専用の乱数生成器をseedを指定して作る
for i in range(5):
# ここで学習してるイメージ
r = rst.randint(100) # 専用の乱数生成器で乱数を生成
print(self.tag, "loop:", i, "rand:", r)
time.sleep(1)
def main():
# RandomStateオブジェクトをseedを指定して作成し、各スレッドに渡す
th1 = learningThread("thread-0")
th2 = learningThread("thread-1")
th1.start()
th2.start()
if __name__ == '__main__':
main()
# 実行結果
# thread-0 loop: 0 rand: 99
# thread-1 loop: 0 rand: 99
# thread-1 loop: 1 rand: 78
# thread-0 loop: 1 rand: 78
# thread-1 loop: 2 rand: 61
# thread-0 loop: 2 rand: 61
# thread-1 loop: 3 rand: 16
# thread-0 loop: 3 rand: 16
# thread-1 loop: 4 rand: 73
# thread-0 loop: 4 rand: 73
めでたしめでたし。