diff options
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py')
-rw-r--r-- | examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py index 12c682a19..71132e15d 100644 --- a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py +++ b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py @@ -1,18 +1,19 @@ import numpy as np + def compute_return(rewards, gamma, td_lambda, val_t): - # computes td-lambda return of path - path_len = len(rewards) - assert len(val_t) == path_len + 1 + # computes td-lambda return of path + path_len = len(rewards) + assert len(val_t) == path_len + 1 + + return_t = np.zeros(path_len) + last_val = rewards[-1] + gamma * val_t[-1] + return_t[-1] = last_val - return_t = np.zeros(path_len) - last_val = rewards[-1] + gamma * val_t[-1] - return_t[-1] = last_val + for i in reversed(range(0, path_len - 1)): + curr_r = rewards[i] + next_ret = return_t[i + 1] + curr_val = curr_r + gamma * ((1.0 - td_lambda) * val_t[i + 1] + td_lambda * next_ret) + return_t[i] = curr_val - for i in reversed(range(0, path_len - 1)): - curr_r = rewards[i] - next_ret = return_t[i + 1] - curr_val = curr_r + gamma * ((1.0 - td_lambda) * val_t[i + 1] + td_lambda * next_ret) - return_t[i] = curr_val - - return return_t
\ No newline at end of file + return return_t |