diff --git a/pyproject.toml b/pyproject.toml index 86c03811..bcbdca46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,8 +60,8 @@ docs = [ "sphinx-design>=0.3", "sphinx-toolbox>=4.0.0", "sphinxext-opengraph>=0.10.0", - "sphinx-autobuild>=2024.10.3", ] +docs-live = ["sphinx-autobuild>=2024.10.3"] plugin-list = ["httpx>=0.27.0", "tabulate[widechars]>=0.9.0", "tqdm>=4.66.3"] test = [ "cloudpickle>=3.0.0", @@ -74,7 +74,7 @@ test = [ "pytest-cov>=5.0.0", "pytest-xdist>=3.6.1", "syrupy>=4.5.0", - "aiohttp>=3.11.0", # For HTTPPath tests. + "aiohttp>=3.11.0", # For HTTPPath tests. "coiled>=1.42.0", "pygraphviz>=1.12;platform_system=='Linux'", ] @@ -173,6 +173,16 @@ filterwarnings = [ [tool.ty.rules] unused-ignore-comment = "error" +[[tool.ty.overrides]] +include = [ + "src/_pytask/_version.py", + "src/_pytask/click.py", + "tests/test_dag_command.py", +] + +[tool.ty.overrides.rules] +unused-ignore-comment = "ignore" + [tool.ty.src] exclude = ["src/_pytask/_hashlib.py"] diff --git a/src/_pytask/_inspect.py b/src/_pytask/_inspect.py index 98a94702..bc469e6f 100644 --- a/src/_pytask/_inspect.py +++ b/src/_pytask/_inspect.py @@ -1,6 +1,140 @@ from __future__ import annotations +import ast +import inspect +import sys +from inspect import get_annotations as _get_annotations_from_inspect +from typing import TYPE_CHECKING +from typing import Any +from typing import cast + +if TYPE_CHECKING: + from collections.abc import Callable + __all__ = ["get_annotations"] -from inspect import get_annotations +def get_annotations( + obj: Callable[..., Any], + *, + globals: dict[str, Any] | None = None, # noqa: A002 + locals: dict[str, Any] | None = None, # noqa: A002 + eval_str: bool = False, +) -> dict[str, Any]: + """Return evaluated annotations with better support for deferred evaluation. + + Context + ------- + * PEP 649 introduces deferred annotations which are only evaluated when explicitly + requested. See https://peps.python.org/pep-0649/ for background and why locals can + disappear between definition and evaluation time. + * Python 3.14 ships :mod:`annotationlib` which exposes the raw annotation source and + provides the building blocks we reuse here. The module doc explains the available + formats: https://docs.python.org/3/library/annotationlib.html + * Other projects run into the same constraints. Pydantic tracks their work in + https://github.com/pydantic/pydantic/issues/12080; we might copy improvements from + there once they settle on a stable strategy. + + Rationale + --------- + When annotations refer to loop variables inside task generators, the locals that + existed during decoration have vanished by the time pytask evaluates annotations + while collecting tasks. Using :func:`inspect.get_annotations` would therefore yield + the same product path for every repeated task. By asking :mod:`annotationlib` for + string representations and re-evaluating them with reconstructed locals (globals, + default arguments, and the frame locals captured via ``@task`` at decoration time) + we recover the correct per-task values. The frame locals capture is essential for + cases where loop variables are only referenced in annotations (not in the function + body or closure). If any of these ingredients are missing—for example on Python + versions without :mod:`annotationlib` - we fall back to the stdlib implementation, + so behaviour on 3.10-3.13 remains unchanged. + """ + if not eval_str or not hasattr(obj, "__globals__"): + return _get_annotations_from_inspect( + obj, globals=globals, locals=locals, eval_str=eval_str + ) + + if sys.version_info < (3, 14): + raw_annotations = _get_annotations_from_inspect( + obj, globals=globals, locals=locals, eval_str=False + ) + evaluation_globals = cast( + "dict[str, Any]", obj.__globals__ if globals is None else globals + ) + evaluation_locals = evaluation_globals if locals is None else locals + evaluated_annotations = {} + for name, expression in raw_annotations.items(): + evaluated_annotations[name] = _evaluate_annotation_expression( + expression, evaluation_globals, evaluation_locals + ) + return evaluated_annotations + + import annotationlib # noqa: PLC0415 + + raw_annotations = annotationlib.get_annotations( + obj, globals=globals, locals=locals, format=annotationlib.Format.STRING + ) + + evaluation_globals = obj.__globals__ if globals is None else globals + evaluation_locals = _build_evaluation_locals(obj, locals) + + evaluated_annotations = {} + for name, expression in raw_annotations.items(): + evaluated_annotations[name] = _evaluate_annotation_expression( + expression, evaluation_globals, evaluation_locals + ) + + return evaluated_annotations + + +def _build_evaluation_locals( + obj: Callable[..., Any], provided_locals: dict[str, Any] | None +) -> dict[str, Any]: + # Order matters: later updates override earlier ones. + # Default arguments are lowest priority (fallbacks), then provided_locals, + # then snapshot_locals (captured loop variables) have highest priority. + evaluation_locals: dict[str, Any] = {} + evaluation_locals.update(_get_default_argument_locals(obj)) + if provided_locals: + evaluation_locals.update(provided_locals) + evaluation_locals.update(_get_snapshot_locals(obj)) + return evaluation_locals + + +def _get_snapshot_locals(obj: Callable[..., Any]) -> dict[str, Any]: + metadata = getattr(obj, "pytask_meta", None) + snapshot = getattr(metadata, "annotation_locals", None) + return dict(snapshot) if snapshot else {} + + +def _get_default_argument_locals(obj: Callable[..., Any]) -> dict[str, Any]: + try: + parameters = inspect.signature(obj).parameters.values() + except (TypeError, ValueError): + return {} + + defaults = {} + for parameter in parameters: + if parameter.default is not inspect.Parameter.empty: + defaults[parameter.name] = parameter.default + return defaults + + +def _evaluate_annotation_expression( + expression: Any, globals_: dict[str, Any] | None, locals_: dict[str, Any] +) -> Any: + if not isinstance(expression, str): + return expression + evaluation_globals = globals_ if globals_ is not None else {} + evaluated = eval(expression, evaluation_globals, locals_) # noqa: S307 + if isinstance(evaluated, str): + try: + literal = ast.literal_eval(expression) + except (SyntaxError, ValueError): + return evaluated + if isinstance(literal, str): + try: + return eval(literal, evaluation_globals, locals_) # noqa: S307 + except Exception: # noqa: BLE001 + return evaluated + return evaluated diff --git a/src/_pytask/click.py b/src/_pytask/click.py index 36f2fa04..741ce650 100644 --- a/src/_pytask/click.py +++ b/src/_pytask/click.py @@ -35,9 +35,11 @@ if importlib.metadata.version("click") < "8.2": - from click.parser import split_opt + from click.parser import split_opt as _split_opt else: - from click.parser import _split_opt as split_opt # ty: ignore[unresolved-import] + from click.parser import _split_opt # ty: ignore[unresolved-import] + +split_opt = _split_opt class EnumChoice(Choice): diff --git a/src/_pytask/models.py b/src/_pytask/models.py index 7511f3e9..3b12d442 100644 --- a/src/_pytask/models.py +++ b/src/_pytask/models.py @@ -38,6 +38,9 @@ class CollectionMetadata: kwargs A dictionary containing keyword arguments which are passed to the task when it is executed. + annotation_locals + A snapshot of local variables captured during decoration which helps evaluate + deferred annotations later on. markers A list of markers that are attached to the task. name @@ -51,6 +54,7 @@ class CollectionMetadata: after: str | list[Callable[..., Any]] = field(factory=list) attributes: dict[str, Any] = field(factory=dict) + annotation_locals: dict[str, Any] | None = None is_generator: bool = False id_: str | None = None kwargs: dict[str, Any] = field(factory=dict) diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 28c5b10c..dd92b521 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -1,14 +1,17 @@ """Contains utilities related to the :func:`@task `.""" from __future__ import annotations +import __future__ import functools import inspect +import sys from collections import defaultdict from types import BuiltinFunctionType from typing import TYPE_CHECKING from typing import Any from typing import TypeVar +from typing import cast import attrs @@ -79,30 +82,18 @@ def task( # noqa: PLR0913 information. is_generator An indicator whether this task is a task generator. - id - An id for the task if it is part of a parametrization. Otherwise, an automatic - id will be generated. See - :doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for - more information. - kwargs - A dictionary containing keyword arguments which are passed to the task when it - is executed. - produces - Definition of products to parse the function returns and store them. See - :doc:`this how-to guide <../how_to_guides/using_task_returns>` for more id An id for the task if it is part of a repetition. Otherwise, an automatic id will be generated. See :ref:`how-to-repeat-a-task-with-different-inputs-the-id` for more information. kwargs - Use a dictionary to pass any keyword arguments to the task function which can be - dependencies or products of the task. Read :ref:`task-kwargs` for more - information. - produces - Use this argument if you want to parse the return of the task function as a - product, but you cannot annotate the return of the function. See :doc:`this - how-to guide <../how_to_guides/using_task_returns>` or :ref:`task-produces` for + A dictionary containing keyword arguments which are passed to the task function. + These can be dependencies or products of the task. Read :ref:`task-kwargs` for more information. + produces + Use this argument to parse the return of the task function as a product. See + :doc:`this how-to guide <../how_to_guides/using_task_returns>` or + :ref:`task-produces` for more information. Examples -------- @@ -117,12 +108,23 @@ def create_text_file() -> Annotated[str, Path("file.txt")]: return "Hello, World!" """ + # Capture the caller's frame locals for deferred annotation evaluation in Python + # 3.14+. If ``from __future__ import annotations`` is active, keep the pre-3.14 + # behavior by evaluating annotations against current globals instead of snapshots. + caller_frame = sys._getframe(1) + has_future_annotations = bool( + caller_frame.f_code.co_flags & __future__.annotations.compiler_flag + ) + caller_locals = None if has_future_annotations else caller_frame.f_locals.copy() def wrapper(func: T) -> TaskDecorated[T]: # Omits frame when a builtin function is wrapped. _rich_traceback_omit = True - for arg, arg_name in ((name, "name"), (id, "id")): + # When @task is used without parentheses, name is the function, not a string. + effective_name = None if is_task_function(name) else name + + for arg, arg_name in ((effective_name, "name"), (id, "id")): if not (isinstance(arg, str) or arg is None): msg = ( f"Argument {arg_name!r} of @task must be a str, but it is {arg!r}." @@ -149,7 +151,7 @@ def wrapper(func: T) -> TaskDecorated[T]: path = get_file(unwrapped) parsed_kwargs = {} if kwargs is None else kwargs - parsed_name = _parse_name(unwrapped, name) + parsed_name = _parse_name(unwrapped, effective_name) parsed_after = _parse_after(after) if isinstance(unwrapped, TaskFunction): @@ -160,10 +162,11 @@ def wrapper(func: T) -> TaskDecorated[T]: unwrapped.pytask_meta.markers.append(Mark("task", (), {})) unwrapped.pytask_meta.name = parsed_name unwrapped.pytask_meta.produces = produces - unwrapped.pytask_meta.after = parsed_after + unwrapped.pytask_meta.annotation_locals = caller_locals else: unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined] after=parsed_after, + annotation_locals=caller_locals, is_generator=is_generator, id_=id, kwargs=parsed_kwargs, @@ -181,10 +184,9 @@ def wrapper(func: T) -> TaskDecorated[T]: return unwrapped - # In case the decorator is used without parentheses, wrap the function which is - # passed as the first argument with the default arguments. + # When decorator is used without parentheses, call wrapper directly. if is_task_function(name) and kwargs is None: - return task()(name) + return wrapper(cast("T", name)) return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 0ecf845b..e0b45db3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ import sys from contextlib import contextmanager from pathlib import Path +from typing import TYPE_CHECKING from typing import Any from typing import NamedTuple @@ -13,13 +14,19 @@ from click.testing import CliRunner from packaging import version +from pytask import console +from pytask import storage + +if TYPE_CHECKING: + from nbmake.pytest_items import NotebookItem as _NotebookItem + +NotebookItem: type[Any] | None try: - from nbmake.pytest_items import NotebookItem + from nbmake.pytest_items import NotebookItem as _NotebookItem except ImportError: NotebookItem = None - -from pytask import console -from pytask import storage +else: + NotebookItem = _NotebookItem @pytest.fixture(autouse=True) diff --git a/tests/test_collect.py b/tests/test_collect.py index 910c18ca..fd22b251 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -342,6 +342,68 @@ def task_my_task(): assert outcome == CollectionOutcome.SUCCESS +def test_lazy_annotations_capture_loop_locals(tmp_path): + source = """ + from pathlib import Path + from typing import Annotated + from pytask import task + + for i in range(2): + path = Path(f"out-{i}.txt") + + @task + def task_example() -> Annotated[str, path]: + return "Hello" + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = build(paths=tmp_path) + assert session.exit_code == ExitCode.OK + assert tmp_path.joinpath("out-0.txt").exists() + assert tmp_path.joinpath("out-1.txt").exists() + + +def test_lazy_annotations_use_current_globals(tmp_path): + source = """ + from __future__ import annotations + + from pathlib import Path + from typing import Annotated + + OUTPUT = Path("first.txt") + + def task_example() -> Annotated[str, OUTPUT]: + return "Hello" + + OUTPUT = Path("second.txt") + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = build(paths=tmp_path) + assert session.exit_code == ExitCode.OK + assert tmp_path.joinpath("second.txt").exists() + assert not tmp_path.joinpath("first.txt").exists() + + +def test_string_literal_annotations_are_resolved(tmp_path): + source = """ + from __future__ import annotations + + from pathlib import Path + from typing import Annotated + + OUTPUT = Path("out.txt") + + def task_example() -> 'Annotated[str, OUTPUT]': + return "Hello" + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + session = build(paths=tmp_path) + assert session.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").exists() + + def test_collect_string_product_raises_error_with_annotation(runner, tmp_path): """The string is not converted to a path.""" source = """ diff --git a/tests/test_dag_command.py b/tests/test_dag_command.py index d56927fc..f19b481f 100644 --- a/tests/test_dag_command.py +++ b/tests/test_dag_command.py @@ -12,7 +12,7 @@ from pytask import cli try: - import pygraphviz # noqa: F401 + import pygraphviz # type: ignore[unresolved-import] # noqa: F401 except ImportError: # pragma: no cover _IS_PYGRAPHVIZ_INSTALLED = False else: diff --git a/uv.lock b/uv.lock index ecd26bb1..4d0ccf28 100644 --- a/uv.lock +++ b/uv.lock @@ -2680,13 +2680,15 @@ docs = [ { name = "myst-parser" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "sphinx", version = "8.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "sphinx-autobuild" }, { name = "sphinx-click" }, { name = "sphinx-copybutton" }, { name = "sphinx-design" }, { name = "sphinx-toolbox" }, { name = "sphinxext-opengraph" }, ] +docs-live = [ + { name = "sphinx-autobuild" }, +] plugin-list = [ { name = "httpx" }, { name = "tabulate", extra = ["widechars"] }, @@ -2735,13 +2737,13 @@ docs = [ { name = "myst-nb", specifier = ">=1.2.0" }, { name = "myst-parser", specifier = ">=3.0.0" }, { name = "sphinx", specifier = ">=7.0.0" }, - { name = "sphinx-autobuild", specifier = ">=2024.10.3" }, { name = "sphinx-click", specifier = ">=6.0.0" }, { name = "sphinx-copybutton", specifier = ">=0.5.2" }, { name = "sphinx-design", specifier = ">=0.3" }, { name = "sphinx-toolbox", specifier = ">=4.0.0" }, { name = "sphinxext-opengraph", specifier = ">=0.10.0" }, ] +docs-live = [{ name = "sphinx-autobuild", specifier = ">=2024.10.3" }] plugin-list = [ { name = "httpx", specifier = ">=0.27.0" }, { name = "tabulate", extras = ["widechars"], specifier = ">=0.9.0" },