Tensorflowで凝った処理を行う際に、非同期で変数を読み書く必要が出てくるケースがあります。
例えば非同期的にモデルの学習と評価を行う場合、学習器がパラメーターを更新していく横で、評価器がそのパラメーターを評価することになります。
では、学習器がパラメーターを更新中に評価器がそのパラメーターにアクセスしてしまった場合、何が起こるでしょうか? というのが本記事のテーマです。
変数が一つの例
以下は、長さ10000の変数$a$に対して、各要素を1ずつ足していくプログラムです。
1ずつ足していく横で別スレッドから繰り返し$a$の値を読み込んでおり、その合計を計算しています。
もし、$a$の全要素が足し終わったあとに、読み込めていれば合計は常に10000の倍数となるはずです。
size = 10000
a = tf.Variable(np.zeros(size), dtype=tf.int32)
add = a.assign_add(np.ones(size, dtype=np.int32))
sess.run(tf.global_variables_initializer())
def reader():
while True:
value = sess.run(a)
s = np.sum(value)
if s % size != 0:
print (s)
th1 = threading.Thread(name="reader", target=reader)
th1.start()
for i in range(100000):
sess.run(add)
この出力は以下のようになり、$a$への書き込み中に変数を読み込んでしまっていることがわかります。
3196076
5008044
6012928
7932040
対策
この例はVariableをResourceVariableに切り替えるだけで対応できます。tf.Variableのコンストラクタにuse_resource=Trueを渡す方法と
tf.enable_resource_variables()
を呼び出す方法があります。これを行うことで、Variable宣言時に指定しなければResourceVariableが得られます。
なお、assign_addにはuse_lockingという引数がありますが、同時書き込みからの保護しか行わないらしく、このケースでは役に立ちません。(参考記事)
変数が複数の例
これで一変数のケースでは書き込みと読み込みが同時に起こらないようにできました。しかし、変数が複数のときはこうは行きません。3層のCNNに対して、2層更新した状態で評価を行うと、不正な状態のNetworkを評価してしまうことになります。
この場合は、tf.contrib.framework.CriticalSectionを使います。Tensorflowの2つのGraphを同時に実行しないように制御してくれる関数です。以下のプログラムでは、2つの変数を同時にインクリメントするwriterと、2つの変数の差を評価するreaderを別々のスレッドで動作させています。
size = 10000
use_critical = False
with tf.Graph().as_default() as g:
# なくとも正常に動作する
tf.enable_resource_variables()
sess = tf.Session()
a = tf.Variable(np.zeros(size), dtype=tf.int32)
b = tf.Variable(np.zeros(size), dtype=tf.int32)
def reader():
value_a, value_b = a.read_value(), b.read_value()
diff_value = tf.reduce_sum((value_a - value_b)**2)
return diff_value
def writer():
add_a = a.assign_add(np.ones(size, dtype=np.int32))
add_b = b.assign_add(np.ones(size, dtype=np.int32))
group_add = tf.group([add_a, add_b])
return group_add
if use_critical:
cs = tf.contrib.framework.CriticalSection()
read_op = cs.execute(reader)
write_op = cs.execute(writer)
else:
read_op = reader()
write_op = writer()
def read_caller():
while True:
diff_value = sess.run(read_op)
if diff_value != 0:
print (diff_value)
def write_caller():
while True:
sess.run(write_op)
sess.run(tf.global_variables_initializer())
th1 = threading.Thread(name="reader", target=read_caller)
th2 = threading.Thread(name="writer", target=write_caller)
th1.start()
th2.start()
use_critical=Falseのときは、変数aが更新され、変数bが更新されていないタイミングでreaderに値を読み込まれてしまいますが、use_critical=Trueのときはこの問題が回避できます。