LoginSignup
6
3

More than 5 years have passed since last update.

マルチスレッドの各スレッドでnumpyの乱数seedを固定する

Last updated at Posted at 2018-06-09

環境:python 3.6, numpy 1.14

DeepLearning関係で、マルチスレッドで学習を回すアプリを書いていたのですが、実験結果の再現性を確保するために各スレッドで乱数シードを固定しようとしました。
Threadのクラスのコンストラクタでnp.random.seed(N)をやってあげたのですが、うまくいきませんでした。

イメージ↓

import threading
import numpy as np
import time


class learningThread(threading.Thread):
    def __init__(self, tag):
        super(learningThread, self).__init__()
        self.tag = tag
        np.random.seed(5) # 乱数固定したつもり

    def run(self):
        for i in range(5):
            # ここで学習してるイメージ
            r = np.random.randint(100)
            print(self.tag, "loop:", i, "rand:", r)
            time.sleep(1)


def main():
    th1 = learningThread("thread-0")
    th2 = learningThread("thread-1")
    th1.start()
    th2.start()


if __name__ == '__main__':
    main()

# 実行結果
# thread-0 loop: 0 rand: 44
# thread-1 loop: 0 rand: 47
# thread-0 loop: 1 rand: 64
# thread-1 loop: 1 rand: 67
# thread-0 loop: 2 rand: 67
# thread-1 loop: 2 rand: 9
# thread-0 loop: 3 rand: 83
# thread-1 loop: 3 rand: 21
# thread-0 loop: 4 rand: 36
# thread-1 loop: 4 rand: 87

推測するに、マルチスレッド上では乱数生成器はグローバルなもの1つしかなく、np.random.seedではそのSeedを初期化するので、結局各スレッドで乱数が揃うことはない、という挙動です。
(マルチプロセスだったらうまくいくんだろうな)

じゃあどうすればいいのかと調べたら、乱数生成器自体を作る方法があるんですね。

numpy.random.RandomState(seed=None)

RandomStateのオブジェクトを作って各スレッドに渡してあげれば、スレッドそれぞれが乱数生成器を持ってくれます。
RandomStateのコンストラクタ引数にseedがあるので、これを指定してあげれば乱数固定できます。
そして各スレッドでは渡されたRandomStateを使って乱数生成すれば、乱数が揃いますね。
→ 2018/6/11 LearningThreadクラスにRandomStateを引数で渡す必要はなかったので修正しました。

import threading
import numpy as np
import time


class learningThread(threading.Thread):
    def __init__(self, tag):
        super(learningThread, self).__init__()
        self.tag = tag

    def run(self):
        rst = np.random.RandomState(5) # スレッド専用の乱数生成器をseedを指定して作る

        for i in range(5):
            # ここで学習してるイメージ
            r = rst.randint(100) # 専用の乱数生成器で乱数を生成
            print(self.tag, "loop:", i, "rand:", r)
            time.sleep(1)


def main():
    # RandomStateオブジェクトをseedを指定して作成し、各スレッドに渡す
    th1 = learningThread("thread-0")
    th2 = learningThread("thread-1")
    th1.start()
    th2.start()


if __name__ == '__main__':
    main()

# 実行結果
# thread-0 loop: 0 rand: 99
# thread-1 loop: 0 rand: 99
# thread-1 loop: 1 rand: 78
# thread-0 loop: 1 rand: 78
# thread-1 loop: 2 rand: 61
# thread-0 loop: 2 rand: 61
# thread-1 loop: 3 rand: 16
# thread-0 loop: 3 rand: 16
# thread-1 loop: 4 rand: 73
# thread-0 loop: 4 rand: 73

めでたしめでたし。

6
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3