diff --git a/python/rcs/envs/base.py b/python/rcs/envs/base.py index 85558a91..4b9bd8bb 100644 --- a/python/rcs/envs/base.py +++ b/python/rcs/envs/base.py @@ -280,6 +280,54 @@ def close(self): super().close() +class MultiRobotWrapper(gym.Env): + """Wraps a dictionary of environments to allow for multi robot control.""" + + def __init__(self, envs: dict[str, gym.Env] | dict[str, gym.Wrapper]): + self.envs = envs + self.unwrapped_multi = cast(dict[str, RobotEnv], {key: env.unwrapped for key, env in envs.items()}) + + def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + # follows gym env by combinding a dict of envs into a single env + obs = {} + reward = 0.0 + terminated = False + truncated = False + info = {} + for key, env in self.envs.items(): + obs[key], r, t, tr, info[key] = env.step(action[key]) + reward += float(r) + terminated = terminated or t + truncated = truncated or tr + info[key]["terminated"] = t + info[key]["truncated"] = tr + return obs, reward, terminated, truncated, info + + def reset( + self, seed: dict[str, int] | None = None, options: dict[str, dict[str, Any]] | None = None # type: ignore + ) -> tuple[dict[str, Any], dict[str, Any]]: + obs = {} + info = {} + + seed_ = seed if seed is not None else {key: None for key in self.envs} # type: ignore + options_ = options if options is not None else {key: None for key in self.envs} # type: ignore + for key, env in self.envs.items(): + obs[key], info[key] = env.reset(seed=seed_[key], options=options_[key]) + return obs, info + + def get_wrapper_attr(self, name: str) -> Any: + """Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object. + If lower environments have the same attribute, it returns a dictionary of the attribute values. + """ + if name in self.__dir__(): + return getattr(self, name) + return {key: env.get_wrapper_attr(name) for key, env in self.envs.items()} + + def close(self): + for env in self.envs.values(): + env.close() + + class RelativeTo(Enum): LAST_STEP = auto() CONFIGURED_ORIGIN = auto() diff --git a/python/rcs/envs/creators.py b/python/rcs/envs/creators.py index 38aa2784..9fba6439 100644 --- a/python/rcs/envs/creators.py +++ b/python/rcs/envs/creators.py @@ -16,6 +16,7 @@ ControlMode, GripperWrapper, HandWrapper, + MultiRobotWrapper, RelativeActionSpace, RelativeTo, RobotEnv, @@ -24,10 +25,10 @@ from rcs.envs.sim import ( CamRobot, CollisionGuard, - FR3Sim, GripperWrapperSim, PickCubeSuccessWrapper, RandomCubePos, + RobotSimWrapper, SimWrapper, ) from rcs.envs.space_utils import VecType @@ -124,6 +125,46 @@ def __call__( # type: ignore return env +class RCSFR3MultiEnvCreator(RCSHardwareEnvCreator): + def __call__( # type: ignore + ips: list[str], + control_mode: ControlMode, + robot_cfg: rcs.hw.FR3Config, + gripper_cfg: rcs.hw.FHConfig | None = None, + camera_set: BaseHardwareCameraSet | None = None, + max_relative_movement: float | tuple[float, float] | None = None, + relative_to: RelativeTo = RelativeTo.LAST_STEP, + urdf_path: str | PathLike | None = None, + ) -> gym.Env: + + urdf_path = get_urdf_path(urdf_path, allow_none_if_not_found=False) + ik = rcs.common.IK(str(urdf_path)) if urdf_path is not None else None + robots: dict[str, rcs.hw.FR3] = {} + for ip in ips: + robots[ip] = rcs.hw.FR3(ip, ik) + robots[ip].set_parameters(robot_cfg) + + envs = {} + for ip in ips: + env: gym.Env = RobotEnv(robots[ip], control_mode) + env = FR3HW(env) + if gripper_cfg is not None: + gripper = rcs.hw.FrankaHand(ip, gripper_cfg) + env = GripperWrapper(env, gripper, binary=True) + + if max_relative_movement is not None: + env = RelativeActionSpace(env, max_mov=max_relative_movement, relative_to=relative_to) + envs[ip] = env + + env = MultiRobotWrapper(envs) + if camera_set is not None: + camera_set.start() + camera_set.wait_for_frames() + logger.info("CameraSet started") + env = CameraSetWrapper(env, camera_set) + return env + + class RCSFR3DefaultEnvCreator(RCSHardwareEnvCreator): def __call__( # type: ignore self, @@ -192,7 +233,7 @@ def __call__( # type: ignore ik = rcs.common.IK(urdf_path) robot = rcs.sim.SimRobot(simulation, ik, robot_cfg) env: gym.Env = RobotEnv(robot, control_mode) - env = FR3Sim(env, simulation, sim_wrapper) + env = RobotSimWrapper(env, simulation, sim_wrapper) if camera_set_cfg is not None: camera_set = SimCameraSet(simulation, camera_set_cfg) diff --git a/python/rcs/envs/sim.py b/python/rcs/envs/sim.py index 932076b5..84413506 100644 --- a/python/rcs/envs/sim.py +++ b/python/rcs/envs/sim.py @@ -5,7 +5,7 @@ import numpy as np import rcs from rcs import sim -from rcs.envs.base import ControlMode, GripperWrapper, RobotEnv +from rcs.envs.base import ControlMode, GripperWrapper, MultiRobotWrapper, RobotEnv from rcs.envs.space_utils import ActObsInfoWrapper, VecType from rcs.envs.utils import default_fr3_sim_robot_cfg @@ -25,7 +25,7 @@ def __init__(self, env: gym.Env, simulation: sim.Sim): self.sim = simulation -class FR3Sim(gym.Wrapper): +class RobotSimWrapper(gym.Wrapper): def __init__(self, env, simulation: sim.Sim, sim_wrapper: Type[SimWrapper] | None = None): self.sim_wrapper = sim_wrapper if sim_wrapper is not None: @@ -58,6 +58,47 @@ def reset( return obs, info +class MultiSimRobotWrapper(gym.Wrapper): + """Wraps a dictionary of environments to allow for multi robot control.""" + + def __init__(self, env: MultiRobotWrapper, simulation: sim.Sim): + super().__init__(env) + self.env: MultiRobotWrapper + self.sim = simulation + self.sim_robots = cast(dict[str, sim.SimRobot], {key: e.robot for key, e in self.env.unwrapped_multi.items()}) + + def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], float, bool, bool, dict]: + _, _, _, _, info = super().step(action) + + self.sim.step_until_convergence() + info["is_sim_converged"] = self.sim.is_converged() + for key in self.envs.envs.items(): + state = self.sim_robots[key].get_state() + info[key]["collision"] = state.collision + info[key]["ik_success"] = state.ik_success + + obs = {key: env.get_obs() for key, env in self.env.unwrapped_multi.items()} + truncated = np.all([info[key]["collision"] or info[key]["ik_success"] for key in info]) + return obs, 0.0, False, bool(truncated), info + + def reset( + self, seed: dict[str, int | None] | None = None, options: dict[str, Any] | None = None # type: ignore + ) -> tuple[dict[str, Any], dict[str, Any]]: + if seed is None: + seed = {key: None for key in self.env.envs} + if options is None: + options = {key: {} for key in self.env.envs} + obs = {} + info = {} + self.sim.reset() + for key, env in self.env.envs.items(): + _, info[key] = env.reset(seed=seed[key], options=options[key]) + self.sim.step(1) + for key, env in self.env.unwrapped_multi.items(): + obs[key] = cast(dict, env.get_obs()) + return obs, info + + class GripperWrapperSim(ActObsInfoWrapper): def __init__(self, env, gripper: sim.SimGripper): super().__init__(env) @@ -178,7 +219,7 @@ def env_from_xml_paths( else: control_mode = env.unwrapped.get_control_mode() c_env: gym.Env = RobotEnv(robot, control_mode) - c_env = FR3Sim(c_env, simulation) + c_env = RobotSimWrapper(c_env, simulation) if gripper: gripper_cfg = sim.SimGripperConfig() gripper_cfg.add_id(id)