diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index eeefbff8c..c2a0ed780 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -13,12 +13,19 @@ StatelessMCPServerProvider, ) from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters +from temporalio.contrib.openai_agents._otel_tracing import ( + setup_tracing, + workflow_span, +) from temporalio.contrib.openai_agents._temporal_openai_agents import ( OpenAIAgentsPlugin, OpenAIPayloadConverter, ) from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError +# Re-export OtelTracingPlugin from its new location for backward compatibility +from temporalio.contrib.opentelemetry import OtelTracingPlugin + from . import testing, workflow __all__ = [ @@ -26,8 +33,11 @@ "ModelActivityParameters", "OpenAIAgentsPlugin", "OpenAIPayloadConverter", + "OtelTracingPlugin", + "setup_tracing", "StatelessMCPServerProvider", "StatefulMCPServerProvider", "testing", "workflow", + "workflow_span", ] diff --git a/temporalio/contrib/openai_agents/_otel_tracing.py b/temporalio/contrib/openai_agents/_otel_tracing.py new file mode 100644 index 000000000..f97ba5bc4 --- /dev/null +++ b/temporalio/contrib/openai_agents/_otel_tracing.py @@ -0,0 +1,73 @@ +"""OpenTelemetry tracing support for OpenAI Agents in Temporal workflows.""" + +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Iterator + +from opentelemetry import trace as otel_trace +from opentelemetry.trace import Span + +from temporalio import workflow + +if TYPE_CHECKING: + from opentelemetry.sdk.trace import TracerProvider + + +def setup_tracing(tracer_provider: TracerProvider) -> None: + """Set up OpenAI Agents OTEL tracing with OpenInference instrumentation. + + This instruments the OpenAI Agents SDK with OpenInference, which converts + agent spans to OTEL spans. Combined with opentelemetry passthrough in the + sandbox, this enables proper span parenting inside Temporal's workflows. + + Args: + tracer_provider: The TracerProvider to use for creating spans. + """ + from openinference.instrumentation.openai_agents import OpenAIAgentsInstrumentor + + otel_trace.set_tracer_provider(tracer_provider) + OpenAIAgentsInstrumentor().instrument(tracer_provider=tracer_provider) + + +@contextmanager +def workflow_span(name: str, **attributes: str) -> Iterator[Span | None]: + """Create an OTEL span in workflow code that is replay-safe. + + This context manager creates a span only on the first execution of the + workflow task, not during replay. This prevents span duplication when + workflow code is re-executed during replay (e.g., when max_cached_workflows=0). + + .. warning:: + This API is experimental and may change in future versions. + Consider using ReplayFilteringSpanProcessor instead for automatic + filtering of all spans during replay. + + Args: + name: The name of the span. + **attributes: Optional attributes to set on the span. + + Yields: + The span on first execution, None during replay. + + Example: + >>> @workflow.defn + ... class MyWorkflow: + ... @workflow.run + ... async def run(self) -> str: + ... with workflow_span("my_operation", query="test") as span: + ... result = await workflow.execute_activity(...) + ... return result + + Note: + Spans created in activities do not need this wrapper since activities + are not replayed - they only execute once and their results are cached. + """ + if workflow.unsafe.is_replaying(): + yield None + else: + tracer = otel_trace.get_tracer(__name__) + with tracer.start_as_current_span(name) as span: + for key, value in attributes.items(): + span.set_attribute(key, value) + yield span diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index c1ace7a55..5fcda58c1 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -181,6 +181,7 @@ def __init__( "StatelessMCPServerProvider | StatefulMCPServerProvider" ] = (), register_activities: bool = True, + create_spans: bool = True, ) -> None: """Initialize the OpenAI agents plugin. @@ -196,7 +197,11 @@ def __init__( register_activities: Whether to register activities during the worker execution. This can be disabled on some workers to allow a separation of workflows and activities but should not be disabled on all workers, or agents will not be able to progress. + create_spans: Whether to create ``temporal:*`` spans for workflow and activity + operations. If False, trace context is still propagated but no spans are + created. Defaults to True. """ + self._create_spans = create_spans if model_params is None: model_params = ModelActivityParameters() @@ -252,7 +257,7 @@ async def run_context() -> AsyncIterator[None]: super().__init__( name="OpenAIAgentsPlugin", data_converter=_data_converter, - worker_interceptors=[OpenAIAgentsTracingInterceptor()], + worker_interceptors=[OpenAIAgentsTracingInterceptor(create_spans=self._create_spans)], activities=add_activities, workflow_runner=workflow_runner, workflow_failure_exception_types=[AgentsWorkflowError], diff --git a/temporalio/contrib/openai_agents/_trace_interceptor.py b/temporalio/contrib/openai_agents/_trace_interceptor.py index d099ae09b..6c139f7b4 100644 --- a/temporalio/contrib/openai_agents/_trace_interceptor.py +++ b/temporalio/contrib/openai_agents/_trace_interceptor.py @@ -56,8 +56,16 @@ def context_from_header( span_name: str, input: _InputWithHeaders, payload_converter: temporalio.converter.PayloadConverter, + create_spans: bool = True, ): - """Extracts and initializes trace information the input header.""" + """Extracts and initializes trace information the input header. + + Args: + span_name: Name for the span to create (if create_spans is True). + input: The input containing headers with trace context. + payload_converter: Converter for deserializing the header payload. + create_spans: Whether to create a span. If False, only sets up context. + """ payload = input.headers.get(HEADER_KEY) span_info = payload_converter.from_payload(payload) if payload else None if span_info is None: @@ -100,7 +108,10 @@ def context_from_header( ) Scope.set_current_span(current_span) - with custom_span(name=span_name, parent=current_span, data=data): + if create_spans: + with custom_span(name=span_name, parent=current_span, data=data): + yield + else: yield @@ -131,15 +142,20 @@ class OpenAIAgentsTracingInterceptor( def __init__( self, payload_converter: temporalio.converter.PayloadConverter = temporalio.converter.default().payload_converter, + create_spans: bool = True, ) -> None: """Initialize the interceptor with a payload converter. Args: payload_converter: The payload converter to use for serializing/deserializing trace context. Defaults to the default Temporal payload converter. + create_spans: Whether to create spans for Temporal operations. If False, + trace context is still propagated but no ``temporal:*`` spans are created. + Defaults to True. """ super().__init__() self._payload_converter = payload_converter + self._create_spans = create_spans def intercept_client( self, next: temporalio.client.OutboundInterceptor @@ -153,7 +169,7 @@ def intercept_client( An interceptor that propagates trace context for client operations. """ return _ContextPropagationClientOutboundInterceptor( - next, self._payload_converter + next, self._payload_converter, self._create_spans ) def intercept_activity( @@ -167,7 +183,7 @@ def intercept_activity( Returns: An interceptor that propagates trace context for activity operations. """ - return _ContextPropagationActivityInboundInterceptor(next) + return _ContextPropagationActivityInboundInterceptor(next, self._create_spans) def workflow_interceptor_class( self, input: temporalio.worker.WorkflowInterceptorClassInput @@ -180,7 +196,13 @@ def workflow_interceptor_class( Returns: The class of the workflow interceptor that propagates trace context. """ - return _ContextPropagationWorkflowInboundInterceptor + # Capture create_spans in closure for use by the workflow interceptor class + create_spans = self._create_spans + + class _BoundWorkflowInboundInterceptor(_ContextPropagationWorkflowInboundInterceptor): + _create_spans = create_spans + + return _BoundWorkflowInboundInterceptor class _ContextPropagationClientOutboundInterceptor( @@ -190,13 +212,19 @@ def __init__( self, next: temporalio.client.OutboundInterceptor, payload_converter: temporalio.converter.PayloadConverter, + create_spans: bool = True, ) -> None: super().__init__(next) self._payload_converter = payload_converter + self._create_spans = create_spans async def start_workflow( self, input: temporalio.client.StartWorkflowInput ) -> temporalio.client.WorkflowHandle[Any, Any]: + if not self._create_spans: + set_header_from_context(input, self._payload_converter) + return await super().start_workflow(input) + metadata = { "temporal:workflowType": input.workflow, **({"temporal:workflowId": input.id} if input.id else {}), @@ -216,6 +244,10 @@ async def start_workflow( return await super().start_workflow(input) async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> Any: + if not self._create_spans: + set_header_from_context(input, self._payload_converter) + return await super().query_workflow(input) + metadata = { "temporal:queryWorkflow": input.query, **({"temporal:workflowId": input.id} if input.id else {}), @@ -235,6 +267,11 @@ async def query_workflow(self, input: temporalio.client.QueryWorkflowInput) -> A async def signal_workflow( self, input: temporalio.client.SignalWorkflowInput ) -> None: + if not self._create_spans: + set_header_from_context(input, self._payload_converter) + await super().signal_workflow(input) + return + metadata = { "temporal:signalWorkflow": input.signal, **({"temporal:workflowId": input.id} if input.id else {}), @@ -254,6 +291,10 @@ async def signal_workflow( async def start_workflow_update( self, input: temporalio.client.StartWorkflowUpdateInput ) -> temporalio.client.WorkflowUpdateHandle[Any]: + if not self._create_spans: + set_header_from_context(input, self._payload_converter) + return await self.next.start_workflow_update(input) + metadata = { "temporal:updateWorkflow": input.update, **({"temporal:workflowId": input.id} if input.id else {}), @@ -277,11 +318,22 @@ async def start_workflow_update( class _ContextPropagationActivityInboundInterceptor( temporalio.worker.ActivityInboundInterceptor ): + def __init__( + self, + next: temporalio.worker.ActivityInboundInterceptor, + create_spans: bool = True, + ) -> None: + super().__init__(next) + self._create_spans = create_spans + async def execute_activity( self, input: temporalio.worker.ExecuteActivityInput ) -> Any: with context_from_header( - "temporal:executeActivity", input, temporalio.activity.payload_converter() + "temporal:executeActivity", + input, + temporalio.activity.payload_converter(), + self._create_spans, ): return await self.next.execute_activity(input) @@ -318,29 +370,43 @@ def _ensure_tracing_random() -> None: class _ContextPropagationWorkflowInboundInterceptor( temporalio.worker.WorkflowInboundInterceptor ): + # Set by factory in workflow_interceptor_class() + _create_spans: bool = True + def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None: - self.next.init(_ContextPropagationWorkflowOutboundInterceptor(outbound)) + self.next.init( + _ContextPropagationWorkflowOutboundInterceptor(outbound, self._create_spans) + ) async def execute_workflow( self, input: temporalio.worker.ExecuteWorkflowInput ) -> Any: _ensure_tracing_random() with context_from_header( - "temporal:executeWorkflow", input, temporalio.workflow.payload_converter() + "temporal:executeWorkflow", + input, + temporalio.workflow.payload_converter(), + self._create_spans, ): return await self.next.execute_workflow(input) async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None: _ensure_tracing_random() with context_from_header( - "temporal:handleSignal", input, temporalio.workflow.payload_converter() + "temporal:handleSignal", + input, + temporalio.workflow.payload_converter(), + self._create_spans, ): return await self.next.handle_signal(input) async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any: _ensure_tracing_random() with context_from_header( - "temporal:handleQuery", input, temporalio.workflow.payload_converter() + "temporal:handleQuery", + input, + temporalio.workflow.payload_converter(), + self._create_spans, ): return await self.next.handle_query(input) @@ -351,6 +417,7 @@ def handle_update_validator( "temporal:handleUpdateValidator", input, temporalio.workflow.payload_converter(), + self._create_spans, ): self.next.handle_update_validator(input) @@ -362,6 +429,7 @@ async def handle_update_handler( "temporal:handleUpdateHandler", input, temporalio.workflow.payload_converter(), + self._create_spans, ): return await self.next.handle_update_handler(input) @@ -369,11 +437,19 @@ async def handle_update_handler( class _ContextPropagationWorkflowOutboundInterceptor( temporalio.worker.WorkflowOutboundInterceptor ): + def __init__( + self, + next: temporalio.worker.WorkflowOutboundInterceptor, + create_spans: bool = True, + ) -> None: + super().__init__(next) + self._create_spans = create_spans + async def signal_child_workflow( self, input: temporalio.worker.SignalChildWorkflowInput ) -> None: trace = get_trace_provider().get_current_trace() - if trace: + if trace and self._create_spans: with custom_span( name="temporal:signalChildWorkflow", data={"workflowId": input.child_workflow_id}, @@ -388,7 +464,7 @@ async def signal_external_workflow( self, input: temporalio.worker.SignalExternalWorkflowInput ) -> None: trace = get_trace_provider().get_current_trace() - if trace: + if trace and self._create_spans: with custom_span( name="temporal:signalExternalWorkflow", data={"workflowId": input.workflow_id}, @@ -404,7 +480,7 @@ def start_activity( ) -> temporalio.workflow.ActivityHandle: trace = get_trace_provider().get_current_trace() span: Span | None = None - if trace: + if trace and self._create_spans: span = custom_span( name="temporal:startActivity", data={"activity": input.activity} ) @@ -421,7 +497,7 @@ async def start_child_workflow( ) -> temporalio.workflow.ChildWorkflowHandle: trace = get_trace_provider().get_current_trace() span: Span | None = None - if trace: + if trace and self._create_spans: span = custom_span( name="temporal:startChildWorkflow", data={"workflow": input.workflow} ) @@ -437,7 +513,7 @@ def start_local_activity( ) -> temporalio.workflow.ActivityHandle: trace = get_trace_provider().get_current_trace() span: Span | None = None - if trace: + if trace and self._create_spans: span = custom_span( name="temporal:startLocalActivity", data={"activity": input.activity} ) diff --git a/temporalio/contrib/opentelemetry/__init__.py b/temporalio/contrib/opentelemetry/__init__.py new file mode 100644 index 000000000..33b7028b0 --- /dev/null +++ b/temporalio/contrib/opentelemetry/__init__.py @@ -0,0 +1,86 @@ +"""OpenTelemetry integration for Temporal. + +This module provides OpenTelemetry tracing support for Temporal workflows +and activities. It includes: + +- :py:class:`TracingInterceptor`: Interceptor for creating and propagating + OpenTelemetry spans across client, worker, and workflow boundaries. + +- :py:class:`OtelTracingPlugin`: A plugin that configures TracingInterceptor + with context-only propagation (no spans) and automatically configures + sandbox passthrough for opentelemetry. + +Basic usage with TracingInterceptor (creates Temporal spans): + + from temporalio.contrib.opentelemetry import TracingInterceptor + + client = await Client.connect( + "localhost:7233", + interceptors=[TracingInterceptor()], + ) + +Usage with OtelTracingPlugin (context propagation only, for use with other +instrumentation like OpenInference): + + from temporalio.contrib.opentelemetry import OtelTracingPlugin + + plugin = OtelTracingPlugin(tracer_provider=tracer_provider) + + client = await Client.connect( + "localhost:7233", + plugins=[plugin], + ) + + # Sandbox passthrough is configured automatically by the plugin + worker = Worker( + client, + task_queue="my-queue", + workflows=[MyWorkflow], + ) + +The plugin automatically: +- Configures opentelemetry as a sandbox passthrough module +- Wraps tracer provider span processors with replay filtering + +Design Notes - Cross-Process Trace Propagation: + + OpenTelemetry spans are process-local by design. The OTEL specification + states: "Spans are not meant to be used to propagate information within + a process." Only SpanContext (trace_id, span_id, trace_flags) crosses + process boundaries via propagators. + + This is intentional - Span objects contain mutable state, thread locks, + and processor references that cannot be serialized. SpanContext is an + immutable tuple designed for cross-process propagation. + + For Temporal workflows that may execute across multiple workers: + + - TracingInterceptor serializes SpanContext (not Span) into headers + - Remote workers deserialize SpanContext and wrap it in NonRecordingSpan + - Child spans created in remote workers link to the parent via span_id + - No span is ever "opened" in one process and "closed" in another + + With create_spans=True, workflow spans are created and immediately ended + (same start/end timestamp) to avoid cross-process lifecycle issues. + + With create_spans=False (OtelTracingPlugin default), no Temporal spans + are created - only context is propagated for other instrumentation. +""" + +from ._id_generator import TemporalIdGenerator +from ._otel_tracing_plugin import OtelTracingPlugin +from ._tracing_interceptor import ( + TracingInterceptor, + TracingWorkflowInboundInterceptor, + default_text_map_propagator, + workflow, +) + +__all__ = [ + "OtelTracingPlugin", + "TemporalIdGenerator", + "TracingInterceptor", + "TracingWorkflowInboundInterceptor", + "default_text_map_propagator", + "workflow", +] diff --git a/temporalio/contrib/opentelemetry/_id_generator.py b/temporalio/contrib/opentelemetry/_id_generator.py new file mode 100644 index 000000000..f748a4b60 --- /dev/null +++ b/temporalio/contrib/opentelemetry/_id_generator.py @@ -0,0 +1,80 @@ +"""Deterministic ID generator for Temporal workflows.""" + +from __future__ import annotations + +from opentelemetry.sdk.trace.id_generator import IdGenerator, RandomIdGenerator +from opentelemetry.trace import INVALID_SPAN_ID, INVALID_TRACE_ID + + +class TemporalIdGenerator(IdGenerator): + """IdGenerator that produces deterministic IDs in workflow context. + + Uses workflow.random() which is seeded deterministically by Temporal and + replays identically. Outside workflow context (activities, client code), + falls back to standard random generation. + + This enables real-duration spans in workflows: + - First execution: Span with ID X created, exported with real duration + - Replay: Same ID X generated (deterministic), span filtered by + ReplayFilteringSpanProcessor + - Parent-child relationships remain stable across executions + + Usage: + from opentelemetry.sdk.trace import TracerProvider + from temporalio.contrib.opentelemetry import TemporalIdGenerator + + tracer_provider = TracerProvider(id_generator=TemporalIdGenerator()) + + Or via OtelTracingPlugin: + plugin = OtelTracingPlugin( + tracer_provider=tracer_provider, + deterministic_ids=True, + ) + """ + + def __init__(self) -> None: + """Initialize the ID generator with a fallback for non-workflow contexts.""" + self._fallback = RandomIdGenerator() + + def generate_span_id(self) -> int: + """Generate a span ID. + + In workflow context, uses deterministic RNG. Otherwise, uses random. + """ + if self._in_workflow_context(): + from temporalio import workflow + + span_id = workflow.random().getrandbits(64) + while span_id == INVALID_SPAN_ID: + span_id = workflow.random().getrandbits(64) + return span_id + return self._fallback.generate_span_id() + + def generate_trace_id(self) -> int: + """Generate a trace ID. + + In workflow context, uses deterministic RNG. Otherwise, uses random. + """ + if self._in_workflow_context(): + from temporalio import workflow + + trace_id = workflow.random().getrandbits(128) + while trace_id == INVALID_TRACE_ID: + trace_id = workflow.random().getrandbits(128) + return trace_id + return self._fallback.generate_trace_id() + + def _in_workflow_context(self) -> bool: + """Check if we're in workflow context where random() is available.""" + from temporalio import workflow + + try: + # workflow.in_workflow() returns True if in workflow context + if not workflow.in_workflow(): + return False + # Also check we're not in a read-only context (query handler) + # by actually calling random() - it raises ReadOnlyContextError if not allowed + workflow.random() + return True + except Exception: + return False diff --git a/temporalio/contrib/opentelemetry/_otel_tracing_plugin.py b/temporalio/contrib/opentelemetry/_otel_tracing_plugin.py new file mode 100644 index 000000000..751f891eb --- /dev/null +++ b/temporalio/contrib/opentelemetry/_otel_tracing_plugin.py @@ -0,0 +1,189 @@ +"""OpenTelemetry tracing plugin for Temporal workflows.""" + +from __future__ import annotations + +import dataclasses +import logging +from typing import TYPE_CHECKING + +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner + +from ._replay_filtering_processor import ReplayFilteringSpanProcessor +from ._tracing_interceptor import TracingInterceptor + +if TYPE_CHECKING: + from opentelemetry.sdk.trace import TracerProvider + + from temporalio.worker.workflow_sandbox import SandboxRestrictions + +_logger = logging.getLogger(__name__) + + +class OtelTracingPlugin(SimplePlugin): + """Plugin for clean OTEL tracing in Temporal workflows. + + This plugin provides: + 1. Context propagation via TracingInterceptor (with create_spans=False) + 2. Automatic sandbox passthrough configuration for opentelemetry module + 3. Optional replay filtering for span processors + + The plugin uses TracingInterceptor with create_spans=False, which means + it propagates trace context through Temporal headers without creating + its own spans. This allows you to use your own instrumentation + (like OpenInference) while still getting proper context propagation. + + Why create_spans=False? + + OpenTelemetry spans cannot cross process boundaries - only SpanContext + can be propagated. Temporal workflows may execute across multiple workers + (different processes/machines), so we propagate context only and let + your instrumentation (e.g., OpenInference) create spans locally. + + Trace continuity is maintained via parent-child relationships: + + - Client creates a span, its SpanContext is propagated via headers + - Worker receives SpanContext, wraps it in NonRecordingSpan + - Your instrumentation creates child spans with the correct parent + - Backend correlates spans by trace_id and parent span_id + + Usage: + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from temporalio.contrib.opentelemetry import OtelTracingPlugin + + tracer_provider = TracerProvider() + tracer_provider.add_span_processor( + SimpleSpanProcessor(OTLPSpanExporter()) + ) + + plugin = OtelTracingPlugin(tracer_provider=tracer_provider) + + client = await Client.connect( + "localhost:7233", + plugins=[plugin], + ) + + # Sandbox passthrough is configured automatically by the plugin + worker = Worker( + client, + task_queue="my-queue", + workflows=[MyWorkflow], + ) + """ + + def __init__( + self, + tracer_provider: TracerProvider | None = None, + filter_replay_spans: bool = True, + deterministic_ids: bool = False, + ) -> None: + """Initialize the OTEL tracing plugin. + + Args: + tracer_provider: Optional tracer provider to configure. If provided, + replay filtering and/or deterministic IDs will be configured + based on the other parameters. + filter_replay_spans: If True and tracer_provider is provided, + wrap span processors to filter out spans created during replay. + Defaults to True. + deterministic_ids: If True and tracer_provider is provided, + configure the tracer provider to use deterministic span ID + generation in workflow context. This enables real-duration + spans in workflows by ensuring the same span IDs are generated + on replay (which are then filtered by ReplayFilteringSpanProcessor). + Defaults to False. + """ + if tracer_provider: + if deterministic_ids: + self._configure_deterministic_ids(tracer_provider) + if filter_replay_spans: + self._wrap_with_replay_filter(tracer_provider) + + interceptor = TracingInterceptor(create_spans=False) + + def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner | None: + """Auto-configure sandbox to passthrough opentelemetry.""" + from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + if runner is None: + return None + if isinstance(runner, SandboxedWorkflowRunner): + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + "opentelemetry" + ), + ) + return runner + + super().__init__( + name="OtelTracingPlugin", + worker_interceptors=[interceptor], + client_interceptors=[interceptor], + workflow_runner=workflow_runner, + ) + + @property + def sandbox_restrictions(self) -> SandboxRestrictions: + """Return sandbox restrictions with opentelemetry passthrough. + + This property returns a SandboxRestrictions object that has opentelemetry + added to the passthrough modules. This is necessary for OTEL context + propagation to work correctly inside workflow sandboxes. + + Without this, the opentelemetry module is re-imported inside the sandbox, + creating a separate ContextVar instance that cannot see context attached + by the TracingInterceptor. + + Usage: + plugin = OtelTracingPlugin() + worker = Worker( + client, + workflows=[...], + workflow_runner=SandboxedWorkflowRunner( + restrictions=plugin.sandbox_restrictions + ), + ) + """ + from temporalio.worker.workflow_sandbox import SandboxRestrictions + + return SandboxRestrictions.default.with_passthrough_modules("opentelemetry") + + def _wrap_with_replay_filter(self, tracer_provider: TracerProvider) -> None: + """Wrap tracer provider's span processors with replay filtering. + + This modifies the tracer provider in place to wrap each span processor + with ReplayFilteringSpanProcessor. + """ + # Access the internal span processors + # Note: This uses internal APIs which may change + if hasattr(tracer_provider, "_active_span_processor"): + processor = tracer_provider._active_span_processor + # The multi span processor has a list of processors + if hasattr(processor, "_span_processors"): + wrapped = [] + for p in processor._span_processors: + wrapped.append(ReplayFilteringSpanProcessor(p)) + processor._span_processors = tuple(wrapped) + + def _configure_deterministic_ids(self, tracer_provider: TracerProvider) -> None: + """Configure tracer provider for deterministic span ID generation. + + This modifies the tracer provider in place to use TemporalIdGenerator, + which produces deterministic span/trace IDs when running in workflow + context using workflow.random(). + + Args: + tracer_provider: The tracer provider to configure. + """ + from ._id_generator import TemporalIdGenerator + + if hasattr(tracer_provider, "id_generator"): + tracer_provider.id_generator = TemporalIdGenerator() + else: + _logger.warning( + "Could not configure deterministic span IDs: " + "TracerProvider does not have id_generator attribute. " + "Span IDs will be random, which may cause issues during replay." + ) diff --git a/temporalio/contrib/opentelemetry/_replay_filtering_processor.py b/temporalio/contrib/opentelemetry/_replay_filtering_processor.py new file mode 100644 index 000000000..c18668d4f --- /dev/null +++ b/temporalio/contrib/opentelemetry/_replay_filtering_processor.py @@ -0,0 +1,73 @@ +"""Span processor that filters out spans created during workflow replay.""" + +from __future__ import annotations + +from opentelemetry.context import Context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor + + +class ReplayFilteringSpanProcessor(SpanProcessor): + """Wraps a SpanProcessor to filter out spans created during workflow replay. + + During Temporal workflow replay, workflow code re-executes but activities + return cached results. Without filtering, this causes duplicate spans. + This processor marks spans created during replay and drops them on export. + + Usage: + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from temporalio.contrib.opentelemetry import ReplayFilteringSpanProcessor + + tracer_provider = TracerProvider() + tracer_provider.add_span_processor( + ReplayFilteringSpanProcessor( + SimpleSpanProcessor(OTLPSpanExporter()) + ) + ) + """ + + # Attribute name used to mark spans created during replay + _REPLAY_MARKER = "_temporal_replay" + + def __init__(self, delegate: SpanProcessor) -> None: + """Initialize the replay filtering processor. + + Args: + delegate: The underlying span processor to delegate to for + non-replay spans. + """ + self._delegate = delegate + + def on_start(self, span: Span, parent_context: Context | None = None) -> None: + """Called when a span is started. + + Checks if we're currently replaying and marks the span if so. + """ + try: + from temporalio import workflow + + if workflow.unsafe.is_replaying(): + # Mark this span as created during replay + setattr(span, self._REPLAY_MARKER, True) + except Exception: + # Not in workflow context, or workflow module not available + # This is fine - just means we're not in a workflow + pass + + self._delegate.on_start(span, parent_context) + + def on_end(self, span: ReadableSpan) -> None: + """Called when a span is ended. + + Drops spans that were marked as created during replay. + """ + if not getattr(span, self._REPLAY_MARKER, False): + self._delegate.on_end(span) + + def shutdown(self) -> None: + """Shuts down the processor.""" + self._delegate.shutdown() + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """Forces a flush of all pending spans.""" + return self._delegate.force_flush(timeout_millis) diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry/_tracing_interceptor.py similarity index 93% rename from temporalio/contrib/opentelemetry.py rename to temporalio/contrib/opentelemetry/_tracing_interceptor.py index ef1e52bb2..fb3618992 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry/_tracing_interceptor.py @@ -80,6 +80,7 @@ def __init__( # type: ignore[reportMissingSuperCall] tracer: opentelemetry.trace.Tracer | None = None, *, always_create_workflow_spans: bool = False, + create_spans: bool = True, ) -> None: """Initialize a OpenTelemetry tracing interceptor. @@ -94,6 +95,12 @@ def __init__( # type: ignore[reportMissingSuperCall] create spans in workflows no matter what, but there is a risk of them being orphans since they may not have a parent span after replaying. + create_spans: When true, the default, spans are created for Temporal + operations (StartWorkflow, RunActivity, etc.). When false, only + context propagation is performed without creating any spans. + This is useful when you want to use Temporal's robust W3C + TraceContext propagation but have another instrumentation + library (like OpenInference) create the actual spans. """ self.tracer = tracer or opentelemetry.trace.get_tracer(__name__) # To customize any of this, users must subclass. We intentionally don't @@ -105,6 +112,7 @@ def __init__( # type: ignore[reportMissingSuperCall] # TODO(cretz): Should I be using the configured one at the client and activity level? self.payload_converter = temporalio.converter.PayloadConverter.default self._always_create_workflow_spans = always_create_workflow_spans + self._create_spans = create_spans def intercept_client( self, next: temporalio.client.OutboundInterceptor @@ -182,6 +190,31 @@ def _start_as_current_span( kind: opentelemetry.trace.SpanKind, context: Context | None = None, ) -> Iterator[None]: + # If create_spans is False, only propagate context without creating spans + if not self._create_spans: + # Attach incoming context if provided (for activities/inbound) + token = opentelemetry.context.attach(context) if context else None + try: + # Still propagate context via headers (for outbound) + if input_with_headers: + input_with_headers.headers = self._context_to_headers( + input_with_headers.headers + ) + if input_with_ctx: + carrier: _CarrierDict = {} + self.text_map_propagator.inject(carrier) + input_with_ctx.ctx = dataclasses.replace( + input_with_ctx.ctx, + headers=_carrier_to_nexus_headers( + carrier, input_with_ctx.ctx.headers + ), + ) + yield None + finally: + if token and context is opentelemetry.context.get_current(): + opentelemetry.context.detach(token) + return + token = opentelemetry.context.attach(context) if context else None try: with self.tracer.start_as_current_span( @@ -228,6 +261,10 @@ def _completed_workflow_span( # Carrier to context, start span, set span as current on context, # context back to carrier + # If create_spans is False, just return the existing context without creating spans + if not self._create_spans: + return params.context + # If the parent is missing and user hasn't said to always create, do not # create if params.parent_missing and not self._always_create_workflow_spans: @@ -244,9 +281,17 @@ def _completed_workflow_span( if link_span is not opentelemetry.trace.INVALID_SPAN: links = [opentelemetry.trace.Link(link_span.get_span_context())] - # We start and end the span immediately because it is not replay-safe to - # keep an unended long-running span. We set the end time the same as the - # start time to make it clear it has no duration. + # OpenTelemetry Design: Spans are process-local, only SpanContext crosses + # process boundaries. Temporal workflows may execute across multiple workers, + # so we cannot keep a long-running span open. + # + # Solution: Create and immediately end workflow spans with the same timestamp. + # This provides: + # 1. A span_id for child operations to reference as parent + # 2. Attributes (workflow type, ID) recorded in the trace + # 3. Replay safety - no state survives across workflow tasks + # + # The span appears as a zero-duration marker with children beneath it. span = self.tracer.start_span( params.name, context, diff --git a/tests/contrib/conftest.py b/tests/contrib/conftest.py new file mode 100644 index 000000000..52fd71e1d --- /dev/null +++ b/tests/contrib/conftest.py @@ -0,0 +1,49 @@ +"""Shared fixtures for contrib tests.""" + +import pytest +import opentelemetry.trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + + +# Global provider - only initialized once per process to avoid conflicts +_provider: TracerProvider | None = None +_exporter: InMemorySpanExporter | None = None + + +def _setup_tracing() -> tuple[TracerProvider, InMemorySpanExporter]: + """Setup shared OTEL tracing. Only initializes once per process.""" + global _provider, _exporter + + if _provider is None: + _provider = TracerProvider() + _exporter = InMemorySpanExporter() + _provider.add_span_processor(SimpleSpanProcessor(_exporter)) + opentelemetry.trace.set_tracer_provider(_provider) + + return _provider, _exporter + + +@pytest.fixture +def tracing() -> InMemorySpanExporter: + """Provide an in-memory span exporter for OTEL tracing tests. + + Clears spans before and after each test to ensure isolation. + """ + _, exporter = _setup_tracing() + exporter.clear() + yield exporter + exporter.clear() + + +@pytest.fixture +def tracer_provider_and_exporter() -> tuple[TracerProvider, InMemorySpanExporter]: + """Provide both tracer provider and exporter. + + Clears spans before and after each test. + """ + provider, exporter = _setup_tracing() + exporter.clear() + yield provider, exporter + exporter.clear() diff --git a/tests/contrib/test_deterministic_span_ids.py b/tests/contrib/test_deterministic_span_ids.py new file mode 100644 index 000000000..9b788c337 --- /dev/null +++ b/tests/contrib/test_deterministic_span_ids.py @@ -0,0 +1,219 @@ +"""Tests for deterministic span ID generation in Temporal workflows.""" + +import uuid +from datetime import timedelta +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import opentelemetry.trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.contrib.opentelemetry import OtelTracingPlugin, TemporalIdGenerator +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + +class TestTemporalIdGeneratorUnit: + """Unit tests for TemporalIdGenerator.""" + + def test_fallback_to_random_outside_workflow(self): + """Verify TemporalIdGenerator uses random IDs outside workflow context.""" + id_gen = TemporalIdGenerator() + + # Outside workflow context, should use fallback random generator + span_id_1 = id_gen.generate_span_id() + span_id_2 = id_gen.generate_span_id() + + # IDs should be valid (non-zero) + assert span_id_1 != 0 + assert span_id_2 != 0 + # IDs should be different (random) + assert span_id_1 != span_id_2 + + def test_fallback_trace_id_outside_workflow(self): + """Verify TemporalIdGenerator uses random trace IDs outside workflow context.""" + id_gen = TemporalIdGenerator() + + trace_id_1 = id_gen.generate_trace_id() + trace_id_2 = id_gen.generate_trace_id() + + # IDs should be valid (non-zero) + assert trace_id_1 != 0 + assert trace_id_2 != 0 + # IDs should be different (random) + assert trace_id_1 != trace_id_2 + + def test_deterministic_span_id_in_workflow_context(self): + """Verify TemporalIdGenerator uses workflow.random() in workflow context.""" + id_gen = TemporalIdGenerator() + + # Mock workflow.in_workflow() and workflow.random() + mock_random = MagicMock() + mock_random.getrandbits.side_effect = [12345, 67890] + + with patch("temporalio.workflow.in_workflow", return_value=True): + with patch("temporalio.workflow.random", return_value=mock_random): + span_id_1 = id_gen.generate_span_id() + span_id_2 = id_gen.generate_span_id() + + assert span_id_1 == 12345 + assert span_id_2 == 67890 + + def test_deterministic_trace_id_in_workflow_context(self): + """Verify TemporalIdGenerator uses workflow.random() for trace IDs in workflow.""" + id_gen = TemporalIdGenerator() + + mock_random = MagicMock() + mock_random.getrandbits.side_effect = [ + 123456789012345678901234567890, + 987654321098765432109876543210, + ] + + with patch("temporalio.workflow.in_workflow", return_value=True): + with patch("temporalio.workflow.random", return_value=mock_random): + trace_id_1 = id_gen.generate_trace_id() + trace_id_2 = id_gen.generate_trace_id() + + assert trace_id_1 == 123456789012345678901234567890 + assert trace_id_2 == 987654321098765432109876543210 + + +class TestOtelTracingPluginDeterministicIds: + """Tests for OtelTracingPlugin with deterministic_ids parameter.""" + + def test_deterministic_ids_false_by_default(self): + """Verify deterministic_ids is False by default (backward compat).""" + provider = TracerProvider() + original_id_gen = provider.id_generator + + # Create plugin without deterministic_ids + OtelTracingPlugin(tracer_provider=provider) + + # id_generator should be unchanged + assert provider.id_generator is original_id_gen + + def test_deterministic_ids_true_configures_generator(self): + """Verify deterministic_ids=True configures TemporalIdGenerator.""" + provider = TracerProvider() + + OtelTracingPlugin(tracer_provider=provider, deterministic_ids=True) + + assert isinstance(provider.id_generator, TemporalIdGenerator) + + def test_deterministic_ids_without_tracer_provider(self): + """Verify deterministic_ids has no effect without tracer_provider.""" + # Should not raise - just creates plugin without configuring anything + plugin = OtelTracingPlugin(deterministic_ids=True) + assert plugin is not None + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +@activity.defn +async def record_span_id_activity() -> dict[str, Any]: + """Activity that creates a span and records its ID.""" + tracer = opentelemetry.trace.get_tracer(__name__) + with tracer.start_as_current_span("activity_span") as span: + span_context = span.get_span_context() + return { + "trace_id": span_context.trace_id, + "span_id": span_context.span_id, + } + + +@workflow.defn +class SpanIdTestWorkflow: + """Workflow that creates spans and records their IDs.""" + + @workflow.run + async def run(self) -> dict[str, Any]: + tracer = opentelemetry.trace.get_tracer(__name__) + + # Create a workflow span + with tracer.start_as_current_span("workflow_span") as span: + span_context = span.get_span_context() + workflow_span_id = span_context.span_id + workflow_trace_id = span_context.trace_id + + # Call activity to get activity span ID + activity_result = await workflow.execute_activity( + record_span_id_activity, + start_to_close_timeout=timedelta(seconds=30), + ) + + return { + "workflow_span_id": workflow_span_id, + "workflow_trace_id": workflow_trace_id, + "activity_span_id": activity_result["span_id"], + "activity_trace_id": activity_result["trace_id"], + } + + +@pytest.fixture +def tracer_provider_with_exporter(): + """Create a TracerProvider with InMemorySpanExporter.""" + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + return provider, exporter + + +@pytest.mark.asyncio +async def test_deterministic_span_ids_in_workflow(tracer_provider_with_exporter): + """Integration test: verify span IDs are deterministic in workflow context.""" + provider, exporter = tracer_provider_with_exporter + + # Set as global tracer provider + opentelemetry.trace.set_tracer_provider(provider) + + plugin = OtelTracingPlugin( + tracer_provider=provider, + deterministic_ids=True, + filter_replay_spans=True, + ) + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + plugins=[plugin], + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[SpanIdTestWorkflow], + activities=[record_span_id_activity], + workflow_runner=SandboxedWorkflowRunner( + restrictions=plugin.sandbox_restrictions + ), + ): + result = await client.execute_workflow( + SpanIdTestWorkflow.run, + id=f"test-workflow-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Verify we got span IDs + assert result["workflow_span_id"] != 0, "Workflow span ID should be non-zero" + assert result["activity_span_id"] != 0, "Activity span ID should be non-zero" + + # Check exported spans + spans = exporter.get_finished_spans() + span_names = [s.name for s in spans] + assert "workflow_span" in span_names, "Workflow span should be exported" + assert "activity_span" in span_names, "Activity span should be exported" + + diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index 7e21c8935..5a206811a 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -989,3 +989,94 @@ def otel_context_error(record: logging.LogRecord) -> bool: assert ( capturer.find(otel_context_error) is None ), "Detach from context message should not be logged" + + +async def test_opentelemetry_create_spans_false_no_spans(client: Client): + """Test that create_spans=False creates no spans but propagates context.""" + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + # Create new client with tracer interceptor with create_spans=False + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer, create_spans=False)] + client = Client(**client_config) + + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[ReadBaggageTestWorkflow], + activities=[read_baggage_activity], + ): + # Test with baggage values to verify context propagation + with baggage_values({"user.id": "test-user-no-spans", "tenant.id": "test-corp"}): + result = await client.execute_workflow( + ReadBaggageTestWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Verify no spans were created + spans = exporter.get_finished_spans() + assert len(spans) == 0, f"Expected no spans but got: {[s.name for s in spans]}" + + # Verify context (baggage) was propagated to activity + assert result["user_id"] == "test-user-no-spans" + assert result["tenant_id"] == "test-corp" + + +async def test_opentelemetry_create_spans_false_with_child_workflow( + client: Client, env: WorkflowEnvironment +): + """Test that create_spans=False propagates context through child workflows.""" + if env.supports_time_skipping: + pytest.skip( + "Java test server: https://github.com/temporalio/sdk-java/issues/1424" + ) + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + # Create new client with tracer interceptor with create_spans=False + client_config = client.config() + client_config["interceptors"] = [TracingInterceptor(tracer, create_spans=False)] + client = Client(**client_config) + + task_queue = f"task_queue_{uuid.uuid4()}" + async with Worker( + client, + task_queue=task_queue, + workflows=[TracingWorkflow], + activities=[tracing_activity], + # Needed for reliable test execution + workflow_runner=UnsandboxedWorkflowRunner(), + ): + # Run workflow with child workflow and activity + workflow_id = f"workflow_{uuid.uuid4()}" + await client.execute_workflow( + TracingWorkflow.run, + TracingWorkflowParam( + actions=[ + TracingWorkflowAction( + activity=TracingWorkflowActionActivity( + param=TracingActivityParam(heartbeat=False), + ), + ), + TracingWorkflowAction( + child_workflow=TracingWorkflowActionChildWorkflow( + id=f"{workflow_id}_child", + param=TracingWorkflowParam(actions=[]), + ) + ), + ], + ), + id=workflow_id, + task_queue=task_queue, + ) + + # Verify no spans were created + spans = exporter.get_finished_spans() + assert len(spans) == 0, f"Expected no spans but got: {[s.name for s in spans]}" diff --git a/tests/contrib/test_sandbox_passthrough.py b/tests/contrib/test_sandbox_passthrough.py new file mode 100644 index 000000000..4de4a8d44 --- /dev/null +++ b/tests/contrib/test_sandbox_passthrough.py @@ -0,0 +1,665 @@ +"""Tests for OpenTelemetry sandbox passthrough behavior. + +These tests verify that OTEL context propagates correctly when opentelemetry +is configured as a passthrough module in the sandbox. +""" + +import asyncio +import concurrent.futures +import multiprocessing +import threading +import uuid +from datetime import timedelta +from multiprocessing import Queue +from typing import Any + +import pytest + +import opentelemetry.trace + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner, SandboxRestrictions +from temporalio.contrib.opentelemetry import TracingInterceptor, OtelTracingPlugin + + +@workflow.defn +class SandboxSpanTestWorkflow: + """Workflow that tests get_current_span() inside the sandbox.""" + + @workflow.run + async def run(self) -> dict[str, Any]: + # This tests that get_current_span() returns the propagated span + # INSIDE the sandbox. Without passthrough, this returns INVALID_SPAN. + span = opentelemetry.trace.get_current_span() + span_context = span.get_span_context() + + # Also check if context was propagated via headers + headers = workflow.info().headers + has_tracer_header = "_tracer-data" in headers + + return { + "is_valid": span is not opentelemetry.trace.INVALID_SPAN, + "trace_id": span_context.trace_id if span_context else 0, + "span_id": span_context.span_id if span_context else 0, + "has_tracer_header": has_tracer_header, + "span_str": str(span), + } + + +@pytest.mark.asyncio +async def test_sandbox_context_propagation_without_passthrough( + tracer_provider_and_exporter, +): + """Test that context does NOT propagate without passthrough. + + This test verifies the problem: without opentelemetry passthrough, + get_current_span() returns INVALID_SPAN inside the sandbox. + """ + provider, exporter = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + interceptor = TracingInterceptor() + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + # Create a client with the tracing interceptor + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + interceptors=[interceptor], + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[SandboxSpanTestWorkflow], + interceptors=[interceptor], + # Default sandboxed runner - NO passthrough + ): + with tracer.start_as_current_span("client_root") as root: + result = await client.execute_workflow( + SandboxSpanTestWorkflow.run, + id=f"test-workflow-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Without passthrough, context is propagated via headers but + # get_current_span() still returns INVALID_SPAN because the sandbox + # re-imports opentelemetry creating a separate ContextVar + assert result["has_tracer_header"], "Tracer header should be present" + assert not result["is_valid"], "Without passthrough, span should be INVALID_SPAN" + + +@pytest.mark.asyncio +async def test_sandbox_context_propagation_with_passthrough( + tracer_provider_and_exporter, +): + """Test that context DOES propagate with passthrough. + + This test verifies the fix: with opentelemetry passthrough, + get_current_span() returns the propagated span inside the sandbox. + """ + provider, exporter = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + interceptor = TracingInterceptor() + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + # Create a client with the tracing interceptor + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + interceptors=[interceptor], + ) + + # Use SandboxedWorkflowRunner with opentelemetry passthrough + async with Worker( + client, + task_queue=task_queue, + workflows=[SandboxSpanTestWorkflow], + interceptors=[interceptor], + workflow_runner=SandboxedWorkflowRunner( + restrictions=SandboxRestrictions.default.with_passthrough_modules( + "opentelemetry" + ) + ), + ): + with tracer.start_as_current_span("client_root") as root: + root_context = root.get_span_context() + + result = await client.execute_workflow( + SandboxSpanTestWorkflow.run, + id=f"test-workflow-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # With passthrough, span should be valid and have the same trace_id + print(f"Result: {result}") + print(f"Root trace_id: {root_context.trace_id}") + assert result["has_tracer_header"], "Tracer header should be present" + assert result["is_valid"], f"With passthrough, span should be valid. Got: {result['span_str']}" + assert result["trace_id"] == root_context.trace_id, "Trace ID should match" + + +@pytest.mark.asyncio +async def test_otel_tracing_plugin_provides_sandbox_restrictions( + tracer_provider_and_exporter, +): + """Test that OtelTracingPlugin provides correct sandbox restrictions.""" + provider, _ = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + plugin = OtelTracingPlugin(tracer_provider=provider) + + # Verify the plugin provides sandbox_restrictions property + restrictions = plugin.sandbox_restrictions + assert "opentelemetry" in restrictions.passthrough_modules + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + plugins=[plugin], + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[SandboxSpanTestWorkflow], + workflow_runner=SandboxedWorkflowRunner( + restrictions=plugin.sandbox_restrictions + ), + ): + with tracer.start_as_current_span("client_root") as root: + root_context = root.get_span_context() + + result = await client.execute_workflow( + SandboxSpanTestWorkflow.run, + id=f"test-workflow-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # With plugin's sandbox_restrictions, context should propagate + assert result["is_valid"], "With plugin restrictions, span should be valid" + assert result["trace_id"] == root_context.trace_id, "Trace ID should match" + + +@pytest.mark.asyncio +async def test_no_state_leakage_between_workflows( + tracer_provider_and_exporter, +): + """Test that context doesn't leak between sequential workflow runs.""" + provider, exporter = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + interceptor = TracingInterceptor() + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + # Create a client with the tracing interceptor + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + interceptors=[interceptor], + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[SandboxSpanTestWorkflow], + interceptors=[interceptor], + workflow_runner=SandboxedWorkflowRunner( + restrictions=SandboxRestrictions.default.with_passthrough_modules( + "opentelemetry" + ) + ), + ): + # Run first workflow with one trace + with tracer.start_as_current_span("trace_1") as span_1: + trace_1_id = span_1.get_span_context().trace_id + result_1 = await client.execute_workflow( + SandboxSpanTestWorkflow.run, + id=f"test-workflow-1-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Run second workflow with different trace + with tracer.start_as_current_span("trace_2") as span_2: + trace_2_id = span_2.get_span_context().trace_id + result_2 = await client.execute_workflow( + SandboxSpanTestWorkflow.run, + id=f"test-workflow-2-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Each workflow should see its own trace context + assert result_1["is_valid"] + assert result_2["is_valid"] + assert result_1["trace_id"] == trace_1_id + assert result_2["trace_id"] == trace_2_id + assert result_1["trace_id"] != result_2["trace_id"], "Traces should be independent" + + +# Global barrier for concurrent workflow synchronization +_concurrent_barrier: threading.Barrier | None = None + + +@activity.defn +def barrier_sync_activity(workflow_index: int) -> None: + """Activity that waits at barrier to ensure concurrent execution. + + All concurrent workflows will block here until all have reached this point, + guaranteeing true concurrent execution for the test. + """ + assert _concurrent_barrier is not None, "Barrier must be set before test" + _concurrent_barrier.wait() + + +@workflow.defn +class ConcurrentTraceContextWorkflow: + """Workflow that captures trace context before and after a barrier activity.""" + + @workflow.run + async def run(self, workflow_index: int) -> dict[str, Any]: + # Capture trace context BEFORE activity (set by interceptor at workflow start) + span_before = opentelemetry.trace.get_current_span() + ctx_before = span_before.get_span_context() + trace_id_before = ctx_before.trace_id if ctx_before else 0 + + # Call activity - all workflows block at barrier until all arrive + await workflow.execute_activity( + barrier_sync_activity, + workflow_index, + start_to_close_timeout=timedelta(seconds=30), + ) + + # Capture trace context AFTER activity (should still be the same) + span_after = opentelemetry.trace.get_current_span() + ctx_after = span_after.get_span_context() + trace_id_after = ctx_after.trace_id if ctx_after else 0 + + return { + "workflow_index": workflow_index, + "trace_id_before": trace_id_before, + "trace_id_after": trace_id_after, + "is_valid": span_before is not opentelemetry.trace.INVALID_SPAN, + } + + +@pytest.mark.asyncio +async def test_concurrent_workflows_isolated_trace_context( + tracer_provider_and_exporter, +): + """Test that concurrent workflows each see their own trace context. + + This test verifies that when multiple workflows run concurrently with + opentelemetry passthrough enabled, each workflow sees only its own trace + context and not another workflow's. Uses a threading.Barrier to ensure + all workflows are truly executing concurrently (not timing-based). + """ + global _concurrent_barrier + num_workflows = 5 + _concurrent_barrier = threading.Barrier(num_workflows) + + provider, exporter = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + interceptor = TracingInterceptor() + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"test-queue-{uuid.uuid4()}" + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + interceptors=[interceptor], + ) + + # Use ThreadPoolExecutor for sync activity + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_workflows + ) as executor: + async with Worker( + client, + task_queue=task_queue, + workflows=[ConcurrentTraceContextWorkflow], + activities=[barrier_sync_activity], + activity_executor=executor, + interceptors=[interceptor], + workflow_runner=SandboxedWorkflowRunner( + restrictions=SandboxRestrictions.default.with_passthrough_modules( + "opentelemetry" + ) + ), + ): + + async def run_workflow_with_trace(index: int) -> tuple[int, dict, int]: + """Start a workflow under its own trace context.""" + with tracer.start_as_current_span(f"trace_{index}") as span: + expected_trace_id = span.get_span_context().trace_id + result = await client.execute_workflow( + ConcurrentTraceContextWorkflow.run, + index, + id=f"concurrent-test-{index}-{uuid.uuid4()}", + task_queue=task_queue, + ) + return (index, result, expected_trace_id) + + # Launch all workflows concurrently + results = await asyncio.gather( + *[run_workflow_with_trace(i) for i in range(num_workflows)] + ) + + # Verify each workflow saw its own trace_id (not another workflow's) + for index, result, expected_trace_id in results: + assert result["is_valid"], f"Workflow {index} should see valid span" + assert result["trace_id_before"] == expected_trace_id, ( + f"Workflow {index} saw wrong trace_id before activity: " + f"expected {expected_trace_id}, got {result['trace_id_before']}" + ) + assert result["trace_id_after"] == expected_trace_id, ( + f"Workflow {index} saw wrong trace_id after activity: " + f"expected {expected_trace_id}, got {result['trace_id_after']}" + ) + + # Verify all trace_ids were different (each workflow had unique trace) + trace_ids = [r[2] for r in results] + assert len(set(trace_ids)) == num_workflows, ( + f"All {num_workflows} workflows should have unique traces, " + f"but got {len(set(trace_ids))} unique trace_ids" + ) + + +# ============================================================================= +# Cross-Process Trace Continuity Test +# ============================================================================= + + +@activity.defn +async def record_trace_with_child_span(label: str) -> dict[str, Any]: + """Activity that creates a child span and records parent span info. + + This activity captures the current span context (propagated via headers) + and creates a child span to verify the parent span_id is correct. + """ + from opentelemetry.sdk.trace import ReadableSpan + + tracer = opentelemetry.trace.get_tracer(__name__) + + # Get the current span context (propagated via Temporal headers) + current_span = opentelemetry.trace.get_current_span() + current_ctx = current_span.get_span_context() + + # Create a child span to verify parent relationship + with tracer.start_as_current_span(f"child_{label}") as child: + child_ctx = child.get_span_context() + # Get the parent span context from the child + # ReadableSpan (from SDK) has parent attribute, NonRecordingSpan doesn't + if isinstance(child, ReadableSpan): + parent_ctx = child.parent + else: + # NonRecordingSpan - fall back to current span context + parent_ctx = current_ctx + + return { + "label": label, + "trace_id": current_ctx.trace_id if current_ctx else 0, + "current_span_id": current_ctx.span_id if current_ctx else 0, + "child_span_id": child_ctx.span_id if child_ctx else 0, + "child_parent_span_id": parent_ctx.span_id if parent_ctx else 0, + "is_valid": current_span is not opentelemetry.trace.INVALID_SPAN, + } + + +@workflow.defn +class CrossProcessTraceWorkflow: + """Workflow that records trace info in two parts, separated by a signal.""" + + def __init__(self) -> None: + self._continue = False + + @workflow.signal + def continue_workflow(self) -> None: + """Signal to continue to part 2.""" + self._continue = True + + @workflow.run + async def run(self) -> dict[str, Any]: + # Part 1: Record trace on first worker + trace_part1 = await workflow.execute_activity( + record_trace_with_child_span, + "part1", + start_to_close_timeout=timedelta(seconds=30), + ) + + # Wait for signal (worker swap happens here) + await workflow.wait_condition(lambda: self._continue) + + # Part 2: Record trace on second worker (different process) + trace_part2 = await workflow.execute_activity( + record_trace_with_child_span, + "part2", + start_to_close_timeout=timedelta(seconds=30), + ) + + return {"part1": trace_part1, "part2": trace_part2} + + +def _run_worker_in_subprocess( + target_host: str, + namespace: str, + task_queue: str, + ready_queue: Queue, + stop_queue: Queue, +) -> None: + """Entry point for worker subprocess. + + Runs in a completely separate process with fresh Python interpreter. + """ + asyncio.run( + _run_worker_async(target_host, namespace, task_queue, ready_queue, stop_queue) + ) + + +async def _run_worker_async( + target_host: str, + namespace: str, + task_queue: str, + ready_queue: Queue, + stop_queue: Queue, +) -> None: + """Async worker runner for subprocess.""" + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + # Set up TracerProvider in subprocess so we get real spans + provider = TracerProvider() + exporter = InMemorySpanExporter() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + opentelemetry.trace.set_tracer_provider(provider) + + # Fresh interceptor in this process - no shared state with parent + # Use create_spans=False to just propagate context (like OtelTracingPlugin does) + interceptor = TracingInterceptor(create_spans=False) + + client = await Client.connect( + target_host, + namespace=namespace, + interceptors=[interceptor], + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[CrossProcessTraceWorkflow], + activities=[record_trace_with_child_span], + interceptors=[interceptor], + workflow_runner=SandboxedWorkflowRunner( + restrictions=SandboxRestrictions.default.with_passthrough_modules( + "opentelemetry" + ) + ), + ): + # Signal that worker is ready + ready_queue.put("ready") + + # Run until stop signal + while stop_queue.empty(): + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_trace_continuity_across_worker_processes( + tracer_provider_and_exporter, +): + """Test trace context survives workflow moving between separate processes. + + This test verifies that trace context is properly serialized into Temporal + workflow headers and correctly deserialized when a completely different + worker process picks up the workflow. This is critical because: + - Workflows can be replayed on different workers + - Workers can be on different machines + - In-process ContextVars cannot be relied upon + + The test: + 1. Starts workflow from main process with trace_id X + 2. Worker 1 (subprocess) handles part 1, records trace info + 3. Worker 1 terminates + 4. Worker 2 (new subprocess) starts, handles part 2 + 5. Verifies both parts saw the same trace_id AND parent span_id + """ + provider, exporter = tracer_provider_and_exporter + tracer = opentelemetry.trace.get_tracer(__name__) + + # Use spawn context for true process isolation (not fork) + mp_context = multiprocessing.get_context("spawn") + + async with await WorkflowEnvironment.start_local() as env: + task_queue = f"cross-process-{uuid.uuid4()}" + target_host = env.client.service_client.config.target_host + namespace = env.client.namespace + + # Create queues for process coordination + ready_queue1: Queue = mp_context.Queue() + stop_queue1: Queue = mp_context.Queue() + ready_queue2: Queue = mp_context.Queue() + stop_queue2: Queue = mp_context.Queue() + + # Start Worker 1 in subprocess + worker1_process = mp_context.Process( + target=_run_worker_in_subprocess, + args=(target_host, namespace, task_queue, ready_queue1, stop_queue1), + ) + worker1_process.start() + + try: + # Wait for worker 1 to be ready + ready_queue1.get(timeout=30) + + # Create client with tracing interceptor + # Use create_spans=False (like OtelTracingPlugin) - just propagate context + interceptor = TracingInterceptor(create_spans=False) + client = await Client.connect( + target_host, + namespace=namespace, + interceptors=[interceptor], + ) + + # Start workflow under a trace context + with tracer.start_as_current_span("root_trace") as root_span: + root_ctx = root_span.get_span_context() + expected_trace_id = root_ctx.trace_id + expected_span_id = root_ctx.span_id + + handle = await client.start_workflow( + CrossProcessTraceWorkflow.run, + id=f"cross-process-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Wait for part 1 to complete (workflow now waiting for signal) + # Poll workflow state or just wait a bit + await asyncio.sleep(2) + + # Stop Worker 1 + stop_queue1.put("stop") + worker1_process.join(timeout=10) + + # Start Worker 2 in NEW subprocess (completely fresh process) + worker2_process = mp_context.Process( + target=_run_worker_in_subprocess, + args=(target_host, namespace, task_queue, ready_queue2, stop_queue2), + ) + worker2_process.start() + + # Wait for worker 2 to be ready + ready_queue2.get(timeout=30) + + # Signal workflow to continue + await handle.signal(CrossProcessTraceWorkflow.continue_workflow) + + # Get result + result = await handle.result() + + # Stop Worker 2 + stop_queue2.put("stop") + worker2_process.join(timeout=10) + + finally: + # Cleanup: ensure processes are terminated + if worker1_process.is_alive(): + worker1_process.terminate() + worker1_process.join(timeout=5) + if "worker2_process" in locals() and worker2_process.is_alive(): + worker2_process.terminate() + worker2_process.join(timeout=5) + + # Verify trace continuity + part1 = result["part1"] + part2 = result["part2"] + + # Both parts should see valid spans + assert part1["is_valid"], "Part 1 should see valid span" + assert part2["is_valid"], "Part 2 should see valid span" + + # Both parts should see the same trace_id + assert part1["trace_id"] == expected_trace_id, ( + f"Part 1 trace_id mismatch: expected {expected_trace_id}, " + f"got {part1['trace_id']}" + ) + assert part2["trace_id"] == expected_trace_id, ( + f"Part 2 trace_id mismatch: expected {expected_trace_id}, " + f"got {part2['trace_id']}" + ) + + # Both parts should see the same current span_id (the workflow span) + assert part1["current_span_id"] == part2["current_span_id"], ( + f"Current span_id should be same across processes: " + f"part1={part1['current_span_id']}, part2={part2['current_span_id']}" + ) + + # Child spans created in both workers should have the same parent + assert part1["child_parent_span_id"] == part2["child_parent_span_id"], ( + f"Child spans should have same parent span_id: " + f"part1={part1['child_parent_span_id']}, " + f"part2={part2['child_parent_span_id']}" + ) + + # The child's parent should be the current span + assert part1["child_parent_span_id"] == part1["current_span_id"], ( + f"Child's parent should be current span: " + f"child_parent={part1['child_parent_span_id']}, " + f"current={part1['current_span_id']}" + ) diff --git a/tests/contrib/test_workflow_span_replay.py b/tests/contrib/test_workflow_span_replay.py new file mode 100644 index 000000000..1629e43b8 --- /dev/null +++ b/tests/contrib/test_workflow_span_replay.py @@ -0,0 +1,151 @@ +"""Tests for workflow_span helper during replay. + +This test verifies that workflow_span correctly prevents duplicate spans +when workflows replay (with max_cached_workflows=0). +""" + +from __future__ import annotations + +import asyncio +import uuid +from datetime import timedelta + +import pytest +from opentelemetry import trace +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.contrib.openai_agents import workflow_span +from temporalio.contrib.opentelemetry import TracingInterceptor +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import UnsandboxedWorkflowRunner, Worker + +# OTEL tracer for activities +tracer = trace.get_tracer(__name__) + + +@activity.defn +async def research_activity(query: str) -> str: + """Simulates a research activity that creates OTEL spans.""" + with tracer.start_as_current_span("research_work") as span: + span.set_attribute("query", query) + await asyncio.sleep(0.01) + return f"Research results for: {query}" + + +@activity.defn +async def analyze_activity(data: str) -> str: + """Simulates an analysis activity.""" + with tracer.start_as_current_span("analyze_work") as span: + span.set_attribute("data_length", len(data)) + await asyncio.sleep(0.01) + return f"Analysis of: {data[:50]}..." + + +@activity.defn +async def write_report_activity(analysis: str) -> str: + """Simulates writing a report.""" + with tracer.start_as_current_span("write_report_work") as span: + span.set_attribute("input_length", len(analysis)) + await asyncio.sleep(0.01) + return f"Report based on: {analysis[:30]}..." + + +@workflow.defn +class WorkflowWithSpans: + """A workflow that creates a span and calls multiple activities.""" + + @workflow.run + async def run(self, query: str) -> str: + # Use workflow_span for replay-safe span creation + with workflow_span("research_workflow", query=query): + # Step 1: Research + research = await workflow.execute_activity( + research_activity, + query, + start_to_close_timeout=timedelta(seconds=30), + ) + + # Step 2: Analyze + analysis = await workflow.execute_activity( + analyze_activity, + research, + start_to_close_timeout=timedelta(seconds=30), + ) + + # Step 3: Write report + report = await workflow.execute_activity( + write_report_activity, + analysis, + start_to_close_timeout=timedelta(seconds=30), + ) + + return report + + +@pytest.mark.asyncio +async def test_workflow_span_no_duplication_during_replay(tracing: InMemorySpanExporter): + """Test that workflow_span prevents duplicate spans during replay. + + With max_cached_workflows=0: + - Each activity completion triggers a new workflow task + - Each workflow task replays from the beginning + - workflow_span should prevent duplicate workflow spans + - Activity spans should appear exactly once (no replay) + """ + async with await WorkflowEnvironment.start_local() as env: + interceptor = TracingInterceptor(create_spans=False) + client_config = env.client.config() + client_config["interceptors"] = [interceptor] + client = Client(**client_config) + + task_queue = f"test-e2e-replay-{uuid.uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[WorkflowWithSpans], + activities=[research_activity, analyze_activity, write_report_activity], + interceptors=[interceptor], + workflow_runner=UnsandboxedWorkflowRunner(), + max_cached_workflows=0, # Force replay on every task + ): + # Create a client-side root span to establish trace context + with tracer.start_as_current_span("client_request") as root: + root.set_attribute("query", "Test query") + result = await client.execute_workflow( + WorkflowWithSpans.run, + "Test query", + id=f"wf-e2e-{uuid.uuid4()}", + task_queue=task_queue, + ) + + assert "Report based on" in result + + await asyncio.sleep(0.3) + spans = tracing.get_finished_spans() + + # Count spans by name + span_counts: dict[str, int] = {} + for s in spans: + span_counts[s.name] = span_counts.get(s.name, 0) + 1 + + # Verify workflow span appears exactly once (replay-safe) + assert span_counts.get("research_workflow", 0) == 1, ( + f"Expected 1 research_workflow span, got {span_counts.get('research_workflow', 0)}. " + f"workflow_span may not be working correctly." + ) + + # Verify each activity span appears exactly once + assert span_counts.get("research_work", 0) == 1, "research_work should appear once" + assert span_counts.get("analyze_work", 0) == 1, "analyze_work should appear once" + assert span_counts.get("write_report_work", 0) == 1, "write_report_work should appear once" + + # Verify no unexpected duplication + duplicated = {name: count for name, count in span_counts.items() if count > 1} + assert not duplicated, f"Unexpected span duplication: {duplicated}" + + # Verify all spans share the same trace ID + trace_ids = set(s.context.trace_id for s in spans) + assert len(trace_ids) == 1, f"Expected 1 trace_id, got {len(trace_ids)}"