summaryrefslogtreecommitdiff
path: root/examples/pybullet/gym/pybullet_envs/deep_mimic/learning/agent_builder.py
blob: ca54f46d591a2d86f53d3fa92a9670b55c83741e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import json
import numpy as np
from learning.ppo_agent import PPOAgent
import pybullet_data

AGENT_TYPE_KEY = "AgentType"

def build_agent(world, id, file):
    agent = None
    with open(pybullet_data.getDataPath()+"/"+file) as data_file:    
        json_data = json.load(data_file)
        
        assert AGENT_TYPE_KEY in json_data
        agent_type = json_data[AGENT_TYPE_KEY]
        
        if (agent_type == PPOAgent.NAME):
            agent = PPOAgent(world, id, json_data)
        else:
            assert False, 'Unsupported agent type: ' + agent_type

    return agent