Code Examples
Below you can find a code example of environment creation and training using Stable-Baselines3. To run this code, you should install eagerx_tutorials, which can be done by running:
pip3 install eagerx_tutorials
Detailed explanation of the code can be found in this Colab tutorial.
import eagerx
from eagerx.backends.single_process import SingleProcess
from eagerx.wrappers import Flatten
from eagerx_tutorials.pendulum.objects import Pendulum
from eagerx_ode.engine import OdeEngine
import stable_baselines3 as sb3
import numpy as np
from typing import Dict
class PendulumEnv(eagerx.BaseEnv):
def __init__(self, name: str, rate: float, graph: eagerx.Graph, engine: eagerx.specs.EngineSpec,
backend: eagerx.specs.BackendSpec):
self.max_steps = 100
self.steps = None
super().__init__(name, rate, graph, engine, backend, force_start=True)
def step(self, action: Dict):
observation = self._step(action)
self.steps += 1
# Calculate reward and check if the episode is terminated
th = observation["angle"][0]
thdot = observation["angular_velocity"][0]
u = float(action["voltage"])
th -= 2 * np.pi * np.floor((th + np.pi) / (2 * np.pi))
cost = th ** 2 + 0.1 * thdot ** 2 + 0.01 * u ** 2
truncated = self.steps > self.max_steps
terminated = False
# Render
if self.render_mode == "human":
return observation, -cost, terminated, truncated, {}
def reset(self, seed=None, options=None) -> Dict:
states = self.state_space.sample()
observation = self._reset(states)
self.steps = 0
# Render
if self.render_mode == "human":
return observation, {}
if __name__ == "__main__":
rate = 30.0
pendulum = Pendulum.make("pendulum", actuators=["u"], sensors=["theta", "theta_dot"], states=["model_state"])
graph = eagerx.Graph.create()
graph.connect(action="voltage", target=pendulum.actuators.u)
graph.connect(source=pendulum.sensors.theta, observation="angle")
graph.connect(source=pendulum.sensors.theta_dot, observation="angular_velocity")
engine = OdeEngine.make(rate=rate)
backend = SingleProcess.make()
env = PendulumEnv(name="PendulumEnv", rate=rate, graph=graph, engine=engine, backend=backend)
env = Flatten(env)
model = sb3.SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(150 * rate))