summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/minitaur/envs/minitaur_reactive_env_example.py
blob: 2c2be710658bdeef85d1994d77ed21bfe272919a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
r"""Running a pre-trained ppo agent on minitaur_reactive_env."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

import inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(os.path.dirname(currentdir)))
print("parentdir=",parentdir)
os.sys.path.insert(0,parentdir)



import tensorflow as tf
from pybullet_envs.minitaur.agents.scripts import utility
import pybullet_data
from pybullet_envs.minitaur.envs import simple_ppo_agent


flags = tf.app.flags
FLAGS = tf.app.flags.FLAGS
LOG_DIR = os.path.join(pybullet_data.getDataPath(), "policies/ppo/minitaur_reactive_env")
CHECKPOINT = "model.ckpt-14000000"


def main(argv):
  del argv  # Unused.
  config = utility.load_config(LOG_DIR)
  policy_layers = config.policy_layers
  value_layers = config.value_layers
  env = config.env(render=True)
  network = config.network

  with tf.Session() as sess:
    agent = simple_ppo_agent.SimplePPOPolicy(
        sess,
        env,
        network,
        policy_layers=policy_layers,
        value_layers=value_layers,
        checkpoint=os.path.join(LOG_DIR, CHECKPOINT))

    sum_reward = 0
    observation = env.reset()
    while True:
      action = agent.get_action([observation])
      observation, reward, done, _ = env.step(action[0])
      time.sleep(0.002)
      sum_reward += reward
      if done:
        break
    tf.logging.info("reward: %s", sum_reward)


if __name__ == "__main__":
  tf.app.run(main)