summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/pybullet/gym/pybullet_utils/arg_parser.py124
-rw-r--r--examples/pybullet/gym/pybullet_utils/examples/testlog.py9
-rw-r--r--examples/pybullet/gym/pybullet_utils/logger.py128
-rw-r--r--examples/pybullet/gym/pybullet_utils/math_util.py18
-rw-r--r--examples/pybullet/gym/pybullet_utils/mpi_util.py52
-rw-r--r--examples/pybullet/gym/pybullet_utils/util.py13
6 files changed, 344 insertions, 0 deletions
diff --git a/examples/pybullet/gym/pybullet_utils/arg_parser.py b/examples/pybullet/gym/pybullet_utils/arg_parser.py
new file mode 100644
index 000000000..5461a5bf2
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/arg_parser.py
@@ -0,0 +1,124 @@
+import re as RE
+
+class ArgParser(object):
+ global_parser = None
+
+ def __init__(self):
+ self._table = dict()
+ return
+
+ def clear(self):
+ self._table.clear()
+ return
+
+ def load_args(self, arg_strs):
+ succ = True
+ vals = []
+ curr_key = ''
+
+ for str in arg_strs:
+ if not (self._is_comment(str)):
+ is_key = self._is_key(str)
+ if (is_key):
+ if (curr_key != ''):
+ if (curr_key not in self._table):
+ self._table[curr_key] = vals
+
+ vals = []
+ curr_key = str[2::]
+ else:
+ vals.append(str)
+
+ if (curr_key != ''):
+ if (curr_key not in self._table):
+ self._table[curr_key] = vals
+
+ vals = []
+
+ return succ
+
+ def load_file(self, filename):
+ succ = False
+ with open(filename, 'r') as file:
+ lines = RE.split(r'[\n\r]+', file.read())
+ file.close()
+
+ arg_strs = []
+ for line in lines:
+ if (len(line) > 0 and not self._is_comment(line)):
+ arg_strs += line.split()
+
+ succ = self.load_args(arg_strs)
+ return succ
+
+ def has_key(self, key):
+ return key in self._table
+
+ def parse_string(self, key, default=''):
+ str = default
+ if self.has_key(key):
+ str = self._table[key][0]
+ return str
+
+ def parse_strings(self, key, default=[]):
+ arr = default
+ if self.has_key(key):
+ arr = self._table[key]
+ return arr
+
+ def parse_int(self, key, default=0):
+ val = default
+ if self.has_key(key):
+ val = int(self._table[key][0])
+ return val
+
+ def parse_ints(self, key, default=[]):
+ arr = default
+ if self.has_key(key):
+ arr = [int(str) for str in self._table[key]]
+ return arr
+
+ def parse_float(self, key, default=0.0):
+ val = default
+ if self.has_key(key):
+ val = float(self._table[key][0])
+ return val
+
+ def parse_floats(self, key, default=[]):
+ arr = default
+ if self.has_key(key):
+ arr = [float(str) for str in self._table[key]]
+ return arr
+
+ def parse_bool(self, key, default=False):
+ val = default
+ if self.has_key(key):
+ val = self._parse_bool(self._table[key][0])
+ return val
+
+ def parse_bools(self, key, default=[]):
+ arr = default
+ if self.has_key(key):
+ arr = [self._parse_bool(str) for str in self._table[key]]
+ return arr
+
+ def _is_comment(self, str):
+ is_comment = False
+ if (len(str) > 0):
+ is_comment = str[0] == '#'
+
+ return is_comment
+
+ def _is_key(self, str):
+ is_key = False
+ if (len(str) >= 3):
+ is_key = str[0] == '-' and str[1] == '-'
+
+ return is_key
+
+ def _parse_bool(self, str):
+ val = False
+ if (str == 'true' or str == 'True' or str == '1'
+ or str == 'T' or str == 't'):
+ val = True
+ return val \ No newline at end of file
diff --git a/examples/pybullet/gym/pybullet_utils/examples/testlog.py b/examples/pybullet/gym/pybullet_utils/examples/testlog.py
new file mode 100644
index 000000000..9a275160a
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/examples/testlog.py
@@ -0,0 +1,9 @@
+from pybullet_utils.logger import Logger
+logger = Logger()
+logger.configure_output_file("e:/mylog.txt")
+for i in range (10):
+ logger.log_tabular("Iteration", 1)
+Logger.print2("hello world")
+
+logger.print_tabular()
+logger.dump_tabular() \ No newline at end of file
diff --git a/examples/pybullet/gym/pybullet_utils/logger.py b/examples/pybullet/gym/pybullet_utils/logger.py
new file mode 100644
index 000000000..ceb2605ba
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/logger.py
@@ -0,0 +1,128 @@
+import pybullet_utils.mpi_util as MPIUtil
+
+"""
+
+Some simple logging functionality, inspired by rllab's logging.
+Assumes that each diagnostic gets logged each iteration
+
+Call logz.configure_output_file() to start logging to a
+tab-separated-values file (some_file_name.txt)
+
+To load the learning curves, you can do, for example
+
+A = np.genfromtxt('/tmp/expt_1468984536/log.txt',delimiter='\t',dtype=None, names=True)
+A['EpRewMean']
+
+"""
+
+import os.path as osp, shutil, time, atexit, os, subprocess
+
+class Logger:
+ def print2(str):
+ if (MPIUtil.is_root_proc()):
+ print(str)
+ return
+
+ def __init__(self):
+ self.output_file = None
+ self.first_row = True
+ self.log_headers = []
+ self.log_current_row = {}
+ self._dump_str_template = ""
+ return
+
+ def reset(self):
+ self.first_row = True
+ self.log_headers = []
+ self.log_current_row = {}
+ if self.output_file is not None:
+ self.output_file = open(output_path, 'w')
+ return
+
+ def configure_output_file(self, filename=None):
+ """
+ Set output directory to d, or to /tmp/somerandomnumber if d is None
+ """
+ self.first_row = True
+ self.log_headers = []
+ self.log_current_row = {}
+
+ output_path = filename or "output/log_%i.txt"%int(time.time())
+
+ out_dir = os.path.dirname(output_path)
+ if not os.path.exists(out_dir) and MPIUtil.is_root_proc():
+ os.makedirs(out_dir)
+
+ if (MPIUtil.is_root_proc()):
+ self.output_file = open(output_path, 'w')
+ assert osp.exists(output_path)
+ atexit.register(self.output_file.close)
+
+ Logger.print2("Logging data to " + self.output_file.name)
+ return
+
+ def log_tabular(self, key, val):
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ """
+ if self.first_row and key not in self.log_headers:
+ self.log_headers.append(key)
+ else:
+ assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key
+ self.log_current_row[key] = val
+ return
+
+ def get_num_keys(self):
+ return len(self.log_headers)
+
+ def print_tabular(self):
+ """
+ Print all of the diagnostics from the current iteration
+ """
+ if (MPIUtil.is_root_proc()):
+ vals = []
+ Logger.print2("-"*37)
+ for key in self.log_headers:
+ val = self.log_current_row.get(key, "")
+ if isinstance(val, float):
+ valstr = "%8.3g"%val
+ elif isinstance(val, int):
+ valstr = str(val)
+ else:
+ valstr = val
+ Logger.print2("| %15s | %15s |"%(key, valstr))
+ vals.append(val)
+ Logger.print2("-" * 37)
+ return
+
+ def dump_tabular(self):
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ if (MPIUtil.is_root_proc()):
+ if (self.first_row):
+ self._dump_str_template = self._build_str_template()
+
+ vals = []
+ for key in self.log_headers:
+ val = self.log_current_row.get(key, "")
+ vals.append(val)
+
+ if self.output_file is not None:
+ if self.first_row:
+ header_str = self._dump_str_template.format(*self.log_headers)
+ self.output_file.write(header_str + "\n")
+
+ val_str = self._dump_str_template.format(*map(str,vals))
+ self.output_file.write(val_str + "\n")
+ self.output_file.flush()
+
+ self.log_current_row.clear()
+ self.first_row=False
+ return
+
+ def _build_str_template(self):
+ num_keys = self.get_num_keys()
+ template = "{:<25}" * num_keys
+ return template
diff --git a/examples/pybullet/gym/pybullet_utils/math_util.py b/examples/pybullet/gym/pybullet_utils/math_util.py
new file mode 100644
index 000000000..d0f333b3f
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/math_util.py
@@ -0,0 +1,18 @@
+import numpy as np
+
+RAD_TO_DEG = 57.2957795
+DEG_TO_RAD = 1.0 / RAD_TO_DEG
+INVALID_IDX = -1
+
+def lerp(x, y, t):
+ return (1 - t) * x + t * y
+
+def log_lerp(x, y, t):
+ return np.exp(lerp(np.log(x), np.log(y), t))
+
+def flatten(arr_list):
+ return np.concatenate([np.reshape(a, [-1]) for a in arr_list], axis=0)
+
+def flip_coin(p):
+ rand_num = np.random.binomial(1, p, 1)
+ return rand_num[0] == 1 \ No newline at end of file
diff --git a/examples/pybullet/gym/pybullet_utils/mpi_util.py b/examples/pybullet/gym/pybullet_utils/mpi_util.py
new file mode 100644
index 000000000..30ca6aed0
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/mpi_util.py
@@ -0,0 +1,52 @@
+import numpy as np
+from mpi4py import MPI
+
+ROOT_PROC_RANK = 0
+
+def get_num_procs():
+ return MPI.COMM_WORLD.Get_size()
+
+def get_proc_rank():
+ return MPI.COMM_WORLD.Get_rank()
+
+def is_root_proc():
+ rank = get_proc_rank()
+ return rank == ROOT_PROC_RANK
+
+def bcast(x):
+ MPI.COMM_WORLD.Bcast(x, root=ROOT_PROC_RANK)
+ return
+
+def reduce_sum(x):
+ return reduce_all(x, MPI.SUM)
+
+def reduce_prod(x):
+ return reduce_all(x, MPI.PROD)
+
+def reduce_avg(x):
+ buffer = reduce_sum(x)
+ buffer /= get_num_procs()
+ return buffer
+
+def reduce_min(x):
+ return reduce_all(x, MPI.MIN)
+
+def reduce_max(x):
+ return reduce_all(x, MPI.MAX)
+
+def reduce_all(x, op):
+ is_array = isinstance(x, np.ndarray)
+ x_buf = x if is_array else np.array([x])
+ buffer = np.zeros_like(x_buf)
+ MPI.COMM_WORLD.Allreduce(x_buf, buffer, op=op)
+ buffer = buffer if is_array else buffer[0]
+ return buffer
+
+def gather_all(x):
+ is_array = isinstance(x, np.ndarray)
+ x_buf = np.array([x])
+ buffer = np.zeros_like(x_buf)
+ buffer = np.repeat(buffer, get_num_procs(), axis=0)
+ MPI.COMM_WORLD.Allgather(x_buf, buffer)
+ buffer = list(buffer)
+ return buffer \ No newline at end of file
diff --git a/examples/pybullet/gym/pybullet_utils/util.py b/examples/pybullet/gym/pybullet_utils/util.py
new file mode 100644
index 000000000..c11ccb811
--- /dev/null
+++ b/examples/pybullet/gym/pybullet_utils/util.py
@@ -0,0 +1,13 @@
+import random
+import numpy as np
+
+def set_global_seeds(seed):
+ try:
+ import tensorflow as tf
+ except ImportError:
+ pass
+ else:
+ tf.set_random_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ return \ No newline at end of file