Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
pytest brainpy/

Expand Down Expand Up @@ -77,6 +79,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
pytest brainpy/

Expand Down Expand Up @@ -106,5 +110,7 @@ jobs:
python -m pip install -r requirements-dev.txt
pip install -e .
- name: Test with pytest
env:
MPLBACKEND: Agg # Use non-interactive backend for matplotlib
run: |
pytest brainpy/
2 changes: 2 additions & 0 deletions brainpy/math/object_transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
Details please see the following.
"""

from brainstate.transform import ProgressBar

from .autograd import *
from .base import *
from .collectors import *
Expand Down
142 changes: 122 additions & 20 deletions brainpy/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numbers
from typing import Union, Sequence, Any, Dict, Callable, Optional

import jax
import jax.numpy as jnp

import brainstate
Expand All @@ -31,6 +32,42 @@
]


def _convert_progress_bar_to_pbar(
progress_bar: Union[bool, brainstate.transform.ProgressBar, int, None]
) -> Optional[brainstate.transform.ProgressBar]:
"""Convert progress_bar parameter to brainstate pbar format.

Parameters
----------
progress_bar : bool, ProgressBar, int, None
The progress_bar parameter value.

Returns
-------
pbar : ProgressBar or None
The converted ProgressBar instance or None.

Raises
------
TypeError
If progress_bar is not a valid type.
"""
if progress_bar is False or progress_bar is None:
return None
elif progress_bar is True:
return brainstate.transform.ProgressBar()
elif isinstance(progress_bar, int):
# Support brainstate convention: int means freq parameter
return brainstate.transform.ProgressBar(freq=progress_bar)
elif isinstance(progress_bar, brainstate.transform.ProgressBar):
return progress_bar
else:
raise TypeError(
f"progress_bar must be bool, int, or ProgressBar instance, "
f"got {type(progress_bar).__name__}"
)


def cond(
pred: bool,
true_fun: Union[Callable, jnp.ndarray, Array, numbers.Number],
Expand Down Expand Up @@ -205,10 +242,8 @@ def for_loop(
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
jit: Optional[bool] = None,
progress_bar: bool = False,
unroll_kwargs: Optional[Dict] = None,
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
):
"""``for-loop`` control flow with :py:class:`~.Variable`.

Expand Down Expand Up @@ -266,10 +301,6 @@ def for_loop(
If body function `body_func` receives multiple arguments,
`operands` should be a tuple/list whose length is equal to the
number of arguments.
remat: bool
Make ``fun`` recompute internal linearization points when differentiated.
jit: bool
Whether to just-in-time compile the function.
reverse: bool
Optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse, equivalent to reversing the leading
Expand All @@ -278,10 +309,37 @@ def for_loop(
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
progress_bar: bool
Whether we use the progress bar to report the running progress.
jit: bool
Whether to just-in-time compile the function. Set to ``False`` to disable JIT compilation.
progress_bar: bool, ProgressBar, int
Whether and how to display a progress bar during execution:

- ``False`` (default): No progress bar
- ``True``: Display progress bar with default settings
- ``ProgressBar`` instance: Display progress bar with custom settings
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)

For advanced customization, create a :py:class:`brainpy.math.ProgressBar` instance:

>>> import brainpy.math as bm
>>> # Custom update frequency
>>> pbar = bm.ProgressBar(freq=10)
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Custom description
>>> pbar = bm.ProgressBar(desc="Processing data")
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Update exactly 20 times during execution
>>> pbar = bm.ProgressBar(count=20)
>>> result = bm.for_loop(body_fun, operands, progress_bar=pbar)
>>>
>>> # Integer shorthand (equivalent to ProgressBar(freq=10))
>>> result = bm.for_loop(body_fun, operands, progress_bar=10)

.. versionadded:: 2.4.2
.. versionchanged:: 2.7.3
Now accepts ProgressBar instances and integers for advanced customization.
dyn_vars: Variable, sequence of Variable, dict
The instances of :py:class:`~.Variable`.

Expand All @@ -296,8 +354,6 @@ def for_loop(
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
unroll_kwargs: dict
The keyword arguments without unrolling.

Returns::

Expand All @@ -306,11 +362,45 @@ def for_loop(
"""
if not isinstance(operands, (tuple, list)):
operands = (operands,)
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
)

# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)

# Handle jit parameter
# Note: JAX's scan doesn't support zero-length inputs in disable_jit mode.
# For zero-length inputs, we need to use JIT mode even when jit=False.
should_disable_jit = False
if jit is False:
# Check if any operand has zero length
first_operand = operands[0]
is_zero_length = False
if hasattr(first_operand, 'shape') and len(first_operand.shape) > 0:
is_zero_length = (first_operand.shape[0] == 0)

if is_zero_length:
# Use JIT mode for zero-length inputs to avoid JAX limitation
import warnings
warnings.warn(
"for_loop with jit=False and zero-length input detected. "
"Using JIT mode to avoid JAX's disable_jit limitation with zero-length scans.",
UserWarning
)
else:
should_disable_jit = True

if should_disable_jit:
with jax.disable_jit():
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=pbar,
)
else:
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=pbar,
)


def scan(
Expand All @@ -320,7 +410,7 @@ def scan(
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
progress_bar: bool = False,
progress_bar: Union[bool, brainstate.transform.ProgressBar, int] = False,
):
"""``scan`` control flow with :py:class:`~.Variable`.

Expand Down Expand Up @@ -359,23 +449,35 @@ def scan(
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
progress_bar: bool
Whether we use the progress bar to report the running progress.
progress_bar: bool, ProgressBar, int
Whether and how to display a progress bar during execution:

- ``False`` (default): No progress bar
- ``True``: Display progress bar with default settings
- ``ProgressBar`` instance: Display progress bar with custom settings
- ``int``: Display progress bar updating every N iterations (treated as freq parameter)

See :py:func:`for_loop` for detailed examples of ProgressBar usage.

.. versionadded:: 2.4.2
.. versionchanged:: 2.7.3
Now accepts ProgressBar instances and integers for advanced customization.

Returns::

outs: Any
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
"""
# Convert progress_bar to pbar format
pbar = _convert_progress_bar_to_pbar(progress_bar)

return brainstate.transform.scan(
warp_to_no_state_input_output(body_fun),
init=init,
xs=operands,
reverse=reverse,
unroll=unroll,
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
pbar=pbar,
)


Expand Down
Loading
Loading