Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/rcs/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,54 @@ def close(self):
super().close()


class MultiRobotWrapper(gym.Env):
"""Wraps a dictionary of environments to allow for multi robot control."""

def __init__(self, envs: dict[str, gym.Env] | dict[str, gym.Wrapper]):
self.envs = envs
self.unwrapped_multi = cast(dict[str, RobotEnv], {key: env.unwrapped for key, env in envs.items()})

def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
# follows gym env by combinding a dict of envs into a single env
obs = {}
reward = 0.0
terminated = False
truncated = False
info = {}
for key, env in self.envs.items():
obs[key], r, t, tr, info[key] = env.step(action[key])
reward += float(r)
terminated = terminated or t
truncated = truncated or tr
info[key]["terminated"] = t
info[key]["truncated"] = tr
return obs, reward, terminated, truncated, info

def reset(
self, seed: dict[str, int] | None = None, options: dict[str, dict[str, Any]] | None = None # type: ignore
) -> tuple[dict[str, Any], dict[str, Any]]:
obs = {}
info = {}

seed_ = seed if seed is not None else {key: None for key in self.envs} # type: ignore
options_ = options if options is not None else {key: None for key in self.envs} # type: ignore
for key, env in self.envs.items():
obs[key], info[key] = env.reset(seed=seed_[key], options=options_[key])
return obs, info

def get_wrapper_attr(self, name: str) -> Any:
"""Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object.
If lower environments have the same attribute, it returns a dictionary of the attribute values.
"""
if name in self.__dir__():
return getattr(self, name)
return {key: env.get_wrapper_attr(name) for key, env in self.envs.items()}

def close(self):
for env in self.envs.values():
env.close()


