Source
- Learning Latent Dynamics for Planning from Pixels
- Introducing PlaNet: A Deep Planning Network for Reinforcement Learning
- Github Code
Introduction
Reinforcement learning can be categorized into two groups; model-free and model-based learning. For model-free methods, the agent has no knowledge of the underlying mechanics of the environment, and can only map superior actions based on the environment observations. While this is a completely feasible approach, it generally takes an extended amount of time to train the agent before it is suitable for practical use.
Compared to its counterpart, model-based methods learn the dynamics of the environment, and can be more data efficient as they can select actions based on long-term outcomes.
Recently, Google has collaborated with DeepMind to develop a new Deep Planning Network (PlaNet) agent, which is able to solve a variety of image-based tasks with up to 5000% (x50) the data efficiency while maintaining competitiveness with advanced model-free agents.
Implementation
Latent Space Planning
The problem is defined as a partially observable Markov decision process (POMDP) as an image does not contain the full state of the environment.
The goal is to implement a policy $p(a_t∣o≤t,a<t)\mathrm{}$ that maximizes the expected sum of rewards $E_{\mathrm{p}}[ \sum_{\tau=t+1}^T \mathrm{p}(r_\tau|s_\tau)]$.
Model-Based Planning
PlaNet learns 4 models from previously experienced episodes, which are:
1. Transition model
2. Observation model
3. Reward model
4. Variational Encoder
The transition model is Gaussian with mean and variance parameterized by a feed-forward neural network, the observation model is Gaussian with mean parameterized by a deconvolutional neural network and identity covariance, the reward model is a scalar Gaussian with mean parameterized by a feed-forward neural network and unit variance and the variational encoder is Gaussian with mean and variance parameterized by a convolutional neural network followed by a feed-forward neural network.
The agent re-plans at each step, and there is no policy network involved.
Experience Collection
Planning Algorithm
The algorithm uses cross entropy method (CEM) to search for best action sequence (one that maximizes the objective, aka getting the most reward). The optimal action sequence is initialized as Gaussian and the planner samples J candidates of action sequences each time and updates the belief to the top K action sequences. This is repeated I times until the action mean $μ_{t}$ is returned for the current time step (the first action of the best action sequence).
It is important to reinitialize the action sequence to zero mean and unit variance after receiving the next observation to avoid local optima.
When evaluating a candidate sequence, only one trajectory is sampled; this allows the algorithm to direct the computational cost on multiple sequences rather than multiple trajectories of the same sequence. Because the reward is modeled as a function of the latent state, the planner can perform planning without the expensive decoder for image generation.
Recurrent State Space Model
The proposed recurrent state-space model (RSSM) can predict forward purely in latent space.
The training objective for a stochastic model (Figure 4b) is:
However, as purely stochastic transitions make it difficult for the transition model to reliably remember information for multiple time steps, a deterministic component is introduced to allow the model to have access to all previous states.
The combination of stochastic and deterministic components results in the model shown in Figure 4c. The model can be trained using equation (3), except with additional KL divergence terms due to an added fixed global prior to prevent the posteriors from collapsing in near-deterministic environments.
Latent Overshooting
The transition function is trained via the KL-divergence regularizers for one-step predictions. However, a model trained on one-step predictions does not ensure that it can generate optimal multi-step path. Therefore, PlaNet trains on multi-step predictions instead.
The multi-step predictions, which are computed by repeatedly applying the transition model and integrating out the intermediate states, is defined as:
where $d$ is the fixed multi-step distance
The variational bound on the multi-step predictive distribution $p_d$ is then:
The bound is limited to a fixed distance, but we need to generate accurate predictions up to the planning horizon. The authors proposed latent overshooting, an objective function for latent sequence models that generalizes the standard variational bound to train the model on multi-step predictions of all distances.
Equation (7) is the final objective function used to train the dynamics model of the agent.
Result
The evaluation is performed by using six image-based continuous control tasks, with the only observation being images of size 64x64x3 pixels. PlaNet outperforms A3C in 1/50 of the episodes and achieves similar performance to the top model-free algorithm D4PG. The training time of 1 day on a single Nvidia V100 GPU is comparable to that of D4PG.
Training evolution over 2000 episodes
Comparison to model-free methods
*Numbers indicate mean performance over 4 seeds.
Stochastic and Deterministic Dynamic Model
Latent overshooting effectiveness
One Agent for All Tasks
The agent is trained in different environments without prior knowledge of the current task. To achieve this, the action spaces are padded with unused elements to make them compatible. The agent reaches the same average performance over tasks as individually trained agents.
Below is predictions of the PlaNet agent trained on multiple tasks. Holdout episodes are shown above with agent video predictions below. The agent observes the first 5 frames as context to infer the task and state and accurately predicts ahead for 50 steps given a sequence of actions.
Conclusion
- Very data efficient.
- Multi-task agent.
- Distributed training?
- Adaptable for real world tasks?