# OTHER
import time
from typing import List, NamedTuple, Any, Tuple
import importlib
import inspect
from functools import wraps
import copy
import json
from eagerx.core.constants import BackendException, Unspecified
def dict_null(items):
result = {}
for key, value in items:
if value is None:
value = "null"
result[key] = value
return result
def dict_None(items):
result = {}
for key, value in items:
if value == "null":
value = None
result[key] = value
return result
def replace_None(d, to_null=True):
dict_str = json.dumps(d)
if to_null:
object_pairs_hook = dict_null
else:
object_pairs_hook = dict_None
return json.loads(dict_str, object_pairs_hook=object_pairs_hook)
def get_attribute_from_module(attribute, module=None):
if module is None:
module, attribute = attribute.split("/")
module = importlib.import_module(module)
attribute = getattr(module, attribute)
return attribute
def load(mod_attr: str):
return get_attribute_from_module(mod_attr)
def initialize_processor(spec):
processor = load(spec.params["processor_type"])()
processor.initialize(spec)
return processor
def is_compatible(dtype_source, dtype_target):
msg = f"Dtype of source ({dtype_source}) does not match with the dtype of target ({dtype_target})."
assert dtype_target == dtype_source, msg
class Header(NamedTuple):
"""A dataclass for meta data of sent messages."""
#: Sequence number of received message since the last reset.
seq: int
#: Timestamp according to the simulated clock (seconds). This time is scaled by the real-time factor if > 0.
sc: float
#: Timestamp according to the wall clock (seconds).
wc: float
[docs]class Stamp(NamedTuple):
"""A dataclass for timestamping received messages."""
#: Sequence number of received message.
seq: int
#: Timestamp according to the simulated clock (seconds). This time is scaled by the real-time factor if > 0.
sc: float
#: Timestamp according to the wall clock (seconds).
wc: float
[docs]class Info(NamedTuple):
"""A dataclass containing info about the received messages in :attr:`~eagerx.utils.utils.Msg.msgs`."""
#: Name of the registered input.
name: str
#: Number of times :func:`~eagerx.core.entities.Node.callback` has been called since the last reset.
node_tick: int
#: Rate (Hz) of the input.
rate_in: float
#: Simulated timestamp that states during which cycle the message was received since the last reset according
#: to :attr:`~eagerx.core.entities.Node.rate` and :attr:`~eagerx.utils.utils.Info.node_tick`.
t_node: List[Stamp]
#: Simulated timestamp that states at what time the message was received
#: according to :attr:`~eagerx.utils.utils.Info.rate_in` and :attr:`~eagerx.utils.utils.Stamp.seq`.
t_in: List[Stamp]
#: Only concerns states. False if a state has not yet been reset.
done: bool
[docs]class Msg(NamedTuple):
"""A dataclass representing a (windowed) input that is passed to :func:`~eagerx.core.entities.Node.callback`."""
#: Info on the received messages in :attr:`~eagerx.utils.utils.Msg.msgs`.
info: Info
#: The received messages with indexing `msgs[-1]` being the most recent message and `msgs[0]` the oldest.
msgs: List[Any]
# Set default values
Header.__new__.__defaults__ = (None,) * len(Header._fields)
Stamp.__new__.__defaults__ = (None,) * len(Stamp._fields)
Info.__new__.__defaults__ = (None,) * len(Info._fields)
def deepcopy(func):
@wraps(func)
def wrapper(*args, **kwargs):
return copy.deepcopy(func(*args, **kwargs))
return wrapper
def is_supported_type(param: Any, types: Tuple, none_support):
if isinstance(param, types) or (param is None and none_support):
if isinstance(param, dict):
for key, value in param.items():
assert isinstance(key, str), f'Invalid key "{key}". Only type "str" is supported as dictionary key.'
is_supported_type(value, types, none_support)
elif not isinstance(param, str) and hasattr(param, "__iter__"):
for value in param:
is_supported_type(value, types, none_support)
else:
raise TypeError(
f'Type "{type(param)}" of a specified (nested) param "{param}" is not supported. Only types {types} are supported.'
)
def supported_types(*types: Tuple, is_classmethod=True):
# Check if we support NoneType
none_support = False
for a in types:
if a is None:
none_support = True
break
# Remove None from types
types = tuple([t for t in types if t is not None])
def _check(func):
@wraps(func)
def wrapper(*args, **kwargs):
if is_classmethod:
check_args = list(args[1:]) + [value for _, value in kwargs.items()]
else:
check_args = list(args) + [value for _, value in kwargs.items()]
for param in check_args:
is_supported_type(param, types, none_support)
return func(*args, **kwargs)
return wrapper
return _check
def get_default_params(func):
argspec = inspect.getfullargspec(func)
if argspec.defaults:
positional_count = len(argspec.args) - len(argspec.defaults)
defaults = dict(zip(argspec.args[positional_count:], argspec.defaults))
else:
defaults = dict()
positional_count = len(argspec.args)
for arg in argspec.args[:positional_count]:
if arg == "self":
continue
defaults[arg] = None
return defaults
# A singleton that is used to check if an argument was specified.
_unspecified = Unspecified()
def get_param_with_blocking(name, backend, default=_unspecified, timeout=2.0):
params = Unspecified()
start = time.time()
while isinstance(params, Unspecified):
if time.time() - start > timeout:
raise KeyError(f"Timeout. Parameter '{name}' not available on parameter server.")
try:
params = backend.get_param(name, default=default)
except (BackendException, KeyError):
if not isinstance(default, Unspecified):
return default
sleep_time = 0.01
time.sleep(sleep_time)
return replace_None(params, to_null=False)