Source code for eagerx.core.env

from eagerx.core.space import Space
from eagerx.core.entities import Backend
from eagerx.core.specs import NodeSpec, EngineSpec, BackendSpec, BaseNodeSpec
from eagerx.core.entities import Node
from eagerx.core.graph import Graph
from eagerx.utils.node_utils import (
    initialize_nodes,
    wait_for_node_initialization,
)
from eagerx.core.executable_node import RxNode
from eagerx.core.executable_engine import RxEngine
from eagerx.core.supervisor import Supervisor, SupervisorNode
from eagerx.core.rx_message_broker import RxMessageBroker
from eagerx.core.constants import process
from eagerx.core.constants import ENVIRONMENT, NEW_PROCESS, EXTERNAL, BackendException

# OTHER IMPORTS
import copy
import atexit
import abc
import numpy as np
import functools
from typing import List, Union, Dict, Tuple, Optional, Any
import gymnasium as gym


[docs]class BaseEnv(gym.Env): """The base class for all EAGERx environments that follows the OpenAI gym's Env API. - Be sure to call :func:`super().__init__` inside the subclass' constructor with the required arguments (name, graph, etc...). A subclass should implement the following methods: - :func:`~eagerx.core.env.BaseEnv.step`: Be sure to call :func:`~eagerx.core.env.BaseEnv._step` inside this method to perform the step. - :func:`~eagerx.core.env.BaseEnv.reset`: Be sure to call :func:`~eagerx.core.env.BaseEnv._reset` inside this method to perform the reset. A subclass can optionally overwrite the following properties: - :attr:`~eagerx.core.env.BaseEnv.observation_space`: Per default, the observations, registered in the graph, are taken. - :attr:`~eagerx.core.env.BaseEnv.action_space`: Per default, the actions, registered in the graph, are taken. """
[docs] def __init__( self, name: str, rate: float, graph: Graph, engine: EngineSpec, backend: BackendSpec = None, force_start: bool = True, render_mode: str = None, ) -> None: """Initializes an environment with EAGERx dynamics. :param name: The name of the environment. Everything related to this environment (parameters, topics, nodes, etc...) will be registered under namespace: "`/name`". :param rate: The rate (Hz) at which the environment will run. :param graph: The graph consisting of nodes and objects that describe the environment's dynamics. :param engine: The physics engine that will govern the environment's dynamics. For every :class:`~eagerx.core.entities.Object` in the graph, the corresponding engine implementations is chosen. :param backend: The backend that will govern the communication for this environment. Per default, the :class:`~eagerx.backends.single_process.SingleProcess` backend is used. :param force_start: If there already exists an environment with the same name, the existing environment is first shutdown by calling the :func:`~eagerx.core.env.BaseEnv` method before initializing this environment. :param render_mode: The render mode that will be used for rendering the environment. """ assert "/" not in name, 'Environment name "%s" cannot contain the reserved character "/".' % name self.name = name self.ns = "/" + name self.initialized = False self.has_shutdown = False self.render_mode = render_mode self._is_initialized = dict() self._launch_nodes = dict() self._sp_nodes = dict() # Register graph (returns unlinked specs with original graph & reloads entities). if engine is None: environment, engine, nodes, render = Graph._get_all_node_specs(graph._state) self.graph = graph else: self.graph, environment, engine, nodes, self.render_node = graph.register(rate, engine) self.rate = rate if rate else environment.config.rate # Initialize backend if backend is None: from eagerx.backends.single_process import SingleProcess backend = SingleProcess.make() self.backend = Backend.from_cmd( self.ns, backend.config.entity_id, backend.config.log_level, main=True, real_time_factor=engine.config.real_time_factor, sync=engine.config.sync, simulate_delays=engine.config.sync, ) # Check if there already exists an environment self._shutdown_srv = self.backend.register_environment(self.ns, force_start, self.shutdown) # Delete pre-existing parameters self.backend.delete_param(f"/{self.name}", level=2) # Upload relevant run-time settings secs, nsecs = self.backend.serialize_time(self.backend.ts_init) self.backend.upload_params( self.ns, { "log_level": self.backend.log_level, "ts_init_secs": secs, "ts_init_nsecs": nsecs, "real_time_factor": self.backend.real_time_factor, "sync": self.backend.sync, "simulate_delays": self.backend.simulate_delays, }, ) # Initialize message broker self.mb = RxMessageBroker(owner="%s/%s" % (self.ns, "env"), backend=self.backend) # Create supervisor spec (adds addresses of (engine)states to supervisor) supervisor = self._create_supervisor(environment, engine, nodes) # Upload parameters self._upload_params(self.ns, self.backend, [supervisor, environment, engine] + nodes) # Initialize environment self.environment_node = RxNode(name="%s/%s" % (self.ns, environment.config.name), message_broker=self.mb) self.environment_node.node_initialized() self.environment = self.environment_node.node # Initialize nodes initialize_nodes(nodes, ENVIRONMENT, self.ns, self.mb, self._is_initialized, self._sp_nodes, self._launch_nodes) # Initialize engine initialize_nodes( engine, ENVIRONMENT, self.ns, self.mb, self._is_initialized, self._sp_nodes, self._launch_nodes, rxnode_cls=RxEngine, ) # Initialize supervisor node self.supervisor_node = Supervisor("%s/%s" % (self.ns, supervisor.config.name), self.mb, self.environment) self.supervisor_node.node_initialized() self.supervisor = self.supervisor_node.node self.mb.connect_io() # Implement clean up atexit.register(self.shutdown)
@staticmethod def _create_supervisor(environment: NodeSpec, engine: EngineSpec, nodes: List[NodeSpec]) -> BaseNodeSpec: entity_type = f"{SupervisorNode.__module__}/{SupervisorNode.__name__}" spec = Node.pre_make("N/a", entity_type) spec.add_output("step", space=Space(shape=(), dtype="int64")) spec.config.rate = environment.config.rate spec.config.name = "env/supervisor_node" spec.config.color = "yellow" spec.config.process = process.ENVIRONMENT spec.config.outputs = ["step"] # Get all states from all nodes for i in [environment, engine] + nodes: for cname in i.params["config"]["states"]: entity_name = i.config.name name = f"{entity_name}/{cname}" address = f"{entity_name}/states/{cname}" processor = None # Only add processor (i.params["states"][cname]["processor"]) once at input side. space = i.params["states"][cname]["space"] assert ( name not in spec.params["states"] ), f'Cannot have duplicate states. State "{name}" is defined multiple times.' mapping = dict(address=address, processor=processor, space=space) with spec.states as d: d[name] = mapping spec.config.states.append(name) # Get all engine states for entity_name, obj in engine.objects.items(): for cname, state in obj.engine_states.items(): name = f"{entity_name}/{cname}" address = f"{entity_name}/states/{cname}" processor = None # Only add processor (i.params["states"][cname]["processor"]) once at input side. space = state["space"].to_dict() assert ( name not in spec.params["states"] ), f'Cannot have duplicate states. State "{name}" is defined multiple times.' mapping = dict(address=address, processor=processor, space=space) with spec.states as d: d[name] = mapping spec.config.states.append(name) return spec @staticmethod def _upload_params(ns: str, backend: Backend, nodes: List[BaseNodeSpec]) -> None: for node in nodes: name = node.config.name msg = f"Node name '{ns}/{node.config.name}' already exists on the parameter server. Node names must be unique." assert backend.get_param(f"{ns}/rate/{name}", None) is None, msg # If no backend multiprocessing support, overwrite NEW_PROCESS to ENVIRONMENT if node.config.process == NEW_PROCESS and not backend.MULTIPROCESSING_SUPPORT: backend.logwarn_once( f"Backend '{backend.BACKEND}' does not support multiprocessing, " "so all nodes are launched in the ENVIRONMENT process." ) node.config.process = ENVIRONMENT # Raise error if external process is not supported elif node.config.process == EXTERNAL and not backend.DISTRIBUTED_SUPPORT: raise BackendException( f"Backend '{backend.BACKEND}' does not support distributed computation. " f"Therefore, this backend is incompatible with node '{name}', " f"because {name}.config.process=EXTERNAL." ) params = node.build(ns=ns) backend.upload_params(ns, params) @property def observation_space(self) -> gym.spaces.Space: """The Space object corresponding to valid observations. Per default, the observation space of all registered observations in the graph is used. """ return self._observation_space @property def action_space(self) -> gym.spaces.Space: """The Space object corresponding to valid actions Per default, the action space of all registered actions in the graph is used. """ return self._action_space @property def state_space(self) -> gym.spaces.Dict: """Infers the state space from the space of every state. This space defines the format of valid states that can be set before the start of an episode. :returns: A dictionary with *key* = *state* and *value* = :class:`Space`. """ state_space = dict() for name, buffer in self.supervisor.state_buffer.items(): state_space[name] = buffer["space"] return gym.spaces.Dict(spaces=state_space) @property def _observation_space(self) -> gym.spaces.Dict: """Infers the observation space from the space of every observation. This space defines the format of valid observations. .. note:: Observations with :attr:`~eagerx.core.specs.RxInput.window` = 0 are excluded from the observation space. For observations with :attr:`~eagerx.core.specs.RxInput.window` > 1, the observation space is duplicated :attr:`~window` times. :returns: A dictionary with *key* = *observation* and *value* = :class:`Space`. """ assert not self.has_shutdown, "This environment has been shutdown." observation_space = dict() for name, buffer in self.environment.observation_buffer.items(): space = buffer["space"] if not buffer["window"] > 0: continue if isinstance(space, gym.spaces.Discrete): stacked_space = gym.spaces.MultiDiscrete([space.n] * buffer["window"]) else: low = np.repeat(space.low[np.newaxis, ...], buffer["window"], axis=0) high = np.repeat(space.high[np.newaxis, ...], buffer["window"], axis=0) stacked_space = gym.spaces.Box(low=low, high=high, dtype=space.dtype) observation_space[name] = stacked_space return gym.spaces.Dict(spaces=observation_space) @property def _action_space(self) -> gym.spaces.Dict: """Infers the action space from the space of every action. This space defines the format of valid actions. :returns: A dictionary with *key* = *action* and *value* = :class:`Space`. """ assert not self.has_shutdown, "This environment has been shutdown." action_space = dict() for name, buffer in self.environment.action_buffer.items(): action_space[name] = buffer["space"] return gym.spaces.Dict(spaces=action_space) def _set_action(self, action) -> None: # Set actions in buffer for name, buffer in self.environment.action_buffer.items(): assert not self.supervisor.sync or name in action, ( 'Action "%s" not specified. Must specify all actions in action_space if running reactive.' % name ) if name in action: buffer["msg"] = action[name] def _set_state(self, state) -> None: # Set states in buffer for name, msg in state.items(): assert name in self.supervisor.state_buffer, 'Cannot set unknown state "%s".' % name self.supervisor.state_buffer[name]["msg"] = msg def _get_observation(self) -> Dict: # Get observations from buffer observation = dict() for name, buffer in self.environment.observation_buffer.items(): observation[name] = buffer["msgs"] return observation def _initialize(self, states: Dict) -> None: assert not self.initialized, "Environment already initialized. Cannot re-initialize pipelines. " # Set desired reset states self._set_state(states) # Wait for nodes to be initialized [node.node_initialized() for name, node in self._sp_nodes.items()] wait_for_node_initialization(self._is_initialized, self.backend) # Initialize communication within this process self.mb.connect_io(print_status=True) self.backend.logdebug("Nodes initialized.") # Perform first reset self.supervisor.reset() # Nodes initialized self.initialized = True self.backend.loginfo("Communication initialized.") def _shutdown(self): if not self.has_shutdown: self._shutdown_srv.unregister() for address, node in self._launch_nodes.items(): self.backend.logdebug(f"[{self.name}] Send termination signal to '{address}'.") node.terminate() for _, rxnode in self._sp_nodes.items(): rxnode: RxNode if not rxnode.has_shutdown: self.backend.logdebug(f"[{self.name}][{rxnode.name}] Shutting down.") rxnode.node_shutdown() if not self.supervisor_node.has_shutdown: self.supervisor_node.node_shutdown() if not self.environment_node.has_shutdown: self.environment_node.node_shutdown() self.mb.shutdown() self.backend.delete_param(f"/{self.name}", level=1) self.backend.shutdown() self.has_shutdown = True
[docs] def _reset(self, states: Dict) -> Dict: """A private method that should be called within :func:`~eagerx.core.env.BaseEnv.reset()`. :param states: The desired states to be set before the start an episode. May also be an (empty) subset of registered states if not all states require a reset. :returns: The initial observation. """ assert not self.has_shutdown, "This environment has been shutdown." # Initialize environment if not self.initialized: self._initialize(states) # Set desired reset states self._set_state(states) # Perform reset self.supervisor.reset() obs = self._get_observation() return obs
[docs] @abc.abstractmethod def reset(self, seed: int = None, options: Dict[str, Any] = None) -> Tuple[Union[Dict, np.ndarray], Dict]: """An abstract method that resets the environment to an initial state and returns an initial observation. .. note:: To reset the graph, the private method :func:`~eagerx.core.env.BaseEnv._reset` must be called with the desired initial states. The spaces of all states (of Objects and Nodes in the graph) are stored in :func:`~eagerx.core.env.BaseEnv.state_space`. :returns: The initial observation that is complies with the :func:`~eagerx.core.env.BaseEnv.observation_space`. """ pass
[docs] def _step(self, action: Dict) -> Dict: """A private method that should be called within :func:`~eagerx.core.env.BaseEnv.step()`. :param action: The actions to be applied in the next timestep. Should include all registered actions. :returns: The observation of the current timestep that comply with the graph's observation space. """ # Check that nodes were previously initialized. assert self.initialized, "Not yet initialized. Call .reset() before calling .step()." assert not self.has_shutdown, "This environment has been shutdown." # Set actions in buffer self._set_action(action) # Call step self.supervisor.step() return self._get_observation()
[docs] @abc.abstractmethod def step(self, action: Union[Dict, np.ndarray]) -> Tuple[Union[Dict, np.ndarray], float, bool, bool, Dict]: """An abstract method that runs one timestep of the environment's dynamics. .. note:: To run one timestep of the graph dynamics (that essentially define the environment dynamics), this method must call the private method :func:`~eagerx.core.BaseEnv._step` with the actions that comply with :attr:`~eagerx.core.BaseEnv._action_space`. When the end of an episode is reached, the user is responsible for calling :func:`~eagerx.core.BaseEnv.reset` to reset this environment's state. :params action: Actions provided by the agent. Should comply with the :func:`~eagerx.core.env.BaseEnv.action_space`. :returns: A tuple (observation, reward, terminated, truncated, info). - observation: Observations of the current timestep that comply with the :func:`~eagerx.core.env.BaseEnv.observation_space`. - reward: amount of reward returned after previous action - terminated: whether the episode has ended due to a terminal state, in which case further step() calls will return undefined results - truncated: whether the episode has ended due to a time limit, in which case further step() calls will return undefined results - info: contains auxiliary diagnostic information (helpful for debugging, and sometimes learning) """ pass
@functools.wraps(Graph.gui) def gui( self, *args, interactive: Optional[bool] = True, resolution: Optional[List[int]] = None, filename: Optional[str] = None, **kwargs, ) -> Union[None, np.ndarray]: """Opens a graphical user interface of the graph that was used to initialize this environment. .. note:: Requires `eagerx-gui`: .. highlight:: python .. code-block:: python pip3 install eagerx-gui :param interactive: If `True`, an interactive application is launched. Otherwise, an RGB render of the GUI is returned. This could be useful when using a headless machine. :param resolution: Specifies the resolution of the returned render when `interactive` is `False`. If `interactive` is `True`, this argument is ignored. :param filename: If provided, the GUI is rendered to an svg file with this name. If `interactive` is `True`, this argument is ignored. :return: RGB render of the GUI if `interactive` is `False`. """ return self.graph.gui(*args, interactive=interactive, resolution=resolution, filename=filename, **kwargs)
[docs] def save(self, file: str) -> None: """Saves the (engine-specific) graph state, that includes the engine & environment nodes. The state is saved in *.yaml* format and contains the state of every added node, action, and observation and the connections between them. :param file: A string giving the name (and the file if the file isn't in the current working directory). """ return self.graph.save(file)
[docs] @classmethod def load(cls, name: str, file: str, backend: BackendSpec = None, force_start: bool = True): """Loads an environment corresponding to the graph state. :param name: The name of the environment. Everything related to this environment (parameters, topics, nodes, etc...) will be registered under namespace: "`/name`". :param file: A string giving the name (and the file if the file isn't in the current working directory). :param backend: The backend that will govern the communication for this environment. Per default, the :class:`~eagerx.backends.single_process.SingleProcess` backend is used. :param force_start: If there already exists an environment with the same name, the existing environment is first shutdown by calling the :func:`~eagerx.core.env.BaseEnv` method before initializing this environment. """ graph = Graph.load(file) graph = copy.deepcopy(graph) return cls(name=name, rate=None, graph=graph, engine=None, backend=backend, force_start=force_start)
[docs] def render(self) -> Optional[np.ndarray]: """A method to start rendering (i.e. open the render window). A bool message to topic address ":attr:`~eagerx.core.env.BaseEnv.name` */env/render/toggle*", which toggles the rendering on/off. :returns: Optionally, a rgb_array if env.mode=rgb_array. """ mode = self.render_mode assert not self.has_shutdown, "This environment has been shutdown." if self.render_node: if mode == "human": self.supervisor.start_render() elif mode == "rgb_array": self.supervisor.start_render() img = self.supervisor.get_last_image() return img elif mode is None: return else: raise ValueError('Render mode "%s" not recognized.' % mode) else: self.backend.logwarn_once("No render node active, so not rendering.") if mode == "rgb_array": return np.empty((0, 0, 3), dtype="uint8") else: return
[docs] def close(self): """A method to stop rendering (i.e. close the render window). A bool message to topic address ":attr:`~eagerx.core.env.BaseEnv.name` */env/render/toggle*", which toggles the rendering on/off. .. note:: Depending on the source node that is producing the images that are rendered, images may still be produced, even when the render window is not visible. This may add computational overhead and influence the run speed. Optionally, users may subscribe to topic address ":attr:`~eagerx.core.env.BaseEnv.name` */env/render/toggle*" in the node that is producing the images to stop the production and output empty images instead. """ assert not self.has_shutdown, "This environment has been shutdown." self.supervisor.stop_render()
[docs] def shutdown(self): """A method to shutdown the environment. - Clear the parameters on the ROS parameter under the namespace /:attr:`~eagerx.core.env.BaseEnv.name`. - Close nodes (i.e. release resources and perform :class:`~eagerx.core.entities.Node.close` procedure). - Unregister topics that supplied the I/O communication between nodes. """ self._shutdown()