3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Ray+keras(Eager Mode)を用いたDerivative Free最適化の例について

3
Last updated at Posted at 2018-04-02

※ 著者の勘違いが孕んでいる可能性があるので注意して読んでほしいです。

強化学習の今日日の流れとして、分散処理が切っても切れない関係となっています。そして、分散処理のコーディングはなかなか骨が折れることはPythonのmultiprocessingなどを書いたことのある方はよくご存知だと思います。ray-projectが提供するrayというライブラリはとてもイージーなインターフェースになっており、デコレータ1つで分散処理用の関数に置き換えることができます。

少し話題が変わりますが、前々からTensorFlowのデバッギングが大変だと考えがえられており、少しずつDefine by Run(Eager Mode)に移行しているように感じます。

特に次のChollet氏のツイートでkeraseager modeがこれほど相性良く書けるのかと驚きました。

Rayについて

今回は、ray-project/tutorialのDerivative Free Optimizationを例にとって、ray+keras(eager mode)を紹介したいと思います。実行環境を整えるためにまず、pipでインストールしていきます。

pip install tensorflow-gpu==1.7
pip install 'ray[rllib]'

rayのAPIは基本的にThe Ray APIで参照することができます。

今回のプログラムで用いたrayのメソッドはray.init()ray.remote()ray.get()だけです。ray.initからrayプロセスと呼ばれる、スケジューラ、オブジェクトマネージャとサーバの一式を立ち上げて、分散処理の準備をします。ここでは、CPU, GPUの数を指定したり、もともと立ち上げてたredisサーバのIPアドレスを指定すればそこでアクセスするようになります。@ray.remoteデコレータにより、リモート関数に再定義してくれます。再定義されたリモート関数は、どうように呼び出せず、

@ray.remote
def f(x):
    return x + 1

In [0]: x_id
Out[0]: ObjectID(658726d0075506ba5367e615b6df68241546e148)

とデコレータによって生えた.remoteメソッドによって呼び出すことができるようになります。この時点では、まだ実行されておらず、rayクラスタが管理するオブジェクトのID(object ID)が返ってきます。multiprocessingでもWorkerが並列に呼び出すように、実行結果ではなくcallableな関数を返すようなイメージだと思います。

実の値を取り出すにはray.get()を用います。

In [1]: ray.get(x_id)
Out[1]: 1

Eager Modeのkerasを併用した解答

先述した通り、ray-project/tutorialのDerivative Free Optimizationを例にあげたいのですが、Derivative Free最適化自体は、強化学習で微分(勾配)を用いない最適化を指しており、この演習ではランダム初期値のニューラルネットワークを政策(Policy)としたモデルで、100回または1000回実行したときの平均のreturnと最大のreturnを求めてます。もちろん、逐次的に100回や1000回を求めるには莫大な時間がかかりますが、rayを用いると高速でできますよ、というのが演習の主目的だと考えて問題ないと思います。ちなみに、強化学習におけるreturnとは総和報酬で、論文によって様々なノーテーションがあるのですが、個人的に気に入ってるもので、

\begin{align}
r_{t}^{\gamma} &= \sum_{l = t}^{\infty} \gamma^{l-t} r(s_l, a_l) \\
&= \sum_{l = 0}^{\infty} \gamma^{t} r(s_{l+t}, a_{l+t})
\end{align}

があります。右辺の$r(\cdot, \cdot)$はreward functionと呼ばれる可測関数で、$s_t, a_t$はそれぞれ時刻$t$における状態と行動を表した確率過程です。$\gamma$はdiscounted factorと呼ばわれる割引率です。他の論文では、軌跡(trajectory)を$\tau$として、returnを単に$r(\tau)$と書くものもあれば、割引率を書かないで$r_t$と書く方もいらっしゃいます。

解答の重要な部分を載せました。フルコードはsolve2.pyを参照していただければありがたいです。

ray.init()

class TwoLayerPolicy(keras.Model):
  def __init__(self, num_inputs, num_hiddens, num_outputs=1):
    super(TwoLayerPolicy, self).__init__(name='two_layer_policy')
    self.dense1 = keras.layers.Dense(num_hiddens, activation=tf.nn.relu)
    self.dense2 = keras.layers.Dense(num_outputs)
  
  def call(self, observ, training=None, mask=None):
    hidden = self.dense1(observ)
    output = self.dense2(hidden)
    assert output.shape.as_list()[1] == 1
    # return 0 if np.all(output.numpy() < 0) else 1
    return tf.cond(tf.reduce_all(output < 0),
                   lambda: tf.zeros_like(output),
lambda: tf.ones_like(output))


@ray.remote
def evaluate_random_policy(num_rollouts):
  # Generate a random policy.
  policy = TwoLayerPolicy(4, 5)
  
  # Create an environment.
  env = gym.make('CartPole-v0')
  
  # Evaluate the same policy multiple times
  # and then take the average in order to evaluate the policy more accurately
  returns = [rollout_policy(env, policy) for _ in range(num_rollouts)]
  return np.mean(returns), np.max(returns)


def main():
  tf.enable_eager_execution()
  
  # Evaluate 100 randomaly generated policies.
  average_100_rewards, best_100_rewards = ray.get(evaluate_random_policy.remote(100))
  # Print the best score obtained.
  print("100 Policy:\n\tAverage return(total return): {0}\n\tBest return: {1}".format(
    average_100_rewards, best_100_rewards))

参照

3
2
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?