class RelativeTo(Enum):
LAST_STEP = auto()
CONFIGURED_ORIGIN = auto()
Expand Down
45 changes: 43 additions & 2 deletions python/rcs/envs/creators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ControlMode,
GripperWrapper,
HandWrapper,
MultiRobotWrapper,
RelativeActionSpace,
RelativeTo,
RobotEnv,
Expand All @@ -24,10 +25,10 @@
from rcs.envs.sim import (
CamRobot,
CollisionGuard,
FR3Sim,
GripperWrapperSim,
PickCubeSuccessWrapper,
RandomCubePos,
RobotSimWrapper,
SimWrapper,
)
from rcs.envs.space_utils import VecType
Expand Down Expand Up @@ -124,6 +125,46 @@ def __call__( # type: ignore
return env


class RCSFR3MultiEnvCreator(RCSHardwareEnvCreator):
def __call__( # type: ignore
ips: list[str],
control_mode: ControlMode,
robot_cfg: rcs.hw.FR3Config,
gripper_cfg: rcs.hw.FHConfig | None = None,
camera_set: BaseHardwareCameraSet | None = None,
max_relative_movement: float | tuple[float, float] | None = None,
relative_to: RelativeTo = RelativeTo.LAST_STEP,
urdf_path: str | PathLike | None = None,
) -> gym.Env:

urdf_path = get_urdf_path(urdf_path, allow_none_if_not_found=False)
ik = rcs.common.IK(str(urdf_path)) if urdf_path is not None else None
robots: dict[str, rcs.hw.FR3] = {}
for ip in ips:
robots[ip] = rcs.hw.FR3(ip, ik)
robots[ip].set_parameters(robot_cfg)

envs = {}
for ip in ips:
env: gym.Env = RobotEnv(robots[ip], control_mode)
env = FR3HW(env)
if gripper_cfg is not None:
gripper = rcs.hw.FrankaHand(ip, gripper_cfg)
env = GripperWrapper(env, gripper, binary=True)

if max_relative_movement is not None:
env = RelativeActionSpace(env, max_mov=max_relative_movement, relative_to=relative_to)
envs[ip] = env

env = MultiRobotWrapper(envs)
if camera_set is not None:
camera_set.start()
camera_set.wait_for_frames()
logger.info("CameraSet started")
env = CameraSetWrapper(env, camera_set)
return env


class RCSFR3DefaultEnvCreator(RCSHardwareEnvCreator):
def __call__( # type: ignore
self,
Expand Down Expand Up @@ -192,7 +233,7 @@ def __call__( # type: ignore
ik = rcs.common.IK(urdf_path)
robot = rcs.sim.SimRobot(simulation, ik, robot_cfg)
env: gym.Env = RobotEnv(robot, control_mode)
env = FR3Sim(env, simulation, sim_wrapper)
env = RobotSimWrapper(env, simulation, sim_wrapper)

if camera_set_cfg is not None:
camera_set = SimCameraSet(simulation, camera_set_cfg)
Expand Down
47 changes: 44 additions & 3 deletions python/rcs/envs/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import rcs
from rcs import sim
from rcs.envs.base import ControlMode, GripperWrapper, RobotEnv
from rcs.envs.base import ControlMode, GripperWrapper, MultiRobotWrapper, RobotEnv
from rcs.envs.space_utils import ActObsInfoWrapper, VecType
from rcs.envs.utils import default_fr3_sim_robot_cfg

Expand All @@ -25,7 +25,7 @@ def __init__(self, env: gym.Env, simulation: sim.Sim):
self.sim = simulation


class FR3Sim(gym.Wrapper):
class RobotSimWrapper(gym.Wrapper):
def __init__(self, env, simulation: sim.Sim, sim_wrapper: Type[SimWrapper] | None = None):
self.sim_wrapper = sim_wrapper
if sim_wrapper is not None:
Expand Down Expand Up @@ -58,6 +58,47 @@ def reset(
return obs, info


class MultiSimRobotWrapper(gym.Wrapper):
"""Wraps a dictionary of environments to allow for multi robot control."""

def __init__(self, env: MultiRobotWrapper, simulation: sim.Sim):
super().__init__(env)
self.env: MultiRobotWrapper
self.sim = simulation
self.sim_robots = cast(dict[str, sim.SimRobot], {key: e.robot for key, e in self.env.unwrapped_multi.items()})

def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]:
_, _, _, _, info = super().step(action)

self.sim.step_until_convergence()
info["is_sim_converged"] = self.sim.is_converged()
for key in self.envs.envs.items():
state = self.sim_robots[key].get_state()
info[key]["collision"] = state.collision
info[key]["ik_success"] = state.ik_success

obs = {key: env.get_obs() for key, env in self.env.unwrapped_multi.items()}
truncated = np.all([info[key]["collision"] or info[key]["ik_success"] for key in info])
return obs, 0.0, False, bool(truncated), info

def reset(
self, seed: dict[str, int | None] | None = None, options: dict[str, Any] | None = None # type: ignore
) -> tuple[dict[str, Any], dict[str, Any]]:
if seed is None:
seed = {key: None for key in self.env.envs}
if options is None:
options = {key: {} for key in self.env.envs}
obs = {}
info = {}
self.sim.reset()
for key, env in self.env.envs.items():
_, info[key] = env.reset(seed=seed[key], options=options[key])
self.sim.step(1)
for key, env in self.env.unwrapped_multi.items():
obs[key] = cast(dict, env.get_obs())
return obs, info


class GripperWrapperSim(ActObsInfoWrapper):
def __init__(self, env, gripper: sim.SimGripper):
super().__init__(env)
Expand Down Expand Up @@ -178,7 +219,7 @@ def env_from_xml_paths(
else:
control_mode = env.unwrapped.get_control_mode()
c_env: gym.Env = RobotEnv(robot, control_mode)
c_env = FR3Sim(c_env, simulation)
c_env = RobotSimWrapper(c_env, simulation)
if gripper:
gripper_cfg = sim.SimGripperConfig()
gripper_cfg.add_id(id)
Expand Down