From fb511e9894b4fa1419ab0f2b376348a44715e26e Mon Sep 17 00:00:00 2001 From: Aryan Rahar Date: Fri, 12 Sep 2025 09:30:17 +0000 Subject: [PATCH 1/2] libero: set use_quantile_norm=True in LIBERO presets (fix #627) --- src/openpi/training/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/openpi/training/config.py b/src/openpi/training/config.py index a59cd1cf6..a691332f1 100644 --- a/src/openpi/training/config.py +++ b/src/openpi/training/config.py @@ -702,7 +702,7 @@ def __post_init__(self) -> None: model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), + base_config=DataConfig(prompt_from_task=True, use_quantile_norm=True), extra_delta_transform=True, ), # Note that we load the pi0-FAST base model checkpoint here. @@ -718,7 +718,7 @@ def __post_init__(self) -> None: ), data=LeRobotLiberoDataConfig( repo_id="physical-intelligence/libero", - base_config=DataConfig(prompt_from_task=True), + base_config=DataConfig(prompt_from_task=True, use_quantile_norm=True), extra_delta_transform=True, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), From baaa8f2c497ea09611951c9b28ae5d47073be756 Mon Sep 17 00:00:00 2001 From: Aryan Rahar Date: Thu, 2 Oct 2025 20:35:49 +0000 Subject: [PATCH 2/2] fix(models): pi0_fast off-by-one; add regression test --- examples/libero/main.py | 2 +- scripts/dev_check_positions.py | 10 ++++++++++ src/openpi/models/pi0_fast.py | 2 +- src/openpi/models/pi0_fast_positions_test.py | 16 ++++++++++++++++ 4 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 scripts/dev_check_positions.py create mode 100644 src/openpi/models/pi0_fast_positions_test.py diff --git a/examples/libero/main.py b/examples/libero/main.py index dc015a617..c3fbf2417 100644 --- a/examples/libero/main.py +++ b/examples/libero/main.py @@ -100,7 +100,7 @@ def eval_libero(args: Args) -> None: t = 0 replay_images = [] - logging.info(f"Starting episode {task_episodes+1}...") + logging.info(f"Starting episode {task_episodes + 1}...") while t < max_steps + args.num_steps_wait: try: # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects diff --git a/scripts/dev_check_positions.py b/scripts/dev_check_positions.py new file mode 100644 index 000000000..1d081a709 --- /dev/null +++ b/scripts/dev_check_positions.py @@ -0,0 +1,10 @@ +import jax.numpy as jnp + + +def next_positions(prefill_len, step): + prefill_len = jnp.asarray(prefill_len) # shape [B] + return prefill_len[:, None] + step + 1 # current behavior in repo + + +# toy batch with different prefix lengths +print("current:", next_positions([3, 5], 0).tolist()) # shows [[4],[6]] but should be [[3],[5]] diff --git a/src/openpi/models/pi0_fast.py b/src/openpi/models/pi0_fast.py index e6b5bd15e..1e6d7ce93 100644 --- a/src/openpi/models/pi0_fast.py +++ b/src/openpi/models/pi0_fast.py @@ -290,7 +290,7 @@ def step(carry): # Decode one step token_embedding = self.PaliGemma.llm(token, embed_only=True) - positions = prefill_len[:, None] + step + 1 + positions = prefill_len[:, None] + step mask = jnp.logical_and( jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None], jnp.arange(prefill_size + max_decoding_steps)[None, None, :] diff --git a/src/openpi/models/pi0_fast_positions_test.py b/src/openpi/models/pi0_fast_positions_test.py new file mode 100644 index 000000000..25af5274a --- /dev/null +++ b/src/openpi/models/pi0_fast_positions_test.py @@ -0,0 +1,16 @@ +# src/openpi/models/pi0_fast_positions_test.py +import jax.numpy as jnp + + +def compute_positions(prefill_len, step): + # mirrors patched logic in pi0_fast.py: + prefill_len = jnp.asarray(prefill_len) # shape [B] + return (prefill_len[:, None] + step).tolist() # first next token at L + + +def test_next_token_is_contiguous_zero_indexed(): + # If prefix tokens are 0..L-1, next must be L + assert compute_positions([3, 5], 0) == [[3], [5]] + assert compute_positions([0, 1], 0) == [[0], [1]] + # advancing decode step should increment positions + assert compute_positions([3], 2) == [[5]]