summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py')
-rw-r--r--examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py
index 12c682a19..71132e15d 100644
--- a/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py
+++ b/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/rl_util.py
@@ -1,18 +1,19 @@
import numpy as np
+
def compute_return(rewards, gamma, td_lambda, val_t):
- # computes td-lambda return of path
- path_len = len(rewards)
- assert len(val_t) == path_len + 1
+ # computes td-lambda return of path
+ path_len = len(rewards)
+ assert len(val_t) == path_len + 1
+
+ return_t = np.zeros(path_len)
+ last_val = rewards[-1] + gamma * val_t[-1]
+ return_t[-1] = last_val
- return_t = np.zeros(path_len)
- last_val = rewards[-1] + gamma * val_t[-1]
- return_t[-1] = last_val
+ for i in reversed(range(0, path_len - 1)):
+ curr_r = rewards[i]
+ next_ret = return_t[i + 1]
+ curr_val = curr_r + gamma * ((1.0 - td_lambda) * val_t[i + 1] + td_lambda * next_ret)
+ return_t[i] = curr_val
- for i in reversed(range(0, path_len - 1)):
- curr_r = rewards[i]
- next_ret = return_t[i + 1]
- curr_val = curr_r + gamma * ((1.0 - td_lambda) * val_t[i + 1] + td_lambda * next_ret)
- return_t[i] = curr_val
-
- return return_t \ No newline at end of file
+ return return_t