Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/libero/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def eval_libero(args: Args) -> None:
t = 0
replay_images = []

logging.info(f"Starting episode {task_episodes+1}...")
logging.info(f"Starting episode {task_episodes + 1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
Expand Down
10 changes: 10 additions & 0 deletions scripts/dev_check_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import jax.numpy as jnp


def next_positions(prefill_len, step):
prefill_len = jnp.asarray(prefill_len) # shape [B]
return prefill_len[:, None] + step + 1 # current behavior in repo


# toy batch with different prefix lengths
print("current:", next_positions([3, 5], 0).tolist()) # shows [[4],[6]] but should be [[3],[5]]
2 changes: 1 addition & 1 deletion src/openpi/models/pi0_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def step(carry):

# Decode one step
token_embedding = self.PaliGemma.llm(token, embed_only=True)
positions = prefill_len[:, None] + step + 1
positions = prefill_len[:, None] + step
mask = jnp.logical_and(
jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
Expand Down
16 changes: 16 additions & 0 deletions src/openpi/models/pi0_fast_positions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# src/openpi/models/pi0_fast_positions_test.py
import jax.numpy as jnp


def compute_positions(prefill_len, step):
# mirrors patched logic in pi0_fast.py:
prefill_len = jnp.asarray(prefill_len) # shape [B]
return (prefill_len[:, None] + step).tolist() # first next token at L


def test_next_token_is_contiguous_zero_indexed():
# If prefix tokens are 0..L-1, next must be L
assert compute_positions([3, 5], 0) == [[3], [5]]
assert compute_positions([0, 1], 0) == [[0], [1]]
# advancing decode step should increment positions
assert compute_positions([3], 2) == [[5]]
4 changes: 2 additions & 2 deletions src/openpi/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def __post_init__(self) -> None:
model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
base_config=DataConfig(prompt_from_task=True, use_quantile_norm=True),
extra_delta_transform=True,
),
# Note that we load the pi0-FAST base model checkpoint here.
Expand All @@ -718,7 +718,7 @@ def __post_init__(self) -> None:
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(prompt_from_task=True),
base_config=DataConfig(prompt_from_task=True, use_quantile_norm=True),
extra_delta_transform=True,
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
Expand Down