diff --git a/src/nexusrpc/_service.py b/src/nexusrpc/_service.py index f62eb67..f97ac59 100644 --- a/src/nexusrpc/_service.py +++ b/src/nexusrpc/_service.py @@ -8,6 +8,7 @@ import dataclasses import typing +import urllib.parse from collections.abc import Mapping from dataclasses import dataclass from typing import ( @@ -29,6 +30,24 @@ ) +def _validate_nexus_name(name: str) -> None: + """Validate that a name is valid for use as a Nexus service or operation name. + + Per the spec: "The service name and operation name MUST not be empty and + may contain any arbitrary character sequence as long as they're encoded + into the URL." + + Raises: + ValueError: If the name is empty, whitespace-only, or cannot be URL-encoded. + """ + if not name or not name.strip(): + raise ValueError("must not be empty") + try: + _ = urllib.parse.quote(name, safe="") + except (UnicodeEncodeError, KeyError) as e: + raise ValueError("contains characters that cannot be URL-encoded") from e + + @dataclass class Operation(Generic[InputT, OutputT]): """Defines a Nexus operation in a Nexus service definition. @@ -67,10 +86,12 @@ class OperationDefinition(Generic[InputT, OutputT]): def from_operation( cls, operation: Operation[InputT, OutputT] ) -> OperationDefinition[InputT, OutputT]: - if not operation.name: + try: + _validate_nexus_name(operation.name) + except ValueError as e: raise ValueError( - f"Operation has no name (method_name is '{operation.method_name}')" - ) + f"Operation name {operation.name!r} {e} (method_name is '{operation.method_name}')" + ) from e if not operation.method_name: raise ValueError(f"Operation '{operation.name}' has no method name") if not operation.input_type: @@ -126,9 +147,12 @@ class AnotherService: """ def decorator(cls: type[ServiceT]) -> type[ServiceT]: - if name is not None and not name: - raise ValueError("Service name must not be empty.") - defn = ServiceDefinition.from_class(cls, name or cls.__name__) + service_name = name if name is not None else cls.__name__ + try: + _validate_nexus_name(service_name) + except ValueError as e: + raise ValueError(f"Service name {service_name!r} {e}.") + defn = ServiceDefinition.from_class(cls, service_name) set_service_definition(cls, defn) # In order for callers to refer to operation definitions at run-time, a decorated user @@ -231,10 +255,18 @@ def from_class(user_class: type[ServiceT], name: str) -> ServiceDefinition: def _validation_errors(self) -> list[str]: errors = [] - if not self.name: - errors.append("Service has no name") - seen_method_names = set() + # Validate service name + try: + _validate_nexus_name(self.name) + except ValueError as e: + errors.append(f"Service name {self.name!r} {e}") + # Validate operation names and check for duplicate method names + seen_method_names: set[str] = set() for op_defn in self.operation_definitions.values(): + try: + _validate_nexus_name(op_defn.name) + except ValueError as e: + errors.append(f"Operation name {op_defn.name!r} {e}") if op_defn.method_name in seen_method_names: errors.append( f"Operation method name '{op_defn.method_name}' is not unique" diff --git a/tests/service_definition/test_service_decorator_validation.py b/tests/service_definition/test_service_decorator_validation.py index 05f76a4..9809a37 100644 --- a/tests/service_definition/test_service_decorator_validation.py +++ b/tests/service_definition/test_service_decorator_validation.py @@ -42,3 +42,112 @@ def test_operation_validation( match=str(test_case.expected_error), ): nexusrpc.service(test_case.Contract) + + +def test_empty_service_name_raises(): + """Empty string passed to @service(name='') should raise.""" + with pytest.raises(ValueError, match=r"Service name '' must not be empty"): + + @nexusrpc.service(name="") + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] + + +def test_whitespace_only_service_name_raises(): + """Whitespace-only service name should raise.""" + with pytest.raises(ValueError, match=r"Service name ' ' must not be empty"): + + @nexusrpc.service(name=" ") + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] + + +def test_non_url_encodable_service_name_raises(): + """Service name with non-URL-encodable characters should raise.""" + with pytest.raises( + ValueError, + match=r"Service name .* contains characters that cannot be URL-encoded", + ): + + @nexusrpc.service(name="invalid\ud800surrogate") + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] + + +def test_valid_service_name_with_special_chars_succeeds(): + """Service names with URL-encodable special characters should succeed.""" + + @nexusrpc.service(name="my service") + class ServiceWithSpace: + op: nexusrpc.Operation[str, str] + + _ = ServiceWithSpace + + @nexusrpc.service(name="service/with/slashes") + class ServiceWithSlashes: + op: nexusrpc.Operation[str, str] + + _ = ServiceWithSlashes + + @nexusrpc.service(name="日本語サービス") + class ServiceWithUnicode: + op: nexusrpc.Operation[str, str] + + _ = ServiceWithUnicode + + @nexusrpc.service(name="service?query=value") + class ServiceWithQueryChars: + op: nexusrpc.Operation[str, str] + + _ = ServiceWithQueryChars + + +def test_empty_operation_name_raises(): + """Empty operation name should raise.""" + with pytest.raises(ValueError, match=r"Operation name '' must not be empty"): + + @nexusrpc.service + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] = nexusrpc.Operation(name="") + + +def test_whitespace_only_operation_name_raises(): + """Whitespace-only operation name should raise.""" + with pytest.raises(ValueError, match=r"Operation name ' ' must not be empty"): + + @nexusrpc.service + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] = nexusrpc.Operation(name=" ") + + +def test_non_url_encodable_operation_name_raises(): + """Operation name with non-URL-encodable characters should raise.""" + with pytest.raises( + ValueError, + match=r"Operation name .* contains characters that cannot be URL-encoded", + ): + + @nexusrpc.service + class MyService: # pyright: ignore[reportUnusedClass] + op: nexusrpc.Operation[str, str] = nexusrpc.Operation( + name="invalid\ud800surrogate" + ) + + +def test_valid_operation_name_with_special_chars_succeeds(): + """Operation names with URL-encodable special characters should succeed.""" + + @nexusrpc.service + class MyService: + op_with_space: nexusrpc.Operation[str, str] = nexusrpc.Operation( + name="my operation" + ) + op_with_slash: nexusrpc.Operation[str, str] = nexusrpc.Operation( + name="op/with/slashes" + ) + op_unicode: nexusrpc.Operation[str, str] = nexusrpc.Operation(name="日本語操作") + op_query: nexusrpc.Operation[str, str] = nexusrpc.Operation( + name="op?param=value" + ) + + _ = MyService diff --git a/tests/test_type_errors.py b/tests/test_type_errors.py index 98a09a1..76c0a0a 100644 --- a/tests/test_type_errors.py +++ b/tests/test_type_errors.py @@ -83,7 +83,7 @@ def _test_type_errors( def _has_type_error_assertions(test_file: Path) -> bool: """Check if a file contains any type error assertions.""" - with open(test_file) as f: + with open(test_file, encoding="utf-8") as f: return any( re.search(r"# assert-type-error-\w+:", line) for line in f.readlines() ) @@ -93,7 +93,7 @@ def _get_expected_errors(test_file: Path, type_checker: str) -> dict[int, str]: """Parse expected type errors from comments in a file for the specified type checker.""" expected_errors = {} - with open(test_file) as f: + with open(test_file, encoding="utf-8") as f: lines = zip(itertools.count(1), f) for line_num, line in lines: if match := re.search(