LoginSignup
7
11

More than 5 years have passed since last update.

2つのtf.Sessionを並列で実行する

Last updated at Posted at 2017-08-02

tf.Sessionを複数,それも並列に実行させたい場合にはオフィシャルにもある通り,tf.train.Serverを使えばよい.この方法を使うとSessionを複数定義することができ,並列に計算を行うことができる.multiprocessingと一緒に使う場合のサンプルは以下の通りになる:

multiprocessing_with_tensorflow.py
import numpy as np
import tensorflow as tf
import multiprocessing as mp
import time

class Worker(object):
    def __init__(self, queue, target):
        self.sess = tf.Session(target)
        self.queue = queue
        self.a_plh = tf.placeholder(tf.float32, [2,2])
        self.b_plh = tf.placeholder(tf.bool)

    def main(self):
        while True:
            queue = self.queue.get()

            a, b = queue
            _a, _b = self.sess.run([self.a_plh, self.b_plh], \
                    {self.a_plh: a, self.b_plh: b})
            time.sleep(1)

            print("get queue : \n {} \n {}".format(_a, _b))

    def __call__(self):
        self.main()

cluster = tf.train.ClusterSpec({'local': ['localhost:2222', 'localhost:2223']})

server_host = tf.train.Server(cluster, job_name='local', task_index=0)
server_worker = tf.train.Server(cluster, job_name='local', task_index=1)

queue = mp.Queue()
worker = Worker(queue, server_worker.target)

p = mp.Process(target=worker)
p.daemon = True
p.start()

sess = tf.Session(server_host.target)

a = tf.random_uniform([2, 2])
b = tf.random_uniform([1])

while True:
    now = time.time()

    _a, _b = sess.run([a, b])
    time.sleep(3)

    queue.put([_a, bool(_b)])

    new_now = time.time()
    diff = new_now - now
    print("{0:3.3f} [sec]".format(diff))

Workerが並列に動き,Queueを通して得られた値を使って演算する(ここではただ中身を返すだけ).実行結果は


3.034 [sec]
get queue :
 [[ 0.84218264  0.23204899]
 [ 0.18645024  0.48593903]]
 True
3.006 [sec]
get queue :
 [[ 0.62776446  0.90778756]
 [ 0.06188095  0.51110578]]
 True
3.002 [sec]
get queue :
 [[ 0.69087613  0.15397286]
 [ 0.90308774  0.81757426]]
 True
3.003 [sec]
get queue :
 [[ 0.36197078  0.58955324]
 [ 0.98018742  0.43241131]]
 True

のようになる. メイン(server_host)の計算がserver_workerの計算を行なっている途中でも実行される.

ただし計算を実際に実行するリソースは限られているので,単純に並列化すればよいわけではない.

7
11
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
11