diff options
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/arg_parser.py | 124 | ||||
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/examples/testlog.py | 9 | ||||
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/logger.py | 128 | ||||
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/math_util.py | 18 | ||||
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/mpi_util.py | 52 | ||||
-rw-r--r-- | examples/pybullet/gym/pybullet_utils/util.py | 13 |
6 files changed, 344 insertions, 0 deletions
diff --git a/examples/pybullet/gym/pybullet_utils/arg_parser.py b/examples/pybullet/gym/pybullet_utils/arg_parser.py new file mode 100644 index 000000000..5461a5bf2 --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/arg_parser.py @@ -0,0 +1,124 @@ +import re as RE + +class ArgParser(object): + global_parser = None + + def __init__(self): + self._table = dict() + return + + def clear(self): + self._table.clear() + return + + def load_args(self, arg_strs): + succ = True + vals = [] + curr_key = '' + + for str in arg_strs: + if not (self._is_comment(str)): + is_key = self._is_key(str) + if (is_key): + if (curr_key != ''): + if (curr_key not in self._table): + self._table[curr_key] = vals + + vals = [] + curr_key = str[2::] + else: + vals.append(str) + + if (curr_key != ''): + if (curr_key not in self._table): + self._table[curr_key] = vals + + vals = [] + + return succ + + def load_file(self, filename): + succ = False + with open(filename, 'r') as file: + lines = RE.split(r'[\n\r]+', file.read()) + file.close() + + arg_strs = [] + for line in lines: + if (len(line) > 0 and not self._is_comment(line)): + arg_strs += line.split() + + succ = self.load_args(arg_strs) + return succ + + def has_key(self, key): + return key in self._table + + def parse_string(self, key, default=''): + str = default + if self.has_key(key): + str = self._table[key][0] + return str + + def parse_strings(self, key, default=[]): + arr = default + if self.has_key(key): + arr = self._table[key] + return arr + + def parse_int(self, key, default=0): + val = default + if self.has_key(key): + val = int(self._table[key][0]) + return val + + def parse_ints(self, key, default=[]): + arr = default + if self.has_key(key): + arr = [int(str) for str in self._table[key]] + return arr + + def parse_float(self, key, default=0.0): + val = default + if self.has_key(key): + val = float(self._table[key][0]) + return val + + def parse_floats(self, key, default=[]): + arr = default + if self.has_key(key): + arr = [float(str) for str in self._table[key]] + return arr + + def parse_bool(self, key, default=False): + val = default + if self.has_key(key): + val = self._parse_bool(self._table[key][0]) + return val + + def parse_bools(self, key, default=[]): + arr = default + if self.has_key(key): + arr = [self._parse_bool(str) for str in self._table[key]] + return arr + + def _is_comment(self, str): + is_comment = False + if (len(str) > 0): + is_comment = str[0] == '#' + + return is_comment + + def _is_key(self, str): + is_key = False + if (len(str) >= 3): + is_key = str[0] == '-' and str[1] == '-' + + return is_key + + def _parse_bool(self, str): + val = False + if (str == 'true' or str == 'True' or str == '1' + or str == 'T' or str == 't'): + val = True + return val
\ No newline at end of file diff --git a/examples/pybullet/gym/pybullet_utils/examples/testlog.py b/examples/pybullet/gym/pybullet_utils/examples/testlog.py new file mode 100644 index 000000000..9a275160a --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/examples/testlog.py @@ -0,0 +1,9 @@ +from pybullet_utils.logger import Logger +logger = Logger() +logger.configure_output_file("e:/mylog.txt") +for i in range (10): + logger.log_tabular("Iteration", 1) +Logger.print2("hello world") + +logger.print_tabular() +logger.dump_tabular()
\ No newline at end of file diff --git a/examples/pybullet/gym/pybullet_utils/logger.py b/examples/pybullet/gym/pybullet_utils/logger.py new file mode 100644 index 000000000..ceb2605ba --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/logger.py @@ -0,0 +1,128 @@ +import pybullet_utils.mpi_util as MPIUtil + +""" + +Some simple logging functionality, inspired by rllab's logging. +Assumes that each diagnostic gets logged each iteration + +Call logz.configure_output_file() to start logging to a +tab-separated-values file (some_file_name.txt) + +To load the learning curves, you can do, for example + +A = np.genfromtxt('/tmp/expt_1468984536/log.txt',delimiter='\t',dtype=None, names=True) +A['EpRewMean'] + +""" + +import os.path as osp, shutil, time, atexit, os, subprocess + +class Logger: + def print2(str): + if (MPIUtil.is_root_proc()): + print(str) + return + + def __init__(self): + self.output_file = None + self.first_row = True + self.log_headers = [] + self.log_current_row = {} + self._dump_str_template = "" + return + + def reset(self): + self.first_row = True + self.log_headers = [] + self.log_current_row = {} + if self.output_file is not None: + self.output_file = open(output_path, 'w') + return + + def configure_output_file(self, filename=None): + """ + Set output directory to d, or to /tmp/somerandomnumber if d is None + """ + self.first_row = True + self.log_headers = [] + self.log_current_row = {} + + output_path = filename or "output/log_%i.txt"%int(time.time()) + + out_dir = os.path.dirname(output_path) + if not os.path.exists(out_dir) and MPIUtil.is_root_proc(): + os.makedirs(out_dir) + + if (MPIUtil.is_root_proc()): + self.output_file = open(output_path, 'w') + assert osp.exists(output_path) + atexit.register(self.output_file.close) + + Logger.print2("Logging data to " + self.output_file.name) + return + + def log_tabular(self, key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + """ + if self.first_row and key not in self.log_headers: + self.log_headers.append(key) + else: + assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration"%key + self.log_current_row[key] = val + return + + def get_num_keys(self): + return len(self.log_headers) + + def print_tabular(self): + """ + Print all of the diagnostics from the current iteration + """ + if (MPIUtil.is_root_proc()): + vals = [] + Logger.print2("-"*37) + for key in self.log_headers: + val = self.log_current_row.get(key, "") + if isinstance(val, float): + valstr = "%8.3g"%val + elif isinstance(val, int): + valstr = str(val) + else: + valstr = val + Logger.print2("| %15s | %15s |"%(key, valstr)) + vals.append(val) + Logger.print2("-" * 37) + return + + def dump_tabular(self): + """ + Write all of the diagnostics from the current iteration + """ + if (MPIUtil.is_root_proc()): + if (self.first_row): + self._dump_str_template = self._build_str_template() + + vals = [] + for key in self.log_headers: + val = self.log_current_row.get(key, "") + vals.append(val) + + if self.output_file is not None: + if self.first_row: + header_str = self._dump_str_template.format(*self.log_headers) + self.output_file.write(header_str + "\n") + + val_str = self._dump_str_template.format(*map(str,vals)) + self.output_file.write(val_str + "\n") + self.output_file.flush() + + self.log_current_row.clear() + self.first_row=False + return + + def _build_str_template(self): + num_keys = self.get_num_keys() + template = "{:<25}" * num_keys + return template diff --git a/examples/pybullet/gym/pybullet_utils/math_util.py b/examples/pybullet/gym/pybullet_utils/math_util.py new file mode 100644 index 000000000..d0f333b3f --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/math_util.py @@ -0,0 +1,18 @@ +import numpy as np + +RAD_TO_DEG = 57.2957795 +DEG_TO_RAD = 1.0 / RAD_TO_DEG +INVALID_IDX = -1 + +def lerp(x, y, t): + return (1 - t) * x + t * y + +def log_lerp(x, y, t): + return np.exp(lerp(np.log(x), np.log(y), t)) + +def flatten(arr_list): + return np.concatenate([np.reshape(a, [-1]) for a in arr_list], axis=0) + +def flip_coin(p): + rand_num = np.random.binomial(1, p, 1) + return rand_num[0] == 1
\ No newline at end of file diff --git a/examples/pybullet/gym/pybullet_utils/mpi_util.py b/examples/pybullet/gym/pybullet_utils/mpi_util.py new file mode 100644 index 000000000..30ca6aed0 --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/mpi_util.py @@ -0,0 +1,52 @@ +import numpy as np +from mpi4py import MPI + +ROOT_PROC_RANK = 0 + +def get_num_procs(): + return MPI.COMM_WORLD.Get_size() + +def get_proc_rank(): + return MPI.COMM_WORLD.Get_rank() + +def is_root_proc(): + rank = get_proc_rank() + return rank == ROOT_PROC_RANK + +def bcast(x): + MPI.COMM_WORLD.Bcast(x, root=ROOT_PROC_RANK) + return + +def reduce_sum(x): + return reduce_all(x, MPI.SUM) + +def reduce_prod(x): + return reduce_all(x, MPI.PROD) + +def reduce_avg(x): + buffer = reduce_sum(x) + buffer /= get_num_procs() + return buffer + +def reduce_min(x): + return reduce_all(x, MPI.MIN) + +def reduce_max(x): + return reduce_all(x, MPI.MAX) + +def reduce_all(x, op): + is_array = isinstance(x, np.ndarray) + x_buf = x if is_array else np.array([x]) + buffer = np.zeros_like(x_buf) + MPI.COMM_WORLD.Allreduce(x_buf, buffer, op=op) + buffer = buffer if is_array else buffer[0] + return buffer + +def gather_all(x): + is_array = isinstance(x, np.ndarray) + x_buf = np.array([x]) + buffer = np.zeros_like(x_buf) + buffer = np.repeat(buffer, get_num_procs(), axis=0) + MPI.COMM_WORLD.Allgather(x_buf, buffer) + buffer = list(buffer) + return buffer
\ No newline at end of file diff --git a/examples/pybullet/gym/pybullet_utils/util.py b/examples/pybullet/gym/pybullet_utils/util.py new file mode 100644 index 000000000..c11ccb811 --- /dev/null +++ b/examples/pybullet/gym/pybullet_utils/util.py @@ -0,0 +1,13 @@ +import random +import numpy as np + +def set_global_seeds(seed): + try: + import tensorflow as tf + except ImportError: + pass + else: + tf.set_random_seed(seed) + np.random.seed(seed) + random.seed(seed) + return
\ No newline at end of file |