diff --git a/pyproject.toml b/pyproject.toml index ccaa9d2..ccb3fb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "scyjava" version = "1.12.2.dev0" description = "Supercharged Java access from Python" license = "Unlicense" -authors = [{name = "SciJava developers", email = "ctrueden@wisc.edu"}] +authors = [{ name = "SciJava developers", email = "ctrueden@wisc.edu" }] readme = "README.md" keywords = ["java", "maven", "cross-language"] classifiers = [ @@ -35,6 +35,7 @@ dependencies = [ "jpype1 >= 1.3.0", "jgo", "cjdk", + "stubgenj", ] [dependency-groups] @@ -50,6 +51,9 @@ dev = [ "validate-pyproject[all]", ] +[project.scripts] +scyjava-stubgen = "scyjava._stubs._cli:main" + [project.urls] homepage = "https://github.com/scijava/scyjava" documentation = "https://github.com/scijava/scyjava/blob/main/README.md" @@ -58,7 +62,7 @@ download = "https://pypi.org/project/scyjava/" tracker = "https://github.com/scijava/scyjava/issues" [tool.setuptools] -package-dir = {"" = "src"} +package-dir = { "" = "src" } include-package-data = false [tool.setuptools.packages.find] diff --git a/src/scyjava/_jvm.py b/src/scyjava/_jvm.py index 224ac61..2f0e4a8 100644 --- a/src/scyjava/_jvm.py +++ b/src/scyjava/_jvm.py @@ -363,7 +363,7 @@ def is_awt_initialized() -> bool: return False Thread = scyjava.jimport("java.lang.Thread") threads = Thread.getAllStackTraces().keySet() - return any(t.getName().startsWith("AWT-") for t in threads) + return any(str(t.getName()).startswith("AWT-") for t in threads) def when_jvm_starts(f) -> None: diff --git a/src/scyjava/_stubs/__init__.py b/src/scyjava/_stubs/__init__.py new file mode 100644 index 0000000..d6a5e7c --- /dev/null +++ b/src/scyjava/_stubs/__init__.py @@ -0,0 +1,4 @@ +from ._dynamic_import import setup_java_imports +from ._genstubs import generate_stubs + +__all__ = ["setup_java_imports", "generate_stubs"] diff --git a/src/scyjava/_stubs/_cli.py b/src/scyjava/_stubs/_cli.py new file mode 100644 index 0000000..cd2670b --- /dev/null +++ b/src/scyjava/_stubs/_cli.py @@ -0,0 +1,181 @@ +"""The scyjava-stubs executable. + +Provides cli access to the `scyjava._stubs.generate_stubs` function. + +The only interesting additional things going on here is the choice of *where* the stubs +go by default. When using the CLI, they land in `scyjava.types` by default; see the +`_get_output_dir` helper function for details on how the output directory is resolved +from the CLI arguments. +""" + +from __future__ import annotations + +import argparse +import importlib +import importlib.util +import logging +import sys +from pathlib import Path + +from ._genstubs import generate_stubs + + +def main() -> None: + """The main entry point for the scyjava-stubs executable.""" + logging.basicConfig(level="INFO") + parser = argparse.ArgumentParser( + description="Generate Python Type Stubs for Java classes." + ) + parser.add_argument( + "endpoints", + type=str, + nargs="+", + help="Maven endpoints to install and use (e.g. org.myproject:myproject:1.0.0)", + ) + parser.add_argument( + "--prefix", + type=str, + help="package prefixes to generate stubs for (e.g. org.myproject), " + "may be used multiple times. If not specified, prefixes are gleaned from the " + "downloaded artifacts.", + action="append", + default=[], + metavar="PREFIX", + dest="prefix", + ) + path_group = parser.add_mutually_exclusive_group() + path_group.add_argument( + "--output-dir", + type=str, + default=None, + help="Filesystem path to write stubs to.", + ) + path_group.add_argument( + "--output-python-path", + type=str, + default=None, + help="Python path to write stubs to (e.g. 'scyjava.types').", + ) + parser.add_argument( + "--convert-strings", + dest="convert_strings", + action="store_true", + default=False, + help="convert java.lang.String to python str in return types. " + "consult the JPype documentation on the convertStrings flag for details", + ) + parser.add_argument( + "--no-javadoc", + dest="with_javadoc", + action="store_false", + default=True, + help="do not generate docstrings from JavaDoc where available", + ) + + rt_group = parser.add_mutually_exclusive_group() + rt_group.add_argument( + "--runtime-imports", + dest="runtime_imports", + action="store_true", + default=True, + help="Add runtime imports to the generated stubs. ", + ) + rt_group.add_argument( + "--no-runtime-imports", dest="runtime_imports", action="store_false" + ) + + parser.add_argument( + "--remove-namespace-only-stubs", + dest="remove_namespace_only_stubs", + action="store_true", + default=False, + help="Remove stubs that export no names beyond a single __module_protocol__. " + "This leaves some folders as PEP420 implicit namespace folders.", + ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + output_dir = _get_output_dir(args.output_dir, args.output_python_path) + if not output_dir.exists(): + output_dir.mkdir(parents=True, exist_ok=True) + + # Determine the Python package prefix for import rewriting + python_package_prefix = args.output_python_path or _derive_python_prefix( + args.output_dir + ) + + generate_stubs( + endpoints=args.endpoints, + prefixes=args.prefix, + output_dir=output_dir, + convert_strings=args.convert_strings, + include_javadoc=args.with_javadoc, + add_runtime_imports=args.runtime_imports, + remove_namespace_only_stubs=args.remove_namespace_only_stubs, + python_package_prefix=python_package_prefix, + ) + + +def _derive_python_prefix(output_dir: str | None) -> str: + """Derive the Python package prefix from the output directory. + + If output_dir is None, defaults to 'scyjava.types'. + """ + if output_dir: + # For a filesystem path, we can't reliably derive the Python prefix + # Return empty string to skip import rewriting + return "" + # Default case: stubs go to scyjava.types + return "scyjava.types" + + +def _get_output_dir(output_dir: str | None, python_path: str | None) -> Path: + if out_dir := output_dir: + return Path(out_dir) + if pp := python_path: + return _glean_path(pp) + try: + import scyjava + + return Path(scyjava.__file__).parent / "types" + except ImportError: + return Path("stubs") + + +def _glean_path(pp: str) -> Path: + try: + importlib.import_module(pp.split(".")[0]) + except ModuleNotFoundError: + # the top level module doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + try: + spec = importlib.util.find_spec(pp) + except ModuleNotFoundError as e: + # at least one of the middle levels doesn't exist: + raise NotImplementedError(f"Cannot install stubs to {pp}: {e}") + + new_ns = None + if not spec: + # if we get here, it means everything but the last level exists: + parent, new_ns = pp.rsplit(".", 1) + spec = importlib.util.find_spec(parent) + + if not spec: + # if we get here, it means the last level doesn't exist: + raise ValueError(f"Module {pp} does not exist. Cannot install stubs there.") + + search_locations = spec.submodule_search_locations + if not spec.loader and search_locations: + # namespace package with submodules + return Path(search_locations[0]) + if spec.origin: + return Path(spec.origin).parent + if new_ns and search_locations: + # namespace package with submodules + return Path(search_locations[0]) / new_ns + + raise ValueError(f"Error finding module {pp}. Cannot install stubs there.") diff --git a/src/scyjava/_stubs/_dynamic_import.py b/src/scyjava/_stubs/_dynamic_import.py new file mode 100644 index 0000000..16e27d4 --- /dev/null +++ b/src/scyjava/_stubs/_dynamic_import.py @@ -0,0 +1,131 @@ +"""Logic for using generated type stubs as runtime importable, with lazy JVM startup. + +Most often, the functionality here will be used as follows: + +``` +from scyjava._stubs import setup_java_imports + +__all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints=["org.scijava:parsington:3.1.0"], + base_prefix="org" +) +``` + +...and that little snippet is written into the generated stubs modules by the +`scyjava._stubs.generate_stubs` function. + +See docstring of `setup_java_imports` for details on how it works. +""" + +import ast +from logging import warning +from pathlib import Path +from typing import Any, Callable, Sequence + + +def setup_java_imports( + module_name: str, + module_file: str, + endpoints: Sequence[str] = (), + base_prefix: str = "", +) -> tuple[list[str], Callable[[str], Any]]: + """Setup a module to dynamically import Java class names. + + This function creates a `__getattr__` function that, when called, will dynamically + import the requested class from the Java namespace corresponding to the calling + module. + + :param module_name: The dotted name/identifier of the module that is calling this + function (usually `__name__` in the calling module). + :param module_file: The path to the module file (usually `__file__` in the calling + module). + :param endpoints: A list of Java endpoints to add to the scyjava configuration. + (Note that `scyjava._stubs.generate_stubs` will automatically add the necessary + endpoints for the generated stubs.) + :param base_prefix: The base prefix for the Java package name. This is used when + determining the Java class path for the requested class. The java class path + will be truncated to only the part including the base_prefix and after. This + makes it possible to embed a module in a subpackage (like `scyjava.types`) and + still have the correct Java class path. + + :return: A 2-tuple containing: + - A list of all classes in the module (as defined in the stub file), to be + assigned to `__all__`. + - A callable that takes a class name and returns a proxy for the Java class. + This callable should be assigned to `__getattr__` in the calling module. + The proxy object, when called, will start the JVM, import the Java class, + and return an instance of the class. The JVM will *only* be started when + the object is called. + + Example: + If the module calling this function is named `scyjava.types.org.scijava.parsington`, + then it should invoke this function as: + + .. code-block:: python + + from scyjava._stubs import setup_java_imports + + __all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints=["org.scijava:parsington:3.1.0"], + base_prefix="org" + ) + """ + import scyjava + import scyjava.config + + for ep in endpoints: + if ep not in scyjava.config.endpoints: + scyjava.config.endpoints.append(ep) + + # list intended to be assigned to `__all__` in the generated module. + module_all = [] + try: + my_stub = Path(module_file).with_suffix(".pyi") + stub_ast = ast.parse(my_stub.read_text()) + module_all = sorted( + { + node.name + for node in stub_ast.body + if isinstance(node, ast.ClassDef) and not node.name.startswith("__") + } + ) + except (OSError, SyntaxError): + warning( + f"Failed to read stub file {my_stub!r}. Falling back to empty __all__.", + stacklevel=3, + ) + + def module_getattr(name: str, mod_name: str = module_name) -> Any: + """Function intended to be assigned to __getattr__ in the generate module.""" + if module_all and name not in module_all: + raise AttributeError(f"module {module_name!r} has no attribute {name!r}") + + # cut the mod_name to only the part including the base_prefix and after + if base_prefix in mod_name: + mod_name = mod_name[mod_name.index(base_prefix) :] + + class_path = f"{mod_name}.{name}" + + # Generate a proxy type (with a nice repr) that + # delays the call to `jimport` until the last moment when type.__new__ is called + + class ProxyMeta(type): + def __repr__(self) -> str: + return f"" + + class Proxy(metaclass=ProxyMeta): + def __new__(_cls_, *args: Any, **kwargs: Any) -> Any: + cls = scyjava.jimport(class_path) + return cls(*args, **kwargs) + + Proxy.__name__ = name + Proxy.__qualname__ = name + Proxy.__module__ = module_name + Proxy.__doc__ = f"Proxy for {class_path}" + return Proxy + + return module_all, module_getattr diff --git a/src/scyjava/_stubs/_genstubs.py b/src/scyjava/_stubs/_genstubs.py new file mode 100644 index 0000000..c61d04f --- /dev/null +++ b/src/scyjava/_stubs/_genstubs.py @@ -0,0 +1,348 @@ +"""Type stub generation utilities using stubgen. + +This module provides utilities for generating type stubs for Java classes +using the stubgenj library. `stubgenj` must be installed for this to work +(it, in turn, only depends on JPype). + +See `generate_stubs` for most functionality. For the command-line tool, +see `scyjava._stubs.cli`, which provides a CLI interface for the `generate_stubs` +function. +""" + +from __future__ import annotations + +import ast +import logging +import os +import shutil +import subprocess +import sys +from importlib import import_module +from itertools import chain +from pathlib import Path, PurePath +from typing import TYPE_CHECKING, Any +from unittest.mock import patch +from zipfile import ZipFile + +import scyjava +import scyjava.config + +if TYPE_CHECKING: + from collections.abc import Sequence + +logger = logging.getLogger(__name__) + + +def generate_stubs( + endpoints: Sequence[str], + prefixes: Sequence[str] = (), + output_dir: str | Path = "stubs", + convert_strings: bool = True, + include_javadoc: bool = True, + add_runtime_imports: bool = True, + remove_namespace_only_stubs: bool = False, + python_package_prefix: str = "", +) -> None: + """Generate stubs for the given maven endpoints. + + Parameters + ---------- + endpoints : Sequence[str] + The maven endpoints to generate stubs for. This should be a list of GAV + coordinates, e.g. ["org.apache.commons:commons-lang3:3.12.0"]. + prefixes : Sequence[str], optional + The prefixes to generate stubs for. This should be a list of Java class + prefixes that you expect to find in the endpoints. For example, + ["org.apache.commons"]. If not provided, the prefixes will be + automatically determined from the jar files provided by endpoints (see the + `_list_top_level_packages` helper function). + output_dir : str | Path, optional + The directory to write the generated stubs to. Defaults to "stubs" in the + current working directory. + convert_strings : bool, optional + Whether to cast Java strings to Python strings in the stubs. Defaults to True. + NOTE: This leads to type stubs that may not be strictly accurate at runtime. + The actual runtime type of strings is determined by whether jpype.startJVM is + called with the `convertStrings` argument set to True or False. By setting + this `convert_strings` argument to true, the type stubs will be generated as if + `convertStrings` is set to True: that is, all string types will be listed as + `str` rather than `java.lang.String | str`. This is a safer default (as `str`) + is a subtype of `java.lang.String`), but may lead to type errors in some cases. + include_javadoc : bool, optional + Whether to include Javadoc in the generated stubs. Defaults to True. + add_runtime_imports : bool, optional + Whether to add runtime imports to the generated stubs. Defaults to True. + This is useful if you want to actually import the stubs as a runtime package + with type safety. The runtime import "magic" depends on the + `scyjava._stubs.setup_java_imports` function. See its documentation for + more details. + remove_namespace_only_stubs : bool, optional + Whether to remove stubs that export no names beyond a single + `__module_protocol__`. This leaves some folders as PEP420 implicit namespace + folders. Defaults to False. Setting this to `True` is useful if you want to + merge the generated stubs with other stubs in the same namespace. Without this, + the `__init__.pyi` for any given module will be whatever whatever the *last* + stub generator wrote to it (and therefore inaccurate). + python_package_prefix : str, optional + The Python package prefix under which stubs are being installed. For example, + if stubs are being installed to `scyjava.types.org.scijava...`, this should be + "scyjava.types". This is used to rewrite imports in the stub files so that + type checkers can properly resolve cross-references. Defaults to "". + """ + try: + import stubgenj + except ImportError as e: + raise ImportError( + "stubgenj is not installed, but is required to generate java stubs. " + "Please install it with `pip/conda install stubgenj`." + ) from e + + import jpype + + startJVM = jpype.startJVM + + scyjava.config.endpoints.extend(endpoints) + + def _patched_start(*args: Any, **kwargs: Any) -> None: + kwargs.setdefault("convertStrings", convert_strings) + startJVM(*args, **kwargs) + + with patch.object(jpype, "startJVM", new=_patched_start): + scyjava.start_jvm() + + _prefixes = set(prefixes) + if not _prefixes: + cp = jpype.getClassPath(env=False) + ep_artifacts = tuple(ep.split(":")[1] for ep in endpoints) + for j in cp.split(os.pathsep): + if Path(j).name.startswith(ep_artifacts): + _prefixes.update(_list_top_level_packages(j)) + + prefixes = sorted(_prefixes) + logger.info(f"Using endpoints: {scyjava.config.endpoints!r}") + logger.info(f"Generating stubs for: {prefixes}") + logger.info(f"Writing stubs to: {output_dir}") + + metapath = sys.meta_path + try: + import jpype.imports + + jmodules = [import_module(prefix) for prefix in prefixes] + finally: + # remove the jpype.imports magic from the import system + # if it wasn't there to begin with + sys.meta_path = metapath + + stubgenj.generateJavaStubs( + jmodules, + useStubsSuffix=False, + outputDir=str(output_dir), + jpypeJPackageStubs=False, + includeJavadoc=include_javadoc, + ) + + output_dir = Path(output_dir) + if python_package_prefix: + logger.info( + "Rewriting stub imports with Python package prefix: %s", + python_package_prefix, + ) + + if add_runtime_imports: + logger.info("Adding runtime imports to generated stubs") + + for stub in output_dir.rglob("*.pyi"): + # Rewrite imports if a Python package prefix was specified + if python_package_prefix: + _rewrite_stub_imports(stub, python_package_prefix) + + stub_ast = ast.parse(stub.read_text()) + members = {node.name for node in stub_ast.body if hasattr(node, "name")} + if members == {"__module_protocol__"}: + # this is simply a module stub... no exports + if remove_namespace_only_stubs: + logger.info("Removing namespace only stub %s", stub) + stub.unlink() + continue + if add_runtime_imports: + real_import = stub.with_suffix(".py") + base_prefix = stub.relative_to(output_dir).parts[0] + real_import.write_text( + INIT_TEMPLATE.format( + endpoints=repr(endpoints), + base_prefix=repr(base_prefix), + ) + ) + + ruff_check(output_dir.absolute()) + + +# the "real" init file that goes into the stub package +INIT_TEMPLATE = """\ +# this file was autogenerated by scyjava-stubgen +# it creates a __getattr__ function that will dynamically import +# the requested class from the Java namespace corresponding to this module. +# see scyjava._stubs for implementation details. +from scyjava._stubs import setup_java_imports + +__all__, __getattr__ = setup_java_imports( + __name__, + __file__, + endpoints={endpoints}, + base_prefix={base_prefix}, +) +""" + + +def _rewrite_stub_imports(stub_path: Path, python_package_prefix: str) -> None: + """Rewrite imports in a stub file to use the full Python package path. + + When stubs are generated into a subdirectory like scyjava/types, they need to have + their imports rewritten so that type checkers can resolve cross-references. This + function transforms imports like: + + import org.scijava.object + + into: + + import scyjava.types.org.scijava.object + + and transforms type references like: + + org.scijava.object.ObjectIndex + + into: + + scyjava.types.org.scijava.object.ObjectIndex + """ + import re + + content = stub_path.read_text() + + # Split into lines for import processing + lines = content.split("\n") + new_lines = [] + import_patterns = [] # Patterns to replace in annotations + + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.strip() + + # Handle import statements + if stripped.startswith("import ") and ( + "org." in stripped or "java." in stripped + ): + # Parse "import org.scijava.service" + match = stripped.split() + if len(match) >= 2: + module_name = match[1] + if ( + not module_name.startswith(".") + and "scyjava.types" not in module_name + ): + # Only rewrite org.* imports (not java.*) + if module_name.startswith("org."): + new_module = f"{python_package_prefix}.{module_name}" + # Preserve indentation + indent = line[: len(line) - len(line.lstrip())] + new_lines.append(f"{indent}import {new_module}") + # Record this pattern for later annotation rewriting + import_patterns.append((module_name, new_module)) + i += 1 + continue + + # Handle "from X import Y" statements + elif stripped.startswith("from ") and (" import " in stripped): + if "org." in stripped or "java." in stripped: + parts = stripped.split(" import ") + if len(parts) == 2: + module_part = parts[0].replace("from ", "").strip() + imports_part = parts[1].strip() + + if ( + not module_part.startswith(".") + and "scyjava.types" not in module_part + ): + if module_part.startswith("org."): + new_module = f"{python_package_prefix}.{module_part}" + indent = line[: len(line) - len(line.lstrip())] + new_lines.append( + f"{indent}from {new_module} import {imports_part}" + ) + import_patterns.append((module_part, new_module)) + i += 1 + continue + + new_lines.append(line) + i += 1 + + # Reconstruct content with rewritten imports + new_content = "\n".join(new_lines) + + # Now rewrite type annotations that reference org.* packages + # Only replace in type hints, not in already-rewritten import statements + # We do this by replacing the old module names with new ones, but being careful + # to not double-replace + for old_prefix, new_prefix in import_patterns: + # Replace qualified names like "org.scijava.service.ServiceIndex" + # but NOT names that already have the prefix + # Pattern: old_prefix followed by a dot and word characters + # Use negative lookbehind to avoid replacing if already prefixed + pattern = ( + r"(? None: + """Run ruff check and format on the generated stubs.""" + if not shutil.which("ruff"): + return + + py_files = [str(x) for x in chain(output.rglob("*.py"), output.rglob("*.pyi"))] + logger.info( + "Running ruff check on %d generated stubs in %s", + len(py_files), + str(output), + ) + subprocess.run( + [ + "ruff", + "check", + *py_files, + "--quiet", + "--fix-only", + "--unsafe-fixes", + f"--select={select}", + ] + ) + logger.info("Running ruff format") + subprocess.run(["ruff", "format", *py_files, "--quiet"]) + + +def _list_top_level_packages(jar_path: str) -> set[str]: + """Inspect a JAR file and return the set of top-level Java package names.""" + packages: set[str] = set() + with ZipFile(jar_path, "r") as jar: + # find all classes + class_dirs = { + entry.parent + for x in jar.namelist() + if (entry := PurePath(x)).suffix == ".class" + } + + roots: set[PurePath] = set() + for p in sorted(class_dirs, key=lambda p: len(p.parts)): + # If none of the already accepted roots is a parent of p, keep p + if not any(root in p.parents for root in roots): + roots.add(p) + packages.update({str(p).replace(os.sep, ".") for p in roots}) + + return packages diff --git a/src/scyjava/types/.gitignore b/src/scyjava/types/.gitignore new file mode 100644 index 0000000..5e7d273 --- /dev/null +++ b/src/scyjava/types/.gitignore @@ -0,0 +1,4 @@ +# Ignore everything in this directory +* +# Except this file +!.gitignore diff --git a/tests/test_stubgen.py b/tests/test_stubgen.py new file mode 100644 index 0000000..3d005b1 --- /dev/null +++ b/tests/test_stubgen.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import ast +import sys +from pathlib import Path +from unittest.mock import patch + +import jpype +import pytest + +import scyjava +from scyjava._stubs import _cli + + +@pytest.mark.skipif( + scyjava.config.mode != scyjava.config.Mode.JPYPE, + reason="Stubgen not supported in JEP", +) +def test_stubgen(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + # run the stubgen command as if it was run from the command line + monkeypatch.setattr( + sys, + "argv", + [ + "scyjava-stubgen", + "org.scijava:parsington:3.1.0", + "--output-dir", + str(tmp_path), + ], + ) + _cli.main() + + # remove the `jpype.imports` magic from the import system if present + mp = [x for x in sys.meta_path if not isinstance(x, jpype.imports._JImportLoader)] + monkeypatch.setattr(sys, "meta_path", mp) + + # add tmp_path to the import path + monkeypatch.setattr(sys, "path", [str(tmp_path)]) + + # first cleanup to make sure we are not importing from the cache + sys.modules.pop("org", None) + sys.modules.pop("org.scijava", None) + sys.modules.pop("org.scijava.parsington", None) + # make sure the stubgen command works and that we can now impmort stuff + + with patch.object(scyjava._jvm, "start_jvm") as mock_start_jvm: + from org.scijava.parsington import Function + + assert Function is not None + # ensure that no calls to start_jvm were made + mock_start_jvm.assert_not_called() + + # only after instantiating the class should we have a call to start_jvm + func = Function(1) + mock_start_jvm.assert_called_once() + assert isinstance(func, jpype.JObject) + + +@pytest.mark.skipif( + scyjava.config.mode != scyjava.config.Mode.JPYPE, + reason="Stubgen not supported in JEP", +) +def test_stubgen_type_references( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """Test that generated stubs have properly qualified type references. + + This validates that when stubs are generated with a Python package prefix, + all type references are properly rewritten so type checkers can resolve them. + """ + import tempfile + + # Generate stubs with --output-python-path so a prefix is used + # (rather than --output-dir which doesn't imply a Python module path) + stubs_module = "test_stubs" + + # Create a temporary directory and add it to sys.path + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + original_path = sys.path.copy() + + try: + sys.path.insert(0, str(tmpdir)) + + # Create the parent module package + stubs_pkg = tmpdir / stubs_module + stubs_pkg.mkdir() + (stubs_pkg / "__init__.py").touch() + + monkeypatch.setattr( + sys, + "argv", + [ + "scyjava-stubgen", + "org.scijava:parsington:3.1.0", + "--output-python-path", + stubs_module, + ], + ) + _cli.main() + + # Check that import statements were rewritten with the prefix + init_stub = stubs_pkg / "org" / "scijava" / "parsington" / "__init__.pyi" + assert init_stub.exists(), f"Expected stub file {init_stub} not found" + + content = init_stub.read_text() + stub_ast = ast.parse(content) + + # Find all Import and ImportFrom nodes + imports = [ + node + for node in ast.walk(stub_ast) + if isinstance(node, (ast.Import, ast.ImportFrom)) + ] + + # Collect imported module names + imported_modules = set() + for imp in imports: + if isinstance(imp, ast.Import): + for alias in imp.names: + imported_modules.add(alias.name) + elif isinstance(imp, ast.ImportFrom) and imp.module: + imported_modules.add(imp.module) + + # Verify that bare org.scijava.* imports don't exist (they should be prefixed) + org_imports = { + m for m in imported_modules if m and m.startswith("org.scijava.") + } + assert not org_imports, ( + f"Found unrewritten org.scijava imports in {init_stub}: {org_imports}. " + f"These should have been prefixed with '{stubs_module}.'" + ) + finally: + sys.path = original_path