diff options
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py')
-rw-r--r-- | examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py | 185 |
1 files changed, 93 insertions, 92 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py index f2d18051c..0077f4b11 100644 --- a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py +++ b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/solvers/mpi_solver.py @@ -8,96 +8,97 @@ from pybullet_utils.logger import Logger from pybullet_envs.deep_mimic.learning.solvers.solver import Solver + class MPISolver(Solver): - CHECK_SYNC_ITERS = 1000 - - def __init__(self, sess, optimizer, vars): - super().__init__(vars) - self.sess = sess - self.optimizer = optimizer - self._build_grad_feed(vars) - self._update = optimizer.apply_gradients(zip(self._grad_tf_list, self.vars)) - self._set_flat_vars = TFUtil.SetFromFlat(sess, self.vars) - self._get_flat_vars = TFUtil.GetFlat(sess, self.vars) - - self.iter = 0 - grad_dim = self._calc_grad_dim() - self._flat_grad = np.zeros(grad_dim, dtype=np.float32) - self._global_flat_grad = np.zeros(grad_dim, dtype=np.float32) - - return - - def get_stepsize(self): - return self.optimizer._learning_rate_tensor.eval() - - def update(self, grads=None, grad_scale=1.0): - if grads is not None: - self._flat_grad = MathUtil.flatten(grads) - else: - self._flat_grad.fill(0) - return self.update_flatgrad(self._flat_grad, grad_scale) - - def update_flatgrad(self, flat_grad, grad_scale=1.0): - if self.iter % self.CHECK_SYNC_ITERS == 0: - assert self.check_synced(), Logger.print2('Network parameters desynchronized') - - if grad_scale != 1.0: - flat_grad *= grad_scale - - MPI.COMM_WORLD.Allreduce(flat_grad, self._global_flat_grad, op=MPI.SUM) - self._global_flat_grad /= MPIUtil.get_num_procs() - - self._load_flat_grad(self._global_flat_grad) - self.sess.run([self._update], self._grad_feed) - self.iter += 1 - - return - - def sync(self): - vars = self._get_flat_vars() - MPIUtil.bcast(vars) - self._set_flat_vars(vars) - return - - def check_synced(self): - synced = True - if self._is_root(): - vars = self._get_flat_vars() - MPIUtil.bcast(vars) - else: - vars_local = self._get_flat_vars() - vars_root = np.empty_like(vars_local) - MPIUtil.bcast(vars_root) - synced = (vars_local == vars_root).all() - return synced - - def _is_root(self): - return MPIUtil.is_root_proc() - - def _build_grad_feed(self, vars): - self._grad_tf_list = [] - self._grad_buffers = [] - for v in self.vars: - shape = v.get_shape() - grad = np.zeros(shape) - grad_tf = tf.placeholder(tf.float32, shape=shape) - self._grad_buffers.append(grad) - self._grad_tf_list.append(grad_tf) - - self._grad_feed = dict({g_tf: g for g_tf, g in zip(self._grad_tf_list, self._grad_buffers)}) - - return - - def _calc_grad_dim(self): - grad_dim = 0 - for grad in self._grad_buffers: - grad_dim += grad.size - return grad_dim - - def _load_flat_grad(self, flat_grad): - start = 0 - for g in self._grad_buffers: - size = g.size - np.copyto(g, np.reshape(flat_grad[start:start + size], g.shape)) - start += size - return
\ No newline at end of file + CHECK_SYNC_ITERS = 1000 + + def __init__(self, sess, optimizer, vars): + super().__init__(vars) + self.sess = sess + self.optimizer = optimizer + self._build_grad_feed(vars) + self._update = optimizer.apply_gradients(zip(self._grad_tf_list, self.vars)) + self._set_flat_vars = TFUtil.SetFromFlat(sess, self.vars) + self._get_flat_vars = TFUtil.GetFlat(sess, self.vars) + + self.iter = 0 + grad_dim = self._calc_grad_dim() + self._flat_grad = np.zeros(grad_dim, dtype=np.float32) + self._global_flat_grad = np.zeros(grad_dim, dtype=np.float32) + + return + + def get_stepsize(self): + return self.optimizer._learning_rate_tensor.eval() + + def update(self, grads=None, grad_scale=1.0): + if grads is not None: + self._flat_grad = MathUtil.flatten(grads) + else: + self._flat_grad.fill(0) + return self.update_flatgrad(self._flat_grad, grad_scale) + + def update_flatgrad(self, flat_grad, grad_scale=1.0): + if self.iter % self.CHECK_SYNC_ITERS == 0: + assert self.check_synced(), Logger.print2('Network parameters desynchronized') + + if grad_scale != 1.0: + flat_grad *= grad_scale + + MPI.COMM_WORLD.Allreduce(flat_grad, self._global_flat_grad, op=MPI.SUM) + self._global_flat_grad /= MPIUtil.get_num_procs() + + self._load_flat_grad(self._global_flat_grad) + self.sess.run([self._update], self._grad_feed) + self.iter += 1 + + return + + def sync(self): + vars = self._get_flat_vars() + MPIUtil.bcast(vars) + self._set_flat_vars(vars) + return + + def check_synced(self): + synced = True + if self._is_root(): + vars = self._get_flat_vars() + MPIUtil.bcast(vars) + else: + vars_local = self._get_flat_vars() + vars_root = np.empty_like(vars_local) + MPIUtil.bcast(vars_root) + synced = (vars_local == vars_root).all() + return synced + + def _is_root(self): + return MPIUtil.is_root_proc() + + def _build_grad_feed(self, vars): + self._grad_tf_list = [] + self._grad_buffers = [] + for v in self.vars: + shape = v.get_shape() + grad = np.zeros(shape) + grad_tf = tf.placeholder(tf.float32, shape=shape) + self._grad_buffers.append(grad) + self._grad_tf_list.append(grad_tf) + + self._grad_feed = dict({g_tf: g for g_tf, g in zip(self._grad_tf_list, self._grad_buffers)}) + + return + + def _calc_grad_dim(self): + grad_dim = 0 + for grad in self._grad_buffers: + grad_dim += grad.size + return grad_dim + + def _load_flat_grad(self, flat_grad): + start = 0 + for g in self._grad_buffers: + size = g.size + np.copyto(g, np.reshape(flat_grad[start:start + size], g.shape)) + start += size + return |