Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/nexusrpc/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import dataclasses
import typing
import urllib.parse
from collections.abc import Mapping
from dataclasses import dataclass
from typing import (
Expand All @@ -29,6 +30,24 @@
)


def _validate_nexus_name(name: str) -> None:
"""Validate that a name is valid for use as a Nexus service or operation name.

Per the spec: "The service name and operation name MUST not be empty and
may contain any arbitrary character sequence as long as they're encoded
into the URL."

Raises:
ValueError: If the name is empty, whitespace-only, or cannot be URL-encoded.
"""
if not name or not name.strip():
raise ValueError("must not be empty")
try:
_ = urllib.parse.quote(name, safe="")
except (UnicodeEncodeError, KeyError) as e:
raise ValueError("contains characters that cannot be URL-encoded") from e


@dataclass
class Operation(Generic[InputT, OutputT]):
"""Defines a Nexus operation in a Nexus service definition.
Expand Down Expand Up @@ -67,10 +86,12 @@ class OperationDefinition(Generic[InputT, OutputT]):
def from_operation(
cls, operation: Operation[InputT, OutputT]
) -> OperationDefinition[InputT, OutputT]:
if not operation.name:
try:
_validate_nexus_name(operation.name)
except ValueError as e:
raise ValueError(
f"Operation has no name (method_name is '{operation.method_name}')"
)
f"Operation name {operation.name!r} {e} (method_name is '{operation.method_name}')"
) from e
if not operation.method_name:
raise ValueError(f"Operation '{operation.name}' has no method name")
if not operation.input_type:
Expand Down Expand Up @@ -126,9 +147,12 @@ class AnotherService:
"""

def decorator(cls: type[ServiceT]) -> type[ServiceT]:
if name is not None and not name:
raise ValueError("Service name must not be empty.")
defn = ServiceDefinition.from_class(cls, name or cls.__name__)
service_name = name if name is not None else cls.__name__
try:
_validate_nexus_name(service_name)
except ValueError as e:
raise ValueError(f"Service name {service_name!r} {e}.")
defn = ServiceDefinition.from_class(cls, service_name)
set_service_definition(cls, defn)

# In order for callers to refer to operation definitions at run-time, a decorated user
Expand Down Expand Up @@ -231,10 +255,18 @@ def from_class(user_class: type[ServiceT], name: str) -> ServiceDefinition:

def _validation_errors(self) -> list[str]:
errors = []
if not self.name:
errors.append("Service has no name")
seen_method_names = set()
# Validate service name
try:
_validate_nexus_name(self.name)
except ValueError as e:
errors.append(f"Service name {self.name!r} {e}")
# Validate operation names and check for duplicate method names
seen_method_names: set[str] = set()
for op_defn in self.operation_definitions.values():
try:
_validate_nexus_name(op_defn.name)
except ValueError as e:
errors.append(f"Operation name {op_defn.name!r} {e}")
if op_defn.method_name in seen_method_names:
errors.append(
f"Operation method name '{op_defn.method_name}' is not unique"
Expand Down
109 changes: 109 additions & 0 deletions tests/service_definition/test_service_decorator_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,112 @@ def test_operation_validation(
match=str(test_case.expected_error),
):
nexusrpc.service(test_case.Contract)


def test_empty_service_name_raises():
"""Empty string passed to @service(name='') should raise."""
with pytest.raises(ValueError, match=r"Service name '' must not be empty"):

@nexusrpc.service(name="")
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str]


def test_whitespace_only_service_name_raises():
"""Whitespace-only service name should raise."""
with pytest.raises(ValueError, match=r"Service name ' ' must not be empty"):

@nexusrpc.service(name=" ")
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str]


def test_non_url_encodable_service_name_raises():
"""Service name with non-URL-encodable characters should raise."""
with pytest.raises(
ValueError,
match=r"Service name .* contains characters that cannot be URL-encoded",
):

@nexusrpc.service(name="invalid\ud800surrogate")
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str]


def test_valid_service_name_with_special_chars_succeeds():
"""Service names with URL-encodable special characters should succeed."""

@nexusrpc.service(name="my service")
class ServiceWithSpace:
op: nexusrpc.Operation[str, str]

_ = ServiceWithSpace

@nexusrpc.service(name="service/with/slashes")
class ServiceWithSlashes:
op: nexusrpc.Operation[str, str]

_ = ServiceWithSlashes

@nexusrpc.service(name="日本語サービス")
class ServiceWithUnicode:
op: nexusrpc.Operation[str, str]

_ = ServiceWithUnicode

@nexusrpc.service(name="service?query=value")
class ServiceWithQueryChars:
op: nexusrpc.Operation[str, str]

_ = ServiceWithQueryChars


def test_empty_operation_name_raises():
"""Empty operation name should raise."""
with pytest.raises(ValueError, match=r"Operation name '' must not be empty"):

@nexusrpc.service
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str] = nexusrpc.Operation(name="")


def test_whitespace_only_operation_name_raises():
"""Whitespace-only operation name should raise."""
with pytest.raises(ValueError, match=r"Operation name ' ' must not be empty"):

@nexusrpc.service
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str] = nexusrpc.Operation(name=" ")


def test_non_url_encodable_operation_name_raises():
"""Operation name with non-URL-encodable characters should raise."""
with pytest.raises(
ValueError,
match=r"Operation name .* contains characters that cannot be URL-encoded",
):

@nexusrpc.service
class MyService: # pyright: ignore[reportUnusedClass]
op: nexusrpc.Operation[str, str] = nexusrpc.Operation(
name="invalid\ud800surrogate"
)


def test_valid_operation_name_with_special_chars_succeeds():
"""Operation names with URL-encodable special characters should succeed."""

@nexusrpc.service
class MyService:
op_with_space: nexusrpc.Operation[str, str] = nexusrpc.Operation(
name="my operation"
)
op_with_slash: nexusrpc.Operation[str, str] = nexusrpc.Operation(
name="op/with/slashes"
)
op_unicode: nexusrpc.Operation[str, str] = nexusrpc.Operation(name="日本語操作")
op_query: nexusrpc.Operation[str, str] = nexusrpc.Operation(
name="op?param=value"
)

_ = MyService
4 changes: 2 additions & 2 deletions tests/test_type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _test_type_errors(

def _has_type_error_assertions(test_file: Path) -> bool:
"""Check if a file contains any type error assertions."""
with open(test_file) as f:
with open(test_file, encoding="utf-8") as f:
return any(
re.search(r"# assert-type-error-\w+:", line) for line in f.readlines()
)
Expand All @@ -93,7 +93,7 @@ def _get_expected_errors(test_file: Path, type_checker: str) -> dict[int, str]:
"""Parse expected type errors from comments in a file for the specified type checker."""
expected_errors = {}

with open(test_file) as f:
with open(test_file, encoding="utf-8") as f:
lines = zip(itertools.count(1), f)
for line_num, line in lines:
if match := re.search(
Expand Down