From d5fe17a64f6800d33802f8d245d9b4294f58021a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Wed, 6 Nov 2024 19:59:40 +0100 Subject: [PATCH 1/2] fix(collision guard): hard sim update and identity action - state of the real robot is setted to the sim in a hard way - support to set hard joint values in sim - if collision is detected the current robot state is set as action and not no action --- python/rcsss/_core/sim.pyi | 1 + python/rcsss/envs/sim.py | 17 ++++++----------- src/pybind/rcsss.cpp | 1 + src/sim/FR3.cpp | 8 ++++++++ src/sim/FR3.h | 1 + 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/rcsss/_core/sim.pyi b/python/rcsss/_core/sim.pyi index 4c758b10..1bab6b51 100644 --- a/python/rcsss/_core/sim.pyi +++ b/python/rcsss/_core/sim.pyi @@ -89,6 +89,7 @@ class FR3(rcsss._core.common.Robot): def __init__(self, sim: Sim, id: str, ik: rcsss._core.common.IK) -> None: ... def get_parameters(self) -> FR3Config: ... def get_state(self) -> FR3State: ... + def set_joints_hard(self, q: numpy.ndarray[typing.Literal[7], numpy.dtype[numpy.float64]]) -> None: ... def set_parameters(self, cfg: FR3Config) -> bool: ... class FR3Config(rcsss._core.common.RConfig): diff --git a/python/rcsss/envs/sim.py b/python/rcsss/envs/sim.py index 8293ed04..9dbebf6a 100644 --- a/python/rcsss/envs/sim.py +++ b/python/rcsss/envs/sim.py @@ -68,23 +68,18 @@ def __init__( self.sim.open_gui() def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], SupportsFloat, bool, bool, dict[str, Any]]: - # TODO: we should set the state of the sim to the state of the real robot + + self.collision_env.get_wrapper_attr("robot").set_joints_hard(self.unwrapped.robot.get_joint_position()) _, _, _, _, info = self.collision_env.step(action) + if self.to_joint_control: fr3_env = self.collision_env.unwrapped assert isinstance(fr3_env, FR3Env), "Collision env must be an FR3Env instance." action[self.unwrapped.joints_key] = fr3_env.robot.get_joint_position() - # modify action to be joint angles down stream - if info["collision"] or not info["ik_success"] or not info["is_sim_converged"]: - # return old obs, with truncated and print warning - self._logger.warning("Collision detected! Truncating episode: %s", info) - if self.last_obs is None: - msg = "Collisions detected and no old observation." - raise RuntimeError(msg) - old_obs, old_info = self.last_obs - old_info.update(info) - return old_obs, 0, False, True, old_info + if info["collision"]: + self._logger.warning("Collision detected! %s", info) + action[self.unwrapped.joints_key] = self.unwrapped.robot.get_joint_position() obs, reward, done, truncated, info = super().step(action) self.last_obs = obs, info diff --git a/src/pybind/rcsss.cpp b/src/pybind/rcsss.cpp index 2876995e..3753f738 100644 --- a/src/pybind/rcsss.cpp +++ b/src/pybind/rcsss.cpp @@ -470,6 +470,7 @@ PYBIND11_MODULE(_core, m) { py::arg("sim"), py::arg("id"), py::arg("ik")) .def("get_parameters", &rcs::sim::FR3::get_parameters) .def("set_parameters", &rcs::sim::FR3::set_parameters, py::arg("cfg")) + .def("set_joints_hard", &rcs::sim::FR3::set_joints_hard, py::arg("q")) .def("get_state", &rcs::sim::FR3::get_state); py::enum_(sim, "CameraType") .value("free", rcs::sim::CameraType::free) diff --git a/src/sim/FR3.cpp b/src/sim/FR3.cpp index 912a19aa..23524a84 100644 --- a/src/sim/FR3.cpp +++ b/src/sim/FR3.cpp @@ -208,6 +208,14 @@ void FR3::m_reset() { } } +void FR3::set_joints_hard(const common::Vector7d& q) { + for (size_t i = 0; i < std::size(this->ids.joints); ++i) { + size_t jnt_id = this->ids.joints[i]; + size_t jnt_qposadr = this->sim->m->jnt_qposadr[jnt_id]; + this->sim->d->qpos[jnt_qposadr] = q[i]; + } +} + common::Pose FR3::get_base_pose_in_world_coordinates() { auto id = mj_name2id(this->sim->m, mjOBJ_BODY, (std::string("base_") + this->id).c_str()); diff --git a/src/sim/FR3.h b/src/sim/FR3.h index 6d3a52d6..52acf9dc 100644 --- a/src/sim/FR3.h +++ b/src/sim/FR3.h @@ -50,6 +50,7 @@ class FR3 : public common::Robot { common::Pose get_base_pose_in_world_coordinates() override; std::optional> get_ik() override; void reset() override; + void set_joints_hard(const common::Vector7d &q); private: FR3Config cfg; From ebe0d0b183d840a88ed3a9e6bba2fe035d05a5dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20J=C3=BClg?= Date: Thu, 7 Nov 2024 08:33:22 +0100 Subject: [PATCH 2/2] fix(collision guard): back compatibility truncate episode For backward compatibility with tests there is a default true option which truncates the episode when a collision occurred --- python/rcsss/envs/sim.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python/rcsss/envs/sim.py b/python/rcsss/envs/sim.py index 9dbebf6a..581dff25 100644 --- a/python/rcsss/envs/sim.py +++ b/python/rcsss/envs/sim.py @@ -49,6 +49,7 @@ def __init__( check_home_collision: bool = True, to_joint_control: bool = False, sim_gui: bool = True, + truncate_on_collision: bool = True, ): super().__init__(env) self.unwrapped: FR3Env @@ -58,6 +59,7 @@ def __init__( self._logger = logging.getLogger(__name__) self.check_home_collision = check_home_collision self.to_joint_control = to_joint_control + self.truncate_on_collision = truncate_on_collision if to_joint_control: assert ( self.unwrapped.get_unwrapped_control_mode(-2) == ControlMode.JOINTS @@ -80,6 +82,11 @@ def step(self, action: dict[str, Any]) -> tuple[dict[str, Any], SupportsFloat, b if info["collision"]: self._logger.warning("Collision detected! %s", info) action[self.unwrapped.joints_key] = self.unwrapped.robot.get_joint_position() + if self.truncate_on_collision: + if self.last_obs is None: + msg = "Collision detected in the first step!" + raise RuntimeError(msg) + return self.last_obs[0], 0, True, True, info obs, reward, done, truncated, info = super().step(action) self.last_obs = obs, info @@ -114,6 +121,7 @@ def env_from_xml_paths( tcp_offset: rcsss.common.Pose | None = None, control_mode: ControlMode | None = None, sim_gui: bool = True, + truncate_on_collision: bool = True, ) -> "CollisionGuard": assert isinstance(env.unwrapped, FR3Env) simulation = sim.Sim(mjmld) @@ -140,4 +148,12 @@ def env_from_xml_paths( gripper_cfg = sim.FHConfig() fh = sim.FrankaHand(simulation, id, gripper_cfg) c_env = GripperWrapper(c_env, fh) - return cls(env, simulation, c_env, check_home_collision, to_joint_control, sim_gui) + return cls( + env=env, + simulation=simulation, + collision_env=c_env, + check_home_collision=check_home_collision, + to_joint_control=to_joint_control, + sim_gui=sim_gui, + truncate_on_collision=truncate_on_collision, + )