LoginSignup
11
1

More than 5 years have passed since last update.

tf.while_loop分かったかも

Last updated at Posted at 2018-07-18

tf.while_loopとは

tf.while_loopとは、TensorFlowの中でwhile loopを回すAPIです。https://www.tensorflow.org/api_docs/python/tf/while_loop

よくある画像の分類とかだと、TensorFlowの中でwhile loopを回すことはあまりなく、(少なくとも自分には)ぱっと見て簡単に使えないこのAPIが分かった気になったので、説明を書きます。

tf.while_loopの使い方

概略

tf.while_loopの引数は、

def while_loop(cond,
               body,
               loop_vars,
               shape_invariants=None,
               parallel_iterations=10,
               back_prop=True,
               swap_memory=False,
               name=None,
               maximum_iterations=None): ...

のようにやたらと引数があって困ります。必須のcond, body, loop_vars以外は、細かいことなので、無視してこの3つの引数についてだけ考えます。

while loopは、条件を満たすまである処理を繰り返すという処理です。ある処理というと雑ですが、「状態の更新を条件を満たす間ずっと行う」と解釈することもできます。tf.while_loopには、この解釈で考えると分かりやすいと思います。

「状態の更新を条件を満たす間ずっと行う」の内、条件の部分をcond、状態の更新をbody、初期状態をloop_varsの引数で指定します。返り値は、終了状態を意味するTensor(のtupleなど)になります。

詳細

cond, body, loop_varsに何を指定すればいいかを説明します。
手っ取り早く、簡単な例を挙げます。

import tensorflow as tf

N = 1000


def condition(i, x):
    print(i, x)
    # Tensor("while/Merge:0", shape=(), dtype=int32) Tensor("while/Merge_1:0", shape=(), dtype=int32)
    return i < N


def update(i, x):
    print(i, x)
    # => Tensor("while/Identity:0", shape=(), dtype=int32) Tensor("while/Identity_1:0", shape=(), dtype=int32)
    return i + 1, i + x


init_val = (0, 0)
loop = tf.while_loop(cond=condition, body=update, loop_vars=init_val)

with tf.Session() as sess:
    print(sess.run(loop))  # => (1000, 499500)

この例は、

i, x, N = 0, 0, 1000
while i < N:
    i, x = i + 1, x + i
print((i, x))

に相当します。このloopでは、(i,x)が状態に相当するものです。初期状態を(0,0)として、終了状態が(1000, 499500)です。condとbodyには、この状態を引数にする関数を指定します。正確には、状態を表す値のリストがloop_varsに指定され、その要素が一つずつに対応する引数を持つ関数です。

condには例のcondition関数のように条件を書くわけですが、途中でprintしてみると、condition関数が呼ばれるときiやxはTensorになっています。同様に、返り値も条件を表すtf.bool型のTensorになります。

bodyについても同様に途中でprintすると、i,xはTensorで、返り値もそれぞれはTensorです。bodyに指定する関数には、このように引数を更新した結果に相当するTensorのlist(もしくはtuple)を返り値にします。

あとは、初期状態を指定してtf.while_loopを呼び、その返り値をsession.runするという手順です。

応用

初期状態にplaceholderを使う

初期状態にplaceholderを使うことができます。次のようにします。

import tensorflow as tf

N = 1000

x_in = tf.placeholder(tf.float32, [])


def condition(i, x):
    return i < N


def update(i, x):
    return i + 1, x + tf.cast(i, tf.float32)


init_val = (0, x_in)
i_out, x_out = loop = tf.while_loop(condition, update, init_val)

with tf.Session() as sess:
    print(sess.run(loop, {x_in: -499500.}))  # => (1000, 0.0)
    print(sess.run(x_out, {x_in: 0}))  # => 0.0

初期状態にtf.float32型のtf.placeholderを使いました。計算の途中でcastしているのは、状態の型が変化してはいけないからです。同様にVariableなども指定できます。ただし、Variableは更新されず単に初期値として使われるだけなので、注意が必要です。

Variableを更新しながらloopさせる方法

tf.assign(x, y)はxにyを代入することに相当する操作を意味します。これはsession.runすると代入された値が得られます。これを応用して、Variableに代入しながらloopすることができます。

import tensorflow as tf

N = 1000
v = tf.Variable(0.)


def condition(i, x):
    return i < N


def update(i, x):
    return i + 1, tf.assign(v, x + i)


init_val = (0., 0.)
loop = tf.while_loop(condition, update, init_val)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(loop)
    print(sess.run(v))  # => 499500.0

確かにVariableの値が更新されました。もっとも、loopの中で何度も代入するより最後に一度だけ代入する方が良いでしょうが・・・

おまけ

ニュートン法

昔、TensorFlowが出た頃、わざわざOptimizerを使って、$\sqrt 2$を計算する記事を書きました。https://qiita.com/n_kats_/items/73538c7c66559d09f35d

今回はニュートン法で$\sqrt 2$を計算してみましょう。ただし、tf.while_loopを使ってpython側でloopを書かないものとします。次のようになります。

import tensorflow as tf

epsi = 1e-6


def func(x):
    return x * x - 2


def condition(x, dx):
    return tf.abs(dx) > epsi


def update(x, dx):
    y = x + dx
    f = func(y)
    df = tf.gradients(f, y)[0]
    dx_ = -(f / df)
    return y, dx_


init_val = 100., 100.

result_x, _ = tf.while_loop(condition, update, init_val)
with tf.Session() as sess:
    print(sess.run(result_x))  # => 1.4142135

NamedTuple

python3.6からあるNamedTupleを状態として使うことができる。ただし、loop_varsはリストなどにしないといけないので、次のようにします。

from typing import NamedTuple

import tensorflow as tf

N = 1000


class State(NamedTuple):
    i: int
    x: float


def condition(state: State):
    return state.i < N


def update(state: State):
    ii = state.i + 1
    return [State(ii, state.x + 1. / tf.cast((ii * ii), tf.float32))]


init_val = [State(i=0, x=0.)]
loop = tf.while_loop(condition, update, init_val)

with tf.Session() as sess:
    print(sess.run(loop))  # => [State(i=1000, x=1.6439348)]

確認環境

python3.6.6
tensorflow1.9.0

11
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
1