Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,6 @@ exclude = ["temporalio/bridge/target/**/*"]
[tool.uv]
# Prevent uv commands from building the package by default
package = false

[tool.uv.sources]
nexus-rpc = { git = "https://github.com/nexus-rpc/sdk-python" }
28 changes: 15 additions & 13 deletions temporalio/bridge/proto/nexus/nexus_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 25 additions & 2 deletions temporalio/bridge/proto/nexus/nexus_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import typing
import google.protobuf.descriptor
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
import google.protobuf.timestamp_pb2

import temporalio.api.common.v1.message_pb2
import temporalio.api.failure.v1.message_pb2
Expand Down Expand Up @@ -227,6 +228,7 @@ class NexusTask(google.protobuf.message.Message):

TASK_FIELD_NUMBER: builtins.int
CANCEL_TASK_FIELD_NUMBER: builtins.int
REQUEST_DEADLINE_FIELD_NUMBER: builtins.int
@property
def task(
self,
Expand All @@ -246,23 +248,44 @@ class NexusTask(google.protobuf.message.Message):
EX: Core knows the nexus operation has timed out, and it does not make sense for the
user's operation handler to continue doing work.
"""
@property
def request_deadline(self) -> google.protobuf.timestamp_pb2.Timestamp:
"""The deadline for this request, parsed from the "Request-Timeout" header.
Only set when variant is `task` and the header was present with a valid value.
Represented as an absolute timestamp.
"""
def __init__(
self,
*,
task: temporalio.api.workflowservice.v1.request_response_pb2.PollNexusTaskQueueResponse
| None = ...,
cancel_task: global___CancelNexusTask | None = ...,
request_deadline: google.protobuf.timestamp_pb2.Timestamp | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"cancel_task", b"cancel_task", "task", b"task", "variant", b"variant"
"cancel_task",
b"cancel_task",
"request_deadline",
b"request_deadline",
"task",
b"task",
"variant",
b"variant",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"cancel_task", b"cancel_task", "task", b"task", "variant", b"variant"
"cancel_task",
b"cancel_task",
"request_deadline",
b"request_deadline",
"task",
b"task",
"variant",
b"variant",
],
) -> None: ...
def WhichOneof(
Expand Down
33 changes: 24 additions & 9 deletions temporalio/worker/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import threading
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from functools import reduce
from typing import (
Any,
Expand Down Expand Up @@ -118,14 +119,22 @@ async def raise_from_exception_queue() -> NoReturn:

if nexus_task.HasField("task"):
task = nexus_task.task
request_deadline = (
nexus_task.request_deadline.ToDatetime().replace(
tzinfo=timezone.utc
)
if nexus_task.HasField("request_deadline")
else None
)
if task.request.HasField("start_operation"):
task_cancellation = _NexusTaskCancellation()
start_op_task = asyncio.create_task(
self._handle_start_operation_task(
task.task_token,
task.request.start_operation,
dict(task.request.header),
task_cancellation,
task_token=task.task_token,
start_request=task.request.start_operation,
headers=dict(task.request.header),
task_cancellation=task_cancellation,
request_deadline=request_deadline,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
Expand All @@ -135,10 +144,11 @@ async def raise_from_exception_queue() -> NoReturn:
task_cancellation = _NexusTaskCancellation()
cancel_op_task = asyncio.create_task(
self._handle_cancel_operation_task(
task.task_token,
task.request.cancel_operation,
dict(task.request.header),
task_cancellation,
task_token=task.task_token,
request=task.request.cancel_operation,
headers=dict(task.request.header),
task_cancellation=task_cancellation,
request_deadline=request_deadline,
)
)
self._running_tasks[task.task_token] = _RunningNexusTask(
Expand Down Expand Up @@ -204,6 +214,7 @@ async def _handle_cancel_operation_task(
request: temporalio.api.nexus.v1.CancelOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
request_deadline: datetime | None,
) -> None:
"""Handle a cancel operation task.

Expand All @@ -216,6 +227,7 @@ async def _handle_cancel_operation_task(
operation=request.operation,
headers=headers,
task_cancellation=task_cancellation,
request_deadline=request_deadline,
)
temporalio.nexus._operation_context._TemporalCancelOperationContext(
info=lambda: Info(task_queue=self._task_queue),
Expand Down Expand Up @@ -264,6 +276,7 @@ async def _handle_start_operation_task(
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
task_cancellation: nexusrpc.handler.OperationTaskCancellation,
request_deadline: datetime | None,
) -> None:
"""Handle a start operation task.

Expand All @@ -273,7 +286,7 @@ async def _handle_start_operation_task(
try:
try:
start_response = await self._start_operation(
start_request, headers, task_cancellation
start_request, headers, task_cancellation, request_deadline
)
except asyncio.CancelledError:
completion = temporalio.bridge.proto.nexus.NexusTaskCompletion(
Expand Down Expand Up @@ -315,6 +328,7 @@ async def _start_operation(
start_request: temporalio.api.nexus.v1.StartOperationRequest,
headers: Mapping[str, str],
cancellation: nexusrpc.handler.OperationTaskCancellation,
request_deadline: datetime | None,
) -> temporalio.api.nexus.v1.StartOperationResponse:
"""Invoke the Nexus handler's start_operation method and construct the StartOperationResponse.

Expand All @@ -334,6 +348,7 @@ async def _start_operation(
],
callback_headers=dict(start_request.callback_header),
task_cancellation=cancellation,
request_deadline=request_deadline,
)
temporalio.nexus._operation_context._TemporalStartOperationContext(
nexus_context=ctx,
Expand Down
2 changes: 2 additions & 0 deletions tests/helpers/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,14 @@ async def cancel_operation(
self,
operation: str,
token: str,
headers: Mapping[str, str] = {},
) -> httpx.Response:
async with httpx.AsyncClient() as http_client:
return await http_client.post(
f"http://{self.server_address}/nexus/endpoints/{self.endpoint}/services/{self.service}/{operation}/cancel",
# Token can also be sent as "Nexus-Operation-Token" header
params={"token": token},
headers=headers,
)

@staticmethod
Expand Down
97 changes: 97 additions & 0 deletions tests/nexus/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections.abc import Callable, Mapping
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from types import MappingProxyType
from typing import Any

Expand Down Expand Up @@ -102,6 +103,7 @@ class MyService:
operation_error_failed: nexusrpc.Operation[Input, Output]
idempotency_check: nexusrpc.Operation[None, Output]
non_serializable_output: nexusrpc.Operation[Input, NonSerializableOutput]
check_request_deadline: nexusrpc.Operation[Input, Output]


@workflow.defn
Expand Down Expand Up @@ -277,6 +279,26 @@ async def non_serializable_output(
) -> NonSerializableOutput:
return NonSerializableOutput()

class OperationHandlerCheckingRequestDeadline(OperationHandler[Input, Output]):
async def start( # type: ignore[override]
self,
ctx: StartOperationContext,
input: Input,
) -> StartOperationResultSync[Output]:
assert ctx.request_deadline is not None, "request_deadline should be set"
# Return ISO format string so we can verify the value
return StartOperationResultSync(
Output(value=ctx.request_deadline.isoformat())
)

async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
assert ctx.request_deadline is not None, "request_deadline should be set"
return

@operation_handler
def check_request_deadline(self) -> OperationHandler[Input, Output]:
return MyServiceHandler.OperationHandlerCheckingRequestDeadline()


# Immutable dicts that can be used as dataclass field defaults

Expand Down Expand Up @@ -985,6 +1007,81 @@ async def test_request_id_is_received_by_sync_operation(
assert resp.json() == {"value": f"request_id: {request_id}"}


async def test_request_deadline_is_present_in_start_operation_context(
env: WorkflowEnvironment,
):
"""Test that request_deadline is populated from Request-Timeout header."""
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
service=MyService.__name__,
)

decorator = service_handler(service=MyService)
user_service_handler = decorator(MyServiceHandler)()

async with Worker(
env.client,
task_queue=task_queue,
nexus_service_handlers=[user_service_handler],
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
):
before = datetime.now(timezone.utc)
resp = await service_client.start_operation(
"check_request_deadline",
dataclass_as_dict(Input("test")),
headers={"Request-Timeout": "30s"},
)
after = datetime.now(timezone.utc)

assert resp.status_code == 200
deadline_str = resp.json()["value"]
deadline = datetime.fromisoformat(deadline_str)

# Deadline should be approximately 30s from request time
expected_min = before + timedelta(seconds=29)
expected_max = after + timedelta(seconds=31)
assert (
expected_min <= deadline <= expected_max
), f"Deadline {deadline} not in expected range [{expected_min}, {expected_max}]"


async def test_request_deadline_is_present_in_cancel_operation_context(
env: WorkflowEnvironment,
):
"""Test that request_deadline is populated from Request-Timeout header."""
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
endpoint = (await create_nexus_endpoint(task_queue, env.client)).endpoint.id
service_client = ServiceClient(
server_address=ServiceClient.default_server_address(env),
endpoint=endpoint,
service=MyService.__name__,
)

decorator = service_handler(service=MyService)
user_service_handler = decorator(MyServiceHandler)()

async with Worker(
env.client,
task_queue=task_queue,
nexus_service_handlers=[user_service_handler],
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
):
resp = await service_client.cancel_operation(
"check_request_deadline",
"test-token",
)
assert resp.status_code == 202


@workflow.defn
class EchoWorkflow:
@workflow.run
Expand Down
Loading
Loading