Source code for eagerx.core.register

import functools
import copy

import eagerx.core.log as log
from eagerx.utils.utils import deepcopy
from eagerx.core.space import Space
from typing import TYPE_CHECKING, Callable, Any
import os


if TYPE_CHECKING:
    from eagerx.core.graph_engine import EngineGraph  # noqa: F401
    from eagerx.core.Entities import Entity, Engine  # noqa: F401

# Global registry with registered entities (engines, objects, nodes, converters, simnodes, etc..)
REGISTRY = dict()
# Global registry with registered I/O types of (.callback, .reset/.pre_reset, .initialize, .add_object)
TYPE_REGISTER = dict()


class ReverseRegisterLookup:
    def __init__(self, d):
        self._dict = d

    def __getitem__(self, spec_lookup):
        for _entity_cls, entity_id in self._dict.items():
            for entity_idd, entry in entity_id.items():
                if entry["spec"] == spec_lookup:
                    return entity_idd


class LookupType:
    def __init__(self, d):
        self._dict = d

    @deepcopy
    def __getitem__(self, func_lookup):
        name_split = func_lookup.__qualname__.split(".")
        cls_name = name_split[0]
        entity_id = func_lookup.__module__ + "/" + cls_name
        return self._dict[entity_id]


# Global (reversed) registry of REGISTER:
REVERSE_REGISTRY = ReverseRegisterLookup(REGISTRY)
LOOKUP_TYPES = LookupType(TYPE_REGISTER)


# TYPES


def _register_types(TYPE_REGISTER, component, cnames, func, space_only=True):
    name_split = func.__qualname__.split(".")
    cls_name = name_split[0]
    fn_name = name_split[1]
    entity_id = func.__module__ + "/" + cls_name
    entry = func.__module__ + "/" + func.__qualname__
    if space_only:
        for key, space in cnames.items():
            if space is None:
                continue
            flag = isinstance(space, Space)
            assert (
                flag
            ), f'TYPE REGISTRATION ERROR: [{cls_name}][{fn_name}][{component}]: "{space}" is an invalid space. Please provide a valid space for "{key}"instead.'
    log.logdebug(f"[{cls_name}][{fn_name}]: {component}={cnames}, entry={entry}")

    @functools.wraps(func)
    def registered_fn(*args, **kwargs):
        """Call the registered function"""
        return func(*args, **kwargs)

    if entity_id not in TYPE_REGISTER:
        """Add class if this is the first registration of class kind"""
        TYPE_REGISTER[entity_id] = dict()

    if component in TYPE_REGISTER[entity_id]:
        """Check if already registered component of duplicate function matches."""
        log.logdebug(f"[{entity_id}][{component}]: {component}={cnames}, entry={entry}")
        flag = cnames == TYPE_REGISTER[entity_id][component] or bool(eval(os.environ.get("EAGERX_RELOAD", "0")))
        assert (
            flag
        ), f'There is already a [{entity_id}][{component}] registered with cnames "{TYPE_REGISTER[entity_id][component]}", and they do not match the cnames of this function: "{cnames}".'
    TYPE_REGISTER[entity_id][component] = cnames
    return registered_fn


[docs]def inputs(**inputs: Any) -> Callable: """A decorator to register the inputs to a :func:`~eagerx.core.entities.Node.callback`. The :func:`~eagerx.core.entities.Node.callback` method should be decorated. :param inputs: The input's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "inputs", inputs)
[docs]def outputs(**outputs) -> Callable: """A decorator to register the outputs of a :func:`~eagerx.core.entities.Node.callback`. The :func:`~eagerx.core.entities.Node.callback` method should be decorated. :param outputs: The output's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "outputs", outputs)
[docs]def states(**states) -> Callable: """A decorator to register the states for a :func:`~eagerx.core.entities.Node.reset`. The :func:`~eagerx.core.entities.Node.reset` method should be decorated. :param outputs: The state's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "states", states)
[docs]def targets(**targets) -> Callable: """A decorator to register the targets of a :func:`~eagerx.core.entities.ResetNode.callback`. The :func:`~eagerx.core.entities.ResetNode.callback` method should be decorated. :param targets: The target's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "targets", targets)
[docs]def sensors(**sensors) -> Callable: """A decorator to register the sensors of an :class:`~eagerx.core.entities.Object`. The :func:`~eagerx.core.entities.Object.agnostic` method should be decorated. :param sensors: The sensor's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "sensors", sensors)
[docs]def actuators(**actuators) -> Callable: """A decorator to register the actuators of an :class:`~eagerx.core.entities.Object`. The :func:`~eagerx.core.entities.Object.agnostic` method should be decorated. :param actuators: The actuator's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "actuators", actuators)
[docs]def engine_states(**engine_states) -> Callable: """A decorator to register the engine states of an :class:`~eagerx.core.entities.Object`. The :func:`~eagerx.core.entities.Object.agnostic` method should be decorated. :param engine_states: The engine state's msg_type class. """ return functools.partial(_register_types, TYPE_REGISTER, "engine_states", engine_states)
# ENGINES
[docs]def engine(engine_cls: "Engine", entity=None) -> Callable: """A decorator to register an engine implementation of an :class:`~eagerx.core.entities.Object`. .. note:: In our running example, the :func:`~eagerx.core.entities.Object.example_engine` method would be decorated. :param engine_cls: The Engine's subclass (not the baseclass :class:`~eagerx.core.entities.Engine`). :param entity: The entity that corresponds to the engine implementation. If left unspecified, the engine is registered to the class that owns the method. """ from eagerx.utils.utils import get_default_params engine_config = get_default_params(engine_cls.add_object) assert "name" in engine_config, "The object `name` must be an argument to engine.add_object()" engine_id = engine_cls.__module__ + "/" + engine_cls.__qualname__ def _register(func, engine_id=engine_id): name_split = func.__qualname__.split(".") if entity is not None: cls_name = entity.__qualname__ fn_name = name_split[0] entity_id = entity.__module__ + "/" + cls_name elif len(name_split) > 1: cls_name = name_split[0] fn_name = name_split[1] entity_id = func.__module__ + "/" + cls_name else: cls_name = "N/A" fn_name = name_split[0] entity_id = func.__module__ + "/" + cls_name entry = func.__module__ + "/" + func.__qualname__ log.logdebug(f"[{cls_name}][{fn_name}]: entry={entry}") @functools.wraps(func) def _engine(spec, engine) -> "EngineGraph": """First, initialize spec with object_info, then call the engine function""" # Add default engine_config parameters engine._initialize_engine_config(spec, copy.deepcopy(engine_config)) # Initialize engine graph graph = spec._initialize_object_graph() # Modify engine_config with user-defined engine implementation func(spec, graph) # Add engine-specific parameters to engine.objects engine_params = {key: value for key, value in spec.engine.items() if key != "states"} engine.add_object(**engine_params) # Add engine-specific states to engine.objects engine._add_engine_states(engine_params["name"], spec) return graph msg = f"Cannot register engine '{entry}' for object '{entity_id}'. " # Check if spec of duplicate entity_id corresponds to same spec function if entity_id not in REGISTRY: REGISTRY[entity_id] = {} flag = engine_id in REGISTRY[entity_id] and _engine == REGISTRY[entity_id][engine_id] flag = not flag or bool(eval(os.environ.get("EAGERX_RELOAD", "0"))) assert flag, msg + "There is already an engine implementation of this type registered." # Register engine implementation REGISTRY[entity_id][engine_id] = _engine return _engine return _register
def add_engine(spec, engine): """Add engine based on registered entity_id""" entity_id = spec.config.entity_id engine_id = engine.config.entity_id msg = f"Cannot add object '{entity_id}' to engine '{engine_id}'. " assert entity_id in REGISTRY, msg + f"The Object '{entity_id}' has not been registered yet." msg_2 = ( "If launching the environment as a subprocess, " "make sure to reload (i.e. import) the engine implementations in each subprocess." ) assert engine_id in REGISTRY[entity_id], msg + "No engine implementation was registered. " + msg_2 graph = REGISTRY[entity_id][engine_id](spec, engine) return graph