summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/minitaur/envs/simple_ppo_agent.py
blob: e3c4b4ed3621e38fdc71392cd97ea92fa54bfa79 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""An agent that can restore and run a policy learned by PPO."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf
from pybullet_envs.agents.ppo import normalize
from pybullet_envs.agents import utility


class SimplePPOPolicy(object):
  """A simple PPO policy that is independent to the PPO infrastructure.

  This class restores the policy network from a tensorflow checkpoint that was
  learned from PPO training. The purpose of this class is to conveniently
  visualize a learned policy or deploy the learned policy on real robots without
  need to change the PPO evaluation infrastructure:
  https://cs.corp.google.com/piper///depot/google3/robotics/reinforcement_learning/agents/scripts/visualize.py.
  """

  def __init__(self, sess, env, network, policy_layers, value_layers,
               checkpoint):
    self.env = env
    self.sess = sess
    observation_size = len(env.observation_space.low)
    action_size = len(env.action_space.low)
    self.observation_placeholder = tf.placeholder(
        tf.float32, [None, observation_size], name="Input")
    self._observ_filter = normalize.StreamingNormalize(
        self.observation_placeholder[0],
        center=True,
        scale=True,
        clip=5,
        name="normalize_observ")
    self._restore_policy(
        network,
        policy_layers=policy_layers,
        value_layers=value_layers,
        action_size=action_size,
        checkpoint=checkpoint)

  def _restore_policy(self, network, policy_layers, value_layers, action_size,
                      checkpoint):
    """Restore the PPO policy from a TensorFlow checkpoint.

    Args:
      network: The neural network definition.
      policy_layers: A tuple specify the number of layers and number of neurons
        of each layer for the policy network.
      value_layers: A tuple specify the number of layers and number of neurons
        of each layer for the value network.
      action_size: The dimension of the action space.
      checkpoint: The checkpoint path.
    """
    observ = self._observ_filter.transform(self.observation_placeholder)
    with tf.variable_scope("network/rnn"):
      self.network = network(
          policy_layers=policy_layers,
          value_layers=value_layers,
          action_size=action_size)

    with tf.variable_scope("temporary"):
      self.last_state = tf.Variable(
          self.network.zero_state(1, tf.float32), False)
      self.sess.run(self.last_state.initializer)

    with tf.variable_scope("network"):
      (mean_action, _, _), new_state = tf.nn.dynamic_rnn(
          self.network,
          observ[:, None],
          tf.ones(1),
          self.last_state,
          tf.float32,
          swap_memory=True)
      self.mean_action = mean_action
      self.update_state = self.last_state.assign(new_state)

    saver = utility.define_saver(exclude=(r"temporary/.*",))
    saver.restore(self.sess, checkpoint)

  def get_action(self, observation):
    normalized_observation = self._normalize_observ(observation)
    normalized_action, _ = self.sess.run(
        [self.mean_action, self.update_state],
        feed_dict={self.observation_placeholder: normalized_observation})
    action = self._denormalize_action(normalized_action)
    return action[:, 0]

  def _denormalize_action(self, action):
    min_ = self.env.action_space.low
    max_ = self.env.action_space.high
    action = (action + 1) / 2 * (max_ - min_) + min_
    return action

  def _normalize_observ(self, observ):
    min_ = self.env.observation_space.low
    max_ = self.env.observation_space.high
    observ = 2 * (observ - min_) / (max_ - min_) - 1
    return observ