diff --git a/docs/source/changes.md b/docs/source/changes.md index ec245536..94cdc796 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -5,6 +5,13 @@ chronological order. Releases follow [semantic versioning](https://semver.org/) releases are available on [PyPI](https://pypi.org/project/pytask) and [Anaconda.org](https://anaconda.org/conda-forge/pytask). +## 0.5.3 - 2025-xx-xx + +- {pull}`650` allows to identify from which data catalog a node is coming from. Thanks + to {user}`felixschmitz` for the report! The feature is enabled by adding an + `attributes` field on `PNode` and `PProvisionalNode` that will be mandatory on custom + nodes in v0.6.0. + ## 0.5.2 - 2024-12-19 - {pull}`633` adds support for Python 3.13 and drops support for 3.8. diff --git a/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py b/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py index ff01f495..e4d00b2e 100644 --- a/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py +++ b/docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py @@ -15,12 +15,20 @@ class PickleNode: Name of the node which makes it identifiable in the DAG. path The path to the file. + attributes + Additional attributes that are stored in the node. """ - def __init__(self, name: str = "", path: Path | None = None) -> None: + def __init__( + self, + name: str = "", + path: Path | None = None, + attributes: dict[Any, Any] | None = None, + ) -> None: self.name = name self.path = path + self.attributes = attributes or {} @property def signature(self) -> str: diff --git a/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py b/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py index 98281374..d6499a64 100644 --- a/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py +++ b/docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py @@ -16,12 +16,20 @@ class PickleNode: Name of the node which makes it identifiable in the DAG. path The path to the file. + attributes + Additional attributes that are stored in the node. """ - def __init__(self, name: str = "", path: Optional[Path] = None) -> None: + def __init__( + self, + name: str = "", + path: Optional[Path] = None, + attributes: Optional[dict[Any, Any]] = None, + ) -> None: self.name = name self.path = path + self.attributes = attributes or {} @property def signature(self) -> str: diff --git a/pyproject.toml b/pyproject.toml index 3d5a4638..de9f3d05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,12 +84,6 @@ Tracker = "https://github.com/pytask-dev/pytask/issues" [project.scripts] pytask = "pytask:cli" -[tool.uv.sources] -pytask-parallel = { workspace = true } - -[tool.uv.workspace] -members = ["packages/*"] - [tool.uv] dev-dependencies = [ "tox-uv>=1.7.0", "pygraphviz;platform_system=='Linux'", diff --git a/src/_pytask/click.py b/src/_pytask/click.py index 6daab734..9b19eb93 100644 --- a/src/_pytask/click.py +++ b/src/_pytask/click.py @@ -24,6 +24,7 @@ from _pytask import __version__ as version from _pytask.console import console +from _pytask.console import create_panel_title if TYPE_CHECKING: from collections.abc import Sequence @@ -109,7 +110,7 @@ def format_help( console.print( Panel( commands_table, - title="[bold #f2f2f2]Commands[/]", + title=create_panel_title("Commands"), title_align="left", border_style="grey37", ) @@ -244,7 +245,7 @@ def _print_options(group_or_command: Command | DefaultGroup, ctx: Context) -> No console.print( Panel( options_table, - title="[bold #f2f2f2]Options[/]", + title=create_panel_title("Options"), title_align="left", border_style="grey37", ) diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index f9bc70fd..0df50442 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -33,6 +33,7 @@ from _pytask.node_protocols import PPathNode from _pytask.node_protocols import PProvisionalNode from _pytask.node_protocols import PTask +from _pytask.node_protocols import warn_about_upcoming_attributes_field_on_nodes from _pytask.nodes import DirectoryNode from _pytask.nodes import PathNode from _pytask.nodes import PythonNode @@ -385,6 +386,9 @@ def pytask_collect_node( # noqa: C901, PLR0912 """ node = node_info.value + if isinstance(node, (PNode, PProvisionalNode)) and not hasattr(node, "attributes"): + warn_about_upcoming_attributes_field_on_nodes() + if isinstance(node, DirectoryNode): if node.root_dir is None: node.root_dir = path diff --git a/src/_pytask/console.py b/src/_pytask/console.py index e46838a4..ed932d3f 100644 --- a/src/_pytask/console.py +++ b/src/_pytask/console.py @@ -24,6 +24,7 @@ from rich.theme import Theme from rich.tree import Tree +from _pytask.data_catalog_utils import DATA_CATALOG_NAME_FIELD from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode from _pytask.node_protocols import PProvisionalNode @@ -42,6 +43,7 @@ __all__ = [ "console", + "create_panel_title", "create_summary_panel", "create_url_style_for_path", "create_url_style_for_task", @@ -146,6 +148,11 @@ def format_node_name( """Format the name of a node.""" if isinstance(node, PPathNode): if node.name != node.path.as_posix(): + # For example, any node added to a data catalog has its name set to the key. + if data_catalog_name := getattr(node, "attributes", {}).get( + DATA_CATALOG_NAME_FIELD + ): + return Text(f"{data_catalog_name}::{node.name}") return Text(node.name) name = shorten_path(node.path, paths) return Text(name) @@ -156,6 +163,11 @@ def format_node_name( reduced_name = shorten_path(Path(path), paths) return Text(f"{reduced_name}::{rest}") + # Python or other custom nodes that are not PathNodes. + if data_catalog_name := getattr(node, "attributes", {}).get( + DATA_CATALOG_NAME_FIELD + ): + return Text(f"{data_catalog_name}::{node.name}") return Text(node.name) @@ -293,10 +305,15 @@ def create_summary_panel( return Panel( grid, - title="[bold #f2f2f2]Summary[/]", + title=create_panel_title("Summary"), expand=False, style="none", border_style=outcome_enum.FAIL.style if counts[outcome_enum.FAIL] else outcome_enum.SUCCESS.style, ) + + +def create_panel_title(title: str) -> Text: + """Create a title for a panel.""" + return Text(title, style="bold #f2f2f2") diff --git a/src/_pytask/data_catalog.py b/src/_pytask/data_catalog.py index 8a9a08cd..ed50a5b1 100644 --- a/src/_pytask/data_catalog.py +++ b/src/_pytask/data_catalog.py @@ -17,11 +17,13 @@ from attrs import field from _pytask.config_utils import find_project_root_and_config +from _pytask.data_catalog_utils import DATA_CATALOG_NAME_FIELD from _pytask.exceptions import NodeNotCollectedError from _pytask.models import NodeInfo from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode from _pytask.node_protocols import PProvisionalNode +from _pytask.node_protocols import warn_about_upcoming_attributes_field_on_nodes from _pytask.nodes import PickleNode from _pytask.pluginmanager import storage from _pytask.session import Session @@ -92,6 +94,10 @@ def __attrs_post_init__(self) -> None: # Initialize the data catalog with persisted nodes from previous runs. for path in self.path.glob("*-node.pkl"): node = pickle.loads(path.read_bytes()) # noqa: S301 + if not hasattr(node, "attributes"): + warn_about_upcoming_attributes_field_on_nodes() + else: + node.attributes = {DATA_CATALOG_NAME_FIELD: self.name} self._entries[node.name] = node def __getitem__(self, name: str) -> PNode | PProvisionalNode: @@ -133,3 +139,9 @@ def add(self, name: str, node: PNode | PProvisionalNode | Any = None) -> None: msg = f"{node!r} cannot be parsed." raise NodeNotCollectedError(msg) self._entries[name] = collected_node + + node = self._entries[name] + if hasattr(node, "attributes"): + node.attributes[DATA_CATALOG_NAME_FIELD] = self.name + else: + warn_about_upcoming_attributes_field_on_nodes() diff --git a/src/_pytask/data_catalog_utils.py b/src/_pytask/data_catalog_utils.py new file mode 100644 index 00000000..1dabc26d --- /dev/null +++ b/src/_pytask/data_catalog_utils.py @@ -0,0 +1,6 @@ +"""Contains utilities for the data catalog.""" + +__all__ = ["DATA_CATALOG_NAME_FIELD"] + + +DATA_CATALOG_NAME_FIELD = "catalog_name" diff --git a/src/_pytask/node_protocols.py b/src/_pytask/node_protocols.py index 6a1d8fc0..56b3ab8f 100644 --- a/src/_pytask/node_protocols.py +++ b/src/_pytask/node_protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -138,3 +139,14 @@ def load(self, is_product: bool = False) -> Any: # pragma: no cover def collect(self) -> list[Any]: """Collect the objects that are defined by the provisional nodes.""" + + +def warn_about_upcoming_attributes_field_on_nodes() -> None: + warnings.warn( + "PNode and PProvisionalNode will require an 'attributes' field starting " + "with pytask v0.6.0. It is a dictionary with any type of key and values " + "similar to PTask. See https://tinyurl.com/pytask-custom-nodes for more " + "information about adjusting your custom nodes.", + stacklevel=1, + category=FutureWarning, + ) diff --git a/src/_pytask/nodes.py b/src/_pytask/nodes.py index 3c230a34..75b9a0ee 100644 --- a/src/_pytask/nodes.py +++ b/src/_pytask/nodes.py @@ -162,11 +162,14 @@ class PathNode(PPathNode): Name of the node which makes it identifiable in the DAG. path The path to the file. + attributes: dict[Any, Any] + A dictionary to store additional information of the task. """ path: Path name: str = "" + attributes: dict[Any, Any] = field(factory=dict) @property def signature(self) -> str: @@ -219,6 +222,8 @@ class PythonNode(PNode): objects. The function should return either an integer or a string. node_info The infos acquired while collecting the node. + attributes: dict[Any, Any] + A dictionary to store additional information of the task. Examples -------- @@ -237,6 +242,7 @@ class PythonNode(PNode): value: Any | NoDefault = no_default hash: bool | Callable[[Any], int | str] = False node_info: NodeInfo | None = None + attributes: dict[Any, Any] = field(factory=dict) @property def signature(self) -> str: @@ -302,11 +308,14 @@ class PickleNode(PPathNode): Name of the node which makes it identifiable in the DAG. path The path to the file. + attributes: dict[Any, Any] + A dictionary to store additional information of the task. """ path: Path name: str = "" + attributes: dict[Any, Any] = field(factory=dict) @property def signature(self) -> str: @@ -350,12 +359,15 @@ class DirectoryNode(PProvisionalNode): root_dir The pattern is interpreted relative to the path given by ``root_dir``. If ``root_dir = None``, it is the directory where the path is defined. + attributes: dict[Any, Any] + A dictionary to store additional information of the task. """ name: str = "" pattern: str = "*" root_dir: Path | None = None + attributes: dict[Any, Any] = field(factory=dict) @property def signature(self) -> str: diff --git a/src/_pytask/warnings.py b/src/_pytask/warnings.py index 6701daee..3b2325e8 100644 --- a/src/_pytask/warnings.py +++ b/src/_pytask/warnings.py @@ -12,6 +12,7 @@ from rich.panel import Panel from _pytask.console import console +from _pytask.console import create_panel_title from _pytask.pluginmanager import hookimpl from _pytask.warnings_utils import WarningReport from _pytask.warnings_utils import catch_warnings_for_item @@ -82,7 +83,9 @@ def pytask_log_session_footer(session: Session) -> None: """Log warnings at the end of a session.""" if session.warnings: renderable = _WarningsRenderable(session.warnings) - panel = Panel(renderable, title="Warnings", style="warning") + panel = Panel( + renderable, title=create_panel_title("Warnings"), style="warning" + ) console.print(panel) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 0183b1be..1f6a14b9 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -517,6 +517,7 @@ def task_example( def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): source = """ from typing import Annotated + from typing import Any from pytask import Product from pathlib import Path from attrs import define @@ -527,6 +528,7 @@ class PickleFile: name: str path: Path signature: str = "id" + attributes: dict[Any, Any] = {} def state(self): return str(self.path.stat().st_mtime) diff --git a/tests/test_node_protocols.py b/tests/test_node_protocols.py index dbf627aa..d35fbb3e 100644 --- a/tests/test_node_protocols.py +++ b/tests/test_node_protocols.py @@ -13,6 +13,7 @@ def test_node_protocol_for_custom_nodes(runner, tmp_path): source = """ from typing import Annotated + from typing import Any from pytask import Product from attrs import define from pathlib import Path @@ -22,6 +23,7 @@ class CustomNode: name: str value: str signature: str = "id" + attributes: dict[Any, Any] = {} def state(self): return self.value @@ -43,12 +45,14 @@ def task_example( result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert tmp_path.joinpath("out.txt").read_text() == "text" + assert "FutureWarning" not in result.output @pytest.mark.end_to_end def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path): source = """ from typing import Annotated + from typing import Any from pytask import Product from pathlib import Path from attrs import define @@ -60,6 +64,7 @@ class PickleFile: path: Path value: Path signature: str = "id" + attributes: dict[Any, Any] = {} def state(self): return str(self.path.stat().st_mtime) @@ -87,3 +92,40 @@ def task_example( result = runner.invoke(cli, [tmp_path.as_posix()]) assert result.exit_code == ExitCode.OK assert tmp_path.joinpath("out.txt").read_text() == "text" + + +@pytest.mark.end_to_end +def test_node_protocol_for_custom_nodes_adding_attributes(runner, tmp_path): + source = """ + from typing import Annotated + from pytask import Product + from attrs import define + from pathlib import Path + + @define + class CustomNode: + name: str + value: str + signature: str = "id" + + def state(self): + return self.value + + def load(self, is_product): + return self.value + + def save(self, value): + self.value = value + + def task_example( + data = CustomNode("custom", "text"), + out: Annotated[Path, Product] = Path("out.txt"), + ) -> None: + out.write_text(data) + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").read_text() == "text" + assert "FutureWarning" in result.output