diff --git a/cmd/protoc-gen-connect-python/generator/generator.go b/cmd/protoc-gen-connect-python/generator/generator.go index ed968b7..f919064 100644 --- a/cmd/protoc-gen-connect-python/generator/generator.go +++ b/cmd/protoc-gen-connect-python/generator/generator.go @@ -180,7 +180,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { p.P(`from connect import (`) p.P(` Client,`) p.P(` ClientOptions,`) - p.P(` ConnectOptions,`) + p.P(` HandlerOptions,`) p.P(` Handler,`) p.P(` HandlerContext,`) p.P(` IdempotencyLevel,`) @@ -307,7 +307,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { } p.P() p.P() - p.P(`def create_`, upperSvcName, `_handlers`, `(`, `service: `, handler, `, options: ConnectOptions | None = None`, `) -> list[Handler]:`) + p.P(`def create_`, upperSvcName, `_handlers`, `(`, `service: `, handler, `, options: HandlerOptions | None = None`, `) -> list[Handler]:`) p.P(` handlers = [`) for _, meth := range sortedMap(p.services) { svc := p.services[meth] @@ -339,7 +339,7 @@ func (g *Generator) generate(gen *protogen.GeneratedFile, f *protogen.File) { p.P(` output=`, svc.output.method, `,`) if options := meth.Options; options != nil { if desc, ok := options.(*descriptorpb.MethodOptions); ok && desc.GetIdempotencyLevel() != descriptorpb.MethodOptions_IDEMPOTENCY_UNKNOWN { - p.P(` options=ConnectOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `).merge(options),`) + p.P(` options=HandlerOptions(idempotency_level=IdempotencyLevel.`, desc.GetIdempotencyLevel().String(), `).merge(options),`) } else { p.P(` options=options,`) } diff --git a/cmd/protoc-gen-connect-python/generator/package.go b/cmd/protoc-gen-connect-python/generator/package.go index 077d6ad..1fe79fa 100644 --- a/cmd/protoc-gen-connect-python/generator/package.go +++ b/cmd/protoc-gen-connect-python/generator/package.go @@ -9,7 +9,7 @@ import ( // p.P() // p.P(`from connect.connect import UnaryRequest, UnaryResponse`) // p.P(`from connect.handler import UnaryHandler`) -// p.P(`from connect.options import ConnectOptions`) +// p.P(`from connect.options import HandlerOptions`) // p.P(`from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor`) // PythonIdent is a Python identifier, consisting of a name and import path. diff --git a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py index c54e429..9f5f5e7 100644 --- a/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py +++ b/conformance/gen/connectrpc/conformance/v1/conformancev1connect/service_connect.py @@ -1,22 +1,43 @@ # Generated by the protoc-gen-connect-python. DO NOT EDIT! # source: connectrpc/conformance/v1/conformancev1connect/service.proto # Protobuf Python Version: (unknown) -# protoc-gen-connect-python version: v0.0.0-20250225131640-797060f503da+dirty +# protoc-gen-connect-python version: v0.0.0-20250710124620-c0c8871def6b+dirty """Generated connect code.""" +import abc from enum import Enum -from connect.client import Client -import connect.connect -from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler, BidiStreamHandler -from connect.handler_context import HandlerContext -from connect.options import ClientOptions, ConnectOptions +from connect import ( + Client, + ClientOptions, + HandlerOptions, + Handler, + HandlerContext, + IdempotencyLevel, + StreamRequest, + StreamResponse, +) +from connect import UnaryRequest as ConnectUnaryRequest +from connect import UnaryResponse as ConnectUnaryResponse from connect.connection_pool import AsyncConnectionPool +from connect.handler import BidiStreamHandler, ClientStreamHandler, ServerStreamHandler, UnaryHandler from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor -from connect.idempotency_level import IdempotencyLevel from .. import service_pb2 -from ..service_pb2 import UnaryRequest, UnaryResponse, ServerStreamRequest, ServerStreamResponse, ClientStreamRequest, ClientStreamResponse, BidiStreamRequest, BidiStreamResponse, UnimplementedRequest, UnimplementedResponse, IdempotentUnaryRequest, IdempotentUnaryResponse +from ..service_pb2 import ( + UnaryRequest, + UnaryResponse, + ServerStreamRequest, + ServerStreamResponse, + ClientStreamRequest, + ClientStreamResponse, + BidiStreamRequest, + BidiStreamResponse, + UnimplementedRequest, + UnimplementedResponse, + IdempotentUnaryRequest, + IdempotentUnaryResponse, +) class ConformanceServiceProcedures(Enum): @@ -60,33 +81,33 @@ def __init__(self, base_url: str, pool: AsyncConnectionPool, options: ClientOpti pool, base_url + ConformanceServiceProcedures.Unimplemented.value, UnimplementedRequest, UnimplementedResponse, options ).call_unary self.IdempotentUnary = Client[IdempotentUnaryRequest, IdempotentUnaryResponse]( - pool, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options), + pool, base_url + ConformanceServiceProcedures.IdempotentUnary.value, IdempotentUnaryRequest, IdempotentUnaryResponse, ClientOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS, enable_get=True).merge(options) ).call_unary -class ConformanceServiceHandler: +class ConformanceServiceHandler(metaclass=abc.ABCMeta): """Handler for the conformanceService service.""" - async def Unary(self, request: connect.connect.UnaryRequest[UnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnaryResponse]: + async def Unary(self, request: ConnectUnaryRequest[UnaryRequest], context: HandlerContext) -> ConnectUnaryResponse[UnaryResponse]: raise NotImplementedError() - async def ServerStream(self, request: connect.connect.StreamRequest[ServerStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ServerStreamResponse]: + async def ServerStream(self, request: StreamRequest[ServerStreamRequest], context: HandlerContext) -> StreamResponse[ServerStreamResponse]: raise NotImplementedError() - async def ClientStream(self, request: connect.connect.StreamRequest[ClientStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[ClientStreamResponse]: + async def ClientStream(self, request: StreamRequest[ClientStreamRequest], context: HandlerContext) -> StreamResponse[ClientStreamResponse]: raise NotImplementedError() - async def BidiStream(self, request: connect.connect.StreamRequest[BidiStreamRequest], context: HandlerContext) -> connect.connect.StreamResponse[BidiStreamResponse]: + async def BidiStream(self, request: StreamRequest[BidiStreamRequest], context: HandlerContext) -> StreamResponse[BidiStreamResponse]: raise NotImplementedError() - async def Unimplemented(self, request: connect.connect.UnaryRequest[UnimplementedRequest], context: HandlerContext) -> connect.connect.UnaryResponse[UnimplementedResponse]: + async def Unimplemented(self, request: ConnectUnaryRequest[UnimplementedRequest], context: HandlerContext) -> ConnectUnaryResponse[UnimplementedResponse]: raise NotImplementedError() - async def IdempotentUnary(self, request: connect.connect.UnaryRequest[IdempotentUnaryRequest], context: HandlerContext) -> connect.connect.UnaryResponse[IdempotentUnaryResponse]: + async def IdempotentUnary(self, request: ConnectUnaryRequest[IdempotentUnaryRequest], context: HandlerContext) -> ConnectUnaryResponse[IdempotentUnaryResponse]: raise NotImplementedError() -def create_ConformanceService_handlers(service: ConformanceServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: +def create_ConformanceService_handlers(service: ConformanceServiceHandler, options: HandlerOptions | None = None) -> list[Handler]: handlers = [ UnaryHandler( procedure=ConformanceServiceProcedures.Unary.value, @@ -128,7 +149,7 @@ def create_ConformanceService_handlers(service: ConformanceServiceHandler, optio unary=service.IdempotentUnary, input=IdempotentUnaryRequest, output=IdempotentUnaryResponse, - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), ), ] return handlers diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py index cb31310..9227697 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect.py @@ -1,7 +1,7 @@ # Generated by the protoc-gen-connect-python. DO NOT EDIT! # source: examples/proto/connectrpc/eliza/v1/v1connect/eliza.proto # Protobuf Python Version: v5.29.3 -# protoc-gen-connect-python version: v0.0.0-20250708090951-d93686e5039f +# protoc-gen-connect-python version: v0.0.0-20250710124620-c0c8871def6b+dirty """Generated connect code.""" import abc @@ -10,7 +10,7 @@ from connect import ( Client, ClientOptions, - ConnectOptions, + HandlerOptions, Handler, HandlerContext, IdempotencyLevel, @@ -87,14 +87,14 @@ async def Reflect(self, request: StreamRequest[ReflectRequest], context: Handler raise NotImplementedError() -def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: +def create_ElizaService_handlers(service: ElizaServiceHandler, options: HandlerOptions | None = None) -> list[Handler]: handlers = [ UnaryHandler( procedure=ElizaServiceProcedures.Say.value, unary=service.Say, input=SayRequest, output=SayResponse, - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), ), BidiStreamHandler( procedure=ElizaServiceProcedures.Converse.value, diff --git a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py index 0c888b8..6e3295e 100644 --- a/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py +++ b/examples/proto/connectrpc/eliza/v1/v1connect/eliza_connect_pb2.py @@ -10,9 +10,9 @@ from connect import ( Client, ClientOptions, - ConnectOptions, Handler, HandlerContext, + HandlerOptions, IdempotencyLevel, StreamRequest, StreamResponse, @@ -93,14 +93,14 @@ async def Reflect( raise NotImplementedError() -def create_ElizaService_handlers(service: ElizaServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: +def create_ElizaService_handlers(service: ElizaServiceHandler, options: HandlerOptions | None = None) -> list[Handler]: handlers = [ UnaryHandler( procedure=ElizaServiceProcedures.Say.value, unary=service.Say, input=SayRequest, output=SayResponse, - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS).merge(options), ), BidiStreamHandler( procedure=ElizaServiceProcedures.Converse.value, diff --git a/src/connect/__init__.py b/src/connect/__init__.py index 0e69959..260b428 100644 --- a/src/connect/__init__.py +++ b/src/connect/__init__.py @@ -23,7 +23,7 @@ from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel from connect.middleware import ConnectMiddleware -from connect.options import ClientOptions, ConnectOptions +from connect.options import ClientOptions, HandlerOptions from connect.protocol import Protocol from connect.request import Request from connect.response import Response as HTTPResponse @@ -43,7 +43,7 @@ "Compression", "ConnectError", "ConnectMiddleware", - "ConnectOptions", + "HandlerOptions", "GZipCompression", "Handler", "HandlerContext", diff --git a/src/connect/call_options.py b/src/connect/call_options.py index adf3f51..88983fa 100644 --- a/src/connect/call_options.py +++ b/src/connect/call_options.py @@ -1,4 +1,4 @@ -"""Options and configuration for making calls, including timeout and abort event support.""" +"""Call options configuration models.""" import asyncio @@ -6,7 +6,12 @@ class CallOptions(BaseModel): - """Options for configuring a call, such as timeout and abort event.""" + """Options for configuring a call. + + Attributes: + timeout (float | None): Timeout for the call in seconds. If None, no timeout is applied. + abort_event (asyncio.Event | None): Event to abort the call. If set, the call can be cancelled by setting this event. + """ model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/connect/client.py b/src/connect/client.py index b11c2b2..11b53df 100644 --- a/src/connect/client.py +++ b/src/connect/client.py @@ -1,7 +1,4 @@ -"""Provide the Client and ClientConfig classes for making unary calls. - -These classes allow making unary calls to a specified URL with given request and response types. -""" +"""Provides the main client implementation for making Connect protocol RPCs.""" import contextlib from collections.abc import AsyncGenerator, Awaitable, Callable @@ -36,17 +33,16 @@ def parse_request_url(raw_url: str) -> URL: - """Parse the given raw URL string and returns a URL object. + """Parses and validates a request URL. Args: - raw_url (str): The raw URL string to be parsed. + raw_url: The URL string to parse. Returns: - URL: The parsed URL object. + A validated URL object. Raises: - ConnectError: If the URL does not have a valid scheme (http or https). - + ConnectError: If the URL is missing a valid scheme (http or https). """ url = URL(raw_url) @@ -60,21 +56,23 @@ def parse_request_url(raw_url: str) -> URL: class ClientConfig: - """Configuration class for a client. - - Attributes: - url (URL): The URL of the client. - protocol (ProtocolConnect): The protocol used for connection. - procedure (str): The procedure path derived from the URL. - codec (Codec): The codec used for encoding/decoding. - request_compression_name (str | None): The name of the request compression method. - compressions (list[Compression]): List of compression methods. - descriptor (Any): The descriptor for the client. - idempotency_level (IdempotencyLevel): The idempotency level of the client. - compress_min_bytes (int): Minimum bytes for compression. - read_max_bytes (int): Maximum bytes to read. - send_max_bytes (int): Maximum bytes to send. - + """Configuration for a Connect client. + + This class holds all the configuration required to make a request with the client, + parsing the raw URL and client options into a structured format. + + url (URL): The parsed request URL. + protocol (Protocol): The protocol implementation to use (e.g., Connect, gRPC, gRPC-Web). + procedure (str): The full procedure string, including the service and method name (e.g., /acme.user.v1.UserService/GetUser). + codec (Codec): The codec for marshaling and unmarshaling request/response messages. + request_compression_name (str | None): The name of the compression algorithm to use for requests. + compressions (list[Compression]): A list of supported compression implementations. + descriptor (Any): The protobuf descriptor for the service. + idempotency_level (IdempotencyLevel): The idempotency level for the procedure. + compress_min_bytes (int): The minimum message size in bytes to be eligible for compression. + read_max_bytes (int): The maximum number of bytes to read from a response. + send_max_bytes (int): The maximum number of bytes to send in a request. + enable_get (bool): Whether to enable GET requests for idempotent procedures. """ url: URL @@ -91,25 +89,19 @@ class ClientConfig: enable_get: bool def __init__(self, raw_url: str, options: ClientOptions): - """Initialize the client with the given URL and options. + """Initializes a new client instance. + + This method configures the client based on the provided URL and options. + It sets up the protocol (Connect, gRPC, or gRPC-Web), the message codec + (Protobuf binary or JSON), compression settings, and other operational + parameters. Args: - raw_url (str): The raw URL to connect to. - options (ClientOptions): The options for the client configuration. - - Attributes: - url (ParseResult): The parsed URL. - protocol (ProtocolConnect): The protocol used for connection. - procedure (str): The procedure path extracted from the URL. - codec (ProtoBinaryCodec): The codec used for encoding/decoding messages. - request_compression_name (str): The name of the request compression method. - compressions (list): The list of compression methods. - descriptor (Descriptor): The descriptor for the client. - idempotency_level (int): The idempotency level for requests. - compress_min_bytes (int): The minimum number of bytes to trigger compression. - read_max_bytes (int): The maximum number of bytes to read. - send_max_bytes (int): The maximum number of bytes to send. + raw_url (str): The full URL for the RPC endpoint. + options (ClientOptions): An object containing configuration options for the client. + Raises: + ConnectError: If an unknown compression algorithm is specified in the options. """ url = parse_request_url(raw_url) proto_path = url.path @@ -144,15 +136,17 @@ def __init__(self, raw_url: str, options: ClientOptions): self.enable_get = options.enable_get def spec(self, stream_type: StreamType) -> Spec: - """Generate a Spec object with the given stream type. + """Builds a specification for a given stream type. + + This method combines the procedure's general configuration (like procedure name, + descriptor, and idempotency level) with a specific stream type to create + a complete `Spec` object. Args: - stream_type (StreamType): The type of stream to be used in the Spec. + stream_type: The type of the stream for which to create the spec. Returns: - Spec: A Spec object initialized with the procedure, descriptor, - stream type, and idempotency level of the client. - + A `Spec` object tailored to the specified stream type. """ return Spec( procedure=self.procedure, @@ -163,21 +157,22 @@ def spec(self, stream_type: StreamType) -> Spec: class Client[T_Request, T_Response]: - """A client for making unary calls to a specified URL with given request and response types. + """A generic client for making Connect protocol RPCs. - Attributes: - config (ClientConfig): Configuration for the client. - protocol_client (ProtocolClient): The protocol client used for communication. - _call_unary (Callable[[UnaryRequest[T_Request]], Awaitable[UnaryResponse[T_Response]]]): - Internal method to handle unary calls. + This client is responsible for making unary and streaming RPCs to a Connect-compliant server. + It is initialized with a connection pool, a server URL, request and response message types, + and optional configurations. It abstracts the underlying protocol details, allowing users + to make different types of RPC calls (unary, server-stream, client-stream, bidi-stream) + through a unified interface. - Methods: - __init__(url: str, input: type[T_Request], output: type[T_Response], options: ClientOptions | None = None): - Initialize the client with the given parameters. + Type Parameters: + T_Request: The type of the request message. + T_Response: The type of the response message. - call_unary(request: UnaryRequest[T_Request]) -> UnaryResponse[T_Response]: - Make a unary call with the given request. + Attributes: + config (ClientConfig): The configuration used by the client. + protocol_client (ProtocolClient): The underlying protocol-specific client. """ config: ClientConfig @@ -194,20 +189,21 @@ def __init__( input: type[T_Request], output: type[T_Response], options: ClientOptions | None = None, - ): - """Initialize the client with the given URL, request and response types, and optional client options. - - Args: - pool (AsyncConnectionPool): The connection pool to use for making requests. - url (str): The URL of the server to connect to. - input (type[T_Request]): The type of the request object. - output (type[T_Response]): The type of the response object. - options (ClientOptions | None, optional): Optional client configuration options. Defaults to None. + ) -> None: + """Initializes a client for a specific RPC method. - Raises: - TypeError: If the request method is not ASCII encoded. - ConnectError: If the request or response type is incorrect. + This constructor sets up the necessary components for making RPC calls to a single method. + It configures the protocol client based on the provided options and prepares wrapped, + interceptor-aware functions for both unary and streaming calls. These internal call + functions handle request/response validation, header manipulation, and the actual + network communication. + Args: + pool: The asynchronous connection pool to use for HTTP requests. + url: The full URL of the RPC method. + input: The expected type of the request message object. + output: The expected type of the response message object. + options: Optional client configuration. """ options = options or ClientOptions() config = ClientConfig(url, options) @@ -312,15 +308,18 @@ async def call_stream( async def call_unary( self, request: UnaryRequest[T_Request], call_options: CallOptions | None = None ) -> UnaryResponse[T_Response]: - """Asynchronously calls a unary RPC (Remote Procedure Call) with the given request. + """Calls a unary RPC method. + + This method sends a single request to the server and receives a single + response. It is a simple request-response pattern. Args: - request (UnaryRequest[T_Request]): The request object containing the data to be sent to the server. - call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. + request: The unary request object containing the message to be sent. + call_options: Optional configuration for the call, such as timeouts + or metadata. Returns: - UnaryResponse[T_Response]: The response object containing the data received from the server. - + An awaitable that resolves to the unary response from the server. """ return await self._call_unary(request, call_options) @@ -328,25 +327,17 @@ async def call_unary( async def call_server_stream( self, request: StreamRequest[T_Request], call_options: CallOptions | None = None ) -> AsyncGenerator[StreamResponse[T_Response]]: - """Initiate a server-streaming RPC call and returns an asynchronous generator that yields responses from the server. + """Calls a server-streaming RPC. Args: - request (StreamRequest[T_Request]): The request object containing the - data to be sent to the server. - call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. + request (StreamRequest[T_Request]): The request object for the RPC. + call_options (CallOptions | None): Optional call options for the RPC. Yields: - StreamResponse[T_Response]: The response objects received from the server. - - Raises: - Any exceptions that occur during the streaming process. - - Notes: - - This method ensures that the response stream is properly closed - after the generator is exhausted or an exception occurs. - - The type parameters `T_Request` and `T_Response` represent the - request and response types, respectively. - + AsyncGenerator[StreamResponse[T_Response]]: An asynchronous generator that yields + the response stream object. The caller is responsible for iterating over this + object to receive messages from the server. The stream is automatically closed + when the generator context is exited. """ response = await self._call_stream(StreamType.ServerStream, request, call_options) try: @@ -358,24 +349,26 @@ async def call_server_stream( async def call_client_stream( self, request: StreamRequest[T_Request], call_options: CallOptions | None = None ) -> AsyncGenerator[StreamResponse[T_Response]]: - """Initiate a client-streaming RPC call and returns an asynchronous generator for streaming responses from the server. - - Args: - request (StreamRequest[T_Request]): The request object containing the - client-streaming data to be sent to the server. - call_options (CallOptions | None, optional): Optional call options for the request. Defaults to None. + """Initiates a client-streaming RPC. - Yields: - StreamResponse[T_Response]: An asynchronous generator that yields - responses from the server. + In a client-streaming RPC, the client sends a sequence of messages to the + server using a provided stream. Once the client has finished writing the + messages, it waits for the server to read them and return a single response. - Raises: - Any exceptions raised during the streaming call. + This method returns an async generator that yields a single `StreamResponse` + object. The generator pattern is used to ensure that the underlying stream + is properly closed after use. You should use this method in an `async for` + loop to correctly manage the stream's lifecycle. - Notes: - - The `response.aclose()` method is called in the `finally` block to - ensure proper cleanup of the response stream. + Args: + request: The `StreamRequest` object, which includes the RPC method + and an async iterable of request messages to be sent. + call_options: Optional configuration for the call, such as timeout + or metadata. + Yields: + A single `StreamResponse` object that can be used to receive the + server's final response message. """ response = await self._call_stream(StreamType.ClientStream, request, call_options) try: @@ -387,27 +380,23 @@ async def call_client_stream( async def call_bidi_stream( self, request: StreamRequest[T_Request], call_options: CallOptions | None = None ) -> AsyncGenerator[StreamResponse[T_Response]]: - """Initiate a bidirectional streaming call with the server. + """Calls a bidirectional streaming method. + + This method initiates a bidirectional streaming call and returns an async generator + that yields a single `StreamResponse` object. The caller is then responsible for + iterating over the yielded `StreamResponse` to receive the response messages from + the server. - This method sends a stream request to the server and returns an asynchronous - generator that yields stream responses from the server. The connection is - automatically closed when the generator is exhausted or an exception occurs. + The stream is automatically closed when the context manager exits. Args: - request (StreamRequest[T_Request]): The stream request object containing - the data to be sent to the server. - call_options (CallOptions | None, optional): Optional call options for + request (StreamRequest[T_Request]): The request object, containing the + method details and an async iterable of request messages. + call_options (CallOptions | None): Optional call-specific configurations. Yields: - StreamResponse[T_Response]: The stream response object received from the server. - - Raises: - Any exceptions raised during the streaming call. - - Notes: - Ensure to consume the generator properly to avoid resource leaks, as the - connection is closed in the `finally` block. - + StreamResponse[T_Response]: An async iterable response object that can be + iterated over to receive messages from the server. """ response = await self._call_stream(StreamType.BiDiStream, request, call_options) try: diff --git a/src/connect/client_interceptor.py b/src/connect/client_interceptor.py index 70f1536..6d19f2b 100644 --- a/src/connect/client_interceptor.py +++ b/src/connect/client_interceptor.py @@ -1,4 +1,4 @@ -"""Defines interceptors and request/response classes for unary and streaming RPC calls.""" +"""Client-side interceptor utilities for modifying RPC behavior in Connect Python.""" import inspect from collections.abc import Awaitable, Callable @@ -12,24 +12,47 @@ class ClientInterceptor: - """Abstract base class for interceptors that can wrap unary functions.""" + """A client-side interceptor for modifying RPC behavior. + + Interceptors are a powerful mechanism for implementing cross-cutting concerns + like logging, metrics, authentication, and retries. They can inspect and + modify requests and responses for both unary and streaming RPCs. + + To create an interceptor, you can instantiate this class directly, providing + wrapper functions for `wrap_unary` and/or `wrap_stream`. These wrappers are + higher-order functions that take an RPC handler and return a new, wrapped + handler. The wrapped handler can then execute logic before and/or after + invoking the original RPC. + + Attributes: + wrap_unary (Callable[[UnaryFunc], UnaryFunc] | None): A function that + takes a unary RPC handler and returns a new unary RPC handler. + The returned handler is responsible for invoking the original + handler. This is used to intercept unary (request-response) RPCs. + wrap_stream (Callable[[StreamFunc], StreamFunc] | None): A function + that takes a streaming RPC handler and returns a new streaming RPC + handler. The returned handler is responsible for invoking the + original handler. This is used to intercept streaming RPCs. + """ wrap_unary: Callable[[UnaryFunc], UnaryFunc] | None = None wrap_stream: Callable[[StreamFunc], StreamFunc] | None = None def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: - """Determine if the given function is a unary function. + """Type guard to determine if a function is a UnaryFunc. - A unary function is defined as a callable that takes a single parameter - whose type annotation has an origin of `UnaryRequest`. + This function inspects the signature of the provided callable `next` + to determine if it matches the expected signature of a unary RPC handler. + A function is considered a `UnaryFunc` if it is callable, accepts exactly + two parameters, and its first parameter is type-hinted as a `UnaryRequest`. Args: - next (UnaryFunc | StreamFunc): The function to be checked. + next: The function to inspect, which can be either a UnaryFunc or a StreamFunc. Returns: - TypeGuard[UnaryFunc]: True if the function is a unary function, False otherwise. - + True if the function signature matches `UnaryFunc`, False otherwise. + This allows type checkers to narrow the type of `next` to `UnaryFunc`. """ signature = inspect.signature(next) parameters = list(signature.parameters.values()) @@ -41,17 +64,16 @@ def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: def is_stream_func(next: UnaryFunc | StreamFunc) -> TypeGuard[StreamFunc]: - """Determine if the given function is a StreamFunc. + """Determines if the given function is a stream function. - This function checks if the provided function `next` is callable, has exactly one parameter, - and if the annotation of that parameter has an origin of `StreamRequest`. + A stream function is identified by being callable, having exactly two parameters, + and the first parameter's annotation having an `__origin__` attribute equal to `StreamRequest`. Args: - next (UnaryFunc | StreamFunc): The function to be checked. + next (UnaryFunc | StreamFunc): The function to check. Returns: - TypeGuard[StreamFunc]: True if `next` is a StreamFunc, False otherwise. - + TypeGuard[StreamFunc]: True if the function is a stream function, False otherwise. """ signature = inspect.signature(next) parameters = list(signature.parameters.values()) @@ -73,20 +95,21 @@ def apply_interceptors(next: StreamFunc, interceptors: list[ClientInterceptor] | def apply_interceptors( next: UnaryFunc | StreamFunc, interceptors: list[ClientInterceptor] | None ) -> UnaryFunc | StreamFunc: - """Apply a list of interceptors to a given function. + """Applies a list of client interceptors to a unary or stream function. + + This function wraps the provided `next` function (either unary or stream) with the corresponding + interceptor wrappers from the `interceptors` list. If an interceptor does not provide a wrapper + for the function type, the wrapping process stops at that interceptor. Args: - next (UnaryFunc | StreamFunc): The function to which interceptors will be applied. - It can be either a unary function or a stream function. - interceptors (list[Interceptor] | None): A list of interceptors to apply. If None, the original function is returned. + next (UnaryFunc | StreamFunc): The function to be wrapped by interceptors. Can be either a unary or stream function. + interceptors (list[ClientInterceptor] | None): A list of client interceptors to apply. If None, the original function is returned. Returns: - UnaryFunc | StreamFunc: The function wrapped with the provided interceptors. + UnaryFunc | StreamFunc: The wrapped function with all applicable interceptors applied. Raises: - ValueError: If an interceptor does not implement the required wrap method for the function type, - or if the provided function type is invalid. - + ValueError: If the provided function is neither a unary nor a stream function. """ if interceptors is None: return next diff --git a/src/connect/code.py b/src/connect/code.py index d4680a8..6b495b9 100644 --- a/src/connect/code.py +++ b/src/connect/code.py @@ -27,7 +27,6 @@ class Code(enum.IntEnum): UNAVAILABLE (int): The service is currently unavailable. DATA_LOSS (int): Unrecoverable data loss or corruption. UNAUTHENTICATED (int): The request does not have valid authentication credentials for the operation. - """ CANCELED = 1 @@ -90,7 +89,6 @@ def string(self) -> str: - "data_loss" - "unauthenticated" - "code_{self}": For any other value not explicitly matched. - """ match self: case Code.CANCELED: diff --git a/src/connect/codec.py b/src/connect/codec.py index 84d2932..4f8ae68 100644 --- a/src/connect/codec.py +++ b/src/connect/codec.py @@ -9,13 +9,16 @@ class CodecNameType: - """CodecNameType is a class that defines constants for different codec types. + """Defines the standard codec names used in the Connect Protocol. - Attributes: - PROTO (str): Represents the "proto" codec type. - JSON (str): Represents the "json" codec type. - JSON_CHARSET_UTF8 (str): Represents the "json; charset=utf-8" codec type. + These names are used in the `Content-Type` header to specify the + serialization format of the request and response bodies. + Attributes: + PROTO: The codec name for Protocol Buffers binary format ("proto"). + JSON: The codec name for JSON format ("json"). + JSON_CHARSET_UTF8: The codec name for JSON format with UTF-8 charset + ("json; charset=utf-8"). """ PROTO = "proto" @@ -24,122 +27,135 @@ class CodecNameType: class Codec(abc.ABC): - """Abstract base class for codecs. + """Defines the interface for a message codec. - This class defines the interface for codecs that can serialize and deserialize - protobuf messages. Subclasses must implement the following methods. + A Codec is responsible for serializing (marshaling) and deserializing (unmarshaling) + messages between their Python object representation and their wire format as bytes. + This is an abstract base class. Subclasses must implement the `name`, `marshal`, + and `unmarshal` methods to provide a concrete implementation for a specific + serialization format, such as Protocol Buffers or JSON. """ @property @abc.abstractmethod def name(self) -> str: - """Return the name of the codec. + """Returns the name of the codec. - Returns: - str: The name of the codec. + This is an abstract method that must be implemented by subclasses. + Returns: + The name of the codec. """ raise NotImplementedError() @abc.abstractmethod def marshal(self, message: Any) -> bytes: - """Serialize a protobuf message to bytes. + """Marshal a message into bytes. Args: - message (Any): The protobuf message to serialize. + message: The message to marshal. Returns: - bytes: The serialized message as bytes. - - Raises: - ValueError: If the message is not a protobuf message. - + The marshaled message as bytes. """ raise NotImplementedError() @abc.abstractmethod def unmarshal(self, data: bytes, message: Any) -> Any: - """Unmarshals the given byte data into the specified message format. + """Unmarshals binary data into a message object. - Args: - data (bytes): The byte data to be unmarshaled. - message (Any): The message format to unmarshal the data into. + This method must be implemented by subclasses to define how to + deserialize a byte string into a given message structure. - Returns: - Any: The unmarshaled message. + Args: + data (bytes): The raw binary data to be deserialized. + message (Any): The target message object to populate with the + deserialized data. Raises: - NotImplementedError: This method should be implemented by subclasses. + NotImplementedError: This method is not implemented in the base class + and must be overridden in a subclass. + Returns: + Any: The populated message object. """ raise NotImplementedError() class StableCodec(Codec): - """StableCodec is an abstract base class that defines the interface for codecs. + """Abstract base class for codecs that provide a stable byte representation. - This class can marshal messages into a stable binary format. + This class defines the interface for codecs that can serialize messages into + a canonical, stable byte format. This is useful for scenarios like signing + messages, where the byte representation must be consistent across different + systems and executions. """ @abc.abstractmethod def marshal_stable(self, message: Any) -> bytes: - """Serialize the given message into a stable byte representation. + """Marshals a message into a stable byte representation. + + "Stable" means that the marshaling is deterministic: given the same + message, the same bytes will be returned. This is important for + use cases like cryptographic signing, where the exact byte sequence + is critical. Args: - message (Any): The message to be serialized. + message: The message to be marshaled. Returns: - bytes: The serialized byte representation of the message. + The stable byte representation of the message. Raises: - NotImplementedError: This method must be implemented by subclasses. - + NotImplementedError: This is an abstract method and must be + implemented by a subclass. """ raise NotImplementedError() @abc.abstractmethod def is_binary(self) -> bool: - """Determine if the codec is binary. - - This method should be implemented by subclasses to indicate whether the codec - handles binary data. + """Checks if the codec handles binary data. Returns: bool: True if the codec is binary, False otherwise. - - Raises: - NotImplementedError: If the method is not implemented by a subclass. - """ raise NotImplementedError() class ProtoBinaryCodec(StableCodec): - """ProtoBinaryCodec is a codec for serializing and deserializing protobuf messages.""" + """Codec for handling Protocol Buffers (protobuf) messages. + + This class implements the StableCodec interface to provide serialization + and deserialization for protobuf messages. It converts protobuf message + objects into byte strings and vice versa. + + The `marshal_stable` method provides a deterministic serialization, but it's + important to note that protobuf's deterministic output is not guaranteed + to be consistent across different library implementations or versions, + especially when unknown fields are present. + """ @property def name(self) -> str: - """Return the name of the codec. + """Returns the name of the codec. Returns: - str: The name of the codec, which is 'PROTO'. - + The name of the codec. """ return CodecNameType.PROTO def marshal(self, message: Any) -> bytes: - """Serialize a protobuf message to a byte string. + """Serializes a protobuf message into a byte string. Args: - message (Any): The protobuf message to serialize. + message: The protobuf message to serialize. Returns: - bytes: The serialized byte string of the protobuf message. + The serialized message as a byte string. Raises: - ValueError: If the provided message is not an instance of google.protobuf.message.Message. - + ValueError: If the provided message is not a protobuf message. """ if not isinstance(message, google.protobuf.message.Message): raise ValueError("Data is not a protobuf message") @@ -147,18 +163,17 @@ def marshal(self, message: Any) -> bytes: return message.SerializeToString() def unmarshal(self, data: bytes, message: Any) -> Any: - """Unmarshals the given byte data into a protobuf message. + """Unmarshals bytes into a protobuf message. Args: - data (bytes): The byte data to be unmarshaled. - message (Any): The protobuf message class to unmarshal the data into. + data: The bytes to unmarshal. + message: The protobuf message type to unmarshal into. Returns: - Any: The unmarshaled protobuf message object. + An instance of the message class populated with the given data. Raises: - ValueError: If the provided message is not a protobuf message. - + ValueError: If the given message is not a protobuf message type. """ obj = message() if not isinstance(obj, google.protobuf.message.Message): @@ -168,24 +183,21 @@ def unmarshal(self, data: bytes, message: Any) -> Any: return obj def marshal_stable(self, message: Any) -> bytes: - """Serialize a given protobuf message to a deterministic byte string. + """Serializes a protobuf message into a deterministic byte string. - Protobuf does not offer a canonical output today, so this format is not - guaranteed to match deterministic output from other protobuf libraries. - In addition, unknown fields may cause inconsistent output for otherwise - equal messages. - https://github.com/golang/protobuf/issues/1121 + This method ensures that serializing the same message multiple times + will produce the exact same byte string. This is useful for applications + requiring a stable binary representation, such as cryptographic signing + or hashing. Args: - message (Any): The protobuf message to be serialized. It must be an - instance of `google.protobuf.message.Message`. + message: The protobuf message to serialize. Returns: - bytes: The serialized byte string representation of the protobuf message. + A byte string representing the deterministically serialized message. Raises: - ValueError: If the provided message is not a protobuf message. - + ValueError: If the provided data is not a protobuf message. """ if not isinstance(message, google.protobuf.message.Message): raise ValueError("Data is not a protobuf message") @@ -193,56 +205,55 @@ def marshal_stable(self, message: Any) -> bytes: return message.SerializeToString(deterministic=True) def is_binary(self) -> bool: - """Check if the codec is binary. - - Returns: - bool: Always returns True indicating the codec is binary. - - """ + """Check if the codec handles binary data.""" return True class ProtoJSONCodec(StableCodec): - """A codec for encoding and decoding Protocol Buffers messages to and from JSON format. + """A codec for serializing and deserializing protobuf messages to and from JSON. - Attributes: - _name (str): The name of the codec. + This class implements the StableCodec interface to handle conversions between + protobuf message objects and their JSON string representation. It leverages the + `google.protobuf.json_format` library for the core conversion logic. + + The `marshal` and `unmarshal` methods provide standard serialization and + deserialization. The `marshal_stable` method ensures a deterministic output + by re-parsing the generated JSON and re-serializing it with compact + separators. This guarantees a consistent byte representation, which is crucial + for operations like request signing. + This is a text-based codec, and as such, `is_binary()` will always return False. + + Attributes: + name (str): The name of the codec (e.g., "json"). """ _name: str def __init__(self, name: str) -> None: - """Initialize the codec with a given name. + """Initializes the codec. Args: - name (str): The name to initialize the codec with. - + name: The name of the codec. """ self._name = name @property def name(self) -> str: - """Return the name of the codec. - - Returns: - str: The name of the codec. - - """ + """The name of the codec, e.g. "proto", "json".""" return self._name def marshal(self, message: Any) -> bytes: - """Serialize a protobuf message to a JSON string encoded as UTF-8 bytes. + """Marshals a protobuf message to its JSON representation. Args: - message (Any): The protobuf message to be serialized. Must be an instance of google.protobuf.message.Message. + message: The protobuf message to marshal. Returns: - bytes: The serialized JSON string encoded as UTF-8 bytes. + The JSON representation of the message, encoded as bytes. Raises: ValueError: If the provided message is not a protobuf message. - """ if not isinstance(message, google.protobuf.message.Message): raise ValueError("Data is not a protobuf message") @@ -252,18 +263,22 @@ def marshal(self, message: Any) -> bytes: return json_str.encode() def unmarshal(self, data: bytes, message: Any) -> Any: - """Unmarshal the given byte data into a protobuf message. + """Unmarshals byte data into a protobuf message. + + This method decodes the byte data as a UTF-8 string and then parses it + as JSON into the provided protobuf message type. Args: - data (bytes): The byte data to unmarshal. - message (Any): The protobuf message class to unmarshal the data into. + data: The byte-encoded JSON data to unmarshal. + message: The protobuf message class to instantiate and populate. Returns: - Any: The unmarshaled protobuf message instance. + An instance of the provided protobuf message class, populated with + the data from the JSON. Raises: - ValueError: If the provided message is not a protobuf message. - + ValueError: If the `message` argument is not a protobuf message class. + google.protobuf.json_format.ParseError: If the data is not valid JSON. """ obj = message() if not isinstance(obj, google.protobuf.message.Message): @@ -272,23 +287,24 @@ def unmarshal(self, data: bytes, message: Any) -> Any: return json_format.Parse(data.decode(), obj, ignore_unknown_fields=True) def marshal_stable(self, message: Any) -> bytes: - """Serialize a protobuf message to a JSON string encoded as UTF-8 bytes in a deterministic way. + """Marshals a protobuf message into a stable, compact JSON byte string. - protojson does not offer a "deterministic" field ordering, but fields - are still ordered consistently by their index. However, protojson can - output inconsistent whitespace for some reason, therefore it is - suggested to use a formatter to ensure consistent formatting. - https://github.com/golang/protobuf/issues/1373 + This method provides a deterministic way to serialize a protobuf message + to a compact JSON format. It converts the message to a JSON string, + then re-serializes it to remove whitespace and ensure a consistent + output. The resulting byte string is stable, meaning the same message + will always produce the exact same byte output. Args: - message (Any): The protobuf message to be serialized. Must be an instance of google.protobuf.message.Message. + message (Any): The protobuf message to be marshaled. Although typed as + Any, this must be an instance of `google.protobuf.message.Message`. Returns: - bytes: The serialized JSON string encoded as UTF-8 bytes. + bytes: A compact, stable JSON representation of the message, encoded + as a byte string. Raises: - ValueError: If the provided message is not a protobuf message. - + ValueError: If the input `message` is not a protobuf message. """ if not isinstance(message, google.protobuf.message.Message): raise ValueError("Data is not a protobuf message") @@ -301,104 +317,110 @@ def marshal_stable(self, message: Any) -> bytes: return compacted_json.encode() def is_binary(self) -> bool: - """Determine if the codec is binary. + """Check if the codec handles binary data. Returns: - bool: Always returns False, indicating the codec is not binary. - + bool: Always False, indicating this codec does not handle binary data. """ return False class ReadOnlyCodecs(abc.ABC): - """Abstract base class for read-only codecs. + """Defines the interface for a read-only collection of codecs. - This class defines the interface for read-only codecs, which are responsible for - encoding and decoding data. Implementations of this class must provide concrete - implementations for the following methods. + This abstract base class provides a standard way to retrieve registered codecs + by name and to list the names of all available codecs. It is designed to be + a non-modifiable view of the available encoding and decoding mechanisms. + Subclasses must implement the `get`, `protobuf`, and `names` methods to provide + concrete functionality. """ @abc.abstractmethod def get(self, name: str) -> Codec | None: - """Retrieve a codec by its name. + """Gets a codec by its name. Args: - name (str): The name of the codec to retrieve. + name: The name of the codec to retrieve. Returns: - Codec: The codec associated with the given name. - + The codec instance if found, otherwise None. """ raise NotImplementedError() @abc.abstractmethod def protobuf(self) -> Codec | None: - """Encode data using the Protocol Buffers (protobuf) codec. + """Returns the Protobuf codec, if available. - Returns: - Codec: An instance of the Codec class configured for protobuf encoding. + This method should be implemented by subclasses to provide a codec + for handling Protobuf-encoded messages. + Raises: + NotImplementedError: If the method is not implemented by a subclass. + + Returns: + Codec | None: A Codec instance for Protobuf, or None if not supported. """ raise NotImplementedError() @abc.abstractmethod def names(self) -> list[str]: - """Return a list of names. + """Get the names of the supported codecs. - Returns: - list[str]: A list of names as strings. + This is an abstract method. Subclasses should implement this to return the + names of the codecs they support (e.g., "gzip", "identity"). + Returns: + A list of strings, where each string is a supported codec name. """ raise NotImplementedError() class CodecMap(ReadOnlyCodecs): - """CodecMap is a class that provides a mapping from codec names to their corresponding Codec objects. It extends the ReadOnlyCodecs class.""" + """Manages a collection of codecs, mapping their names to Codec instances. + + This class provides a way to store and retrieve different encoding/decoding + mechanisms (codecs) used in communication protocols. It allows looking up a + specific codec by its registered name. + + Attributes: + name_to_codec (dict[str, Codec]): A dictionary where keys are codec names + and values are the corresponding Codec objects. + """ name_to_codec: dict[str, Codec] def __init__(self, name_to_codec: dict[str, Codec]) -> None: - """Initialize the codec mapping. + """Initializes the codec registry. Args: - name_to_codec (dict[str, Codec]): A dictionary mapping codec names to their corresponding Codec objects. - + name_to_codec: A dictionary mapping codec names to Codec instances. """ self.name_to_codec = name_to_codec def get(self, name: str) -> Codec | None: - """Retrieve a codec by its name. + """Gets a codec by its registered name. Args: - name (str): The name of the codec to retrieve. + name: The name of the codec to retrieve. Returns: - Codec: The codec associated with the given name. - - Raises: - KeyError: If the codec with the specified name does not exist. - + The codec instance if found, otherwise None. """ return self.name_to_codec.get(name) def protobuf(self) -> Codec | None: - """Return the Codec instance associated with Protocol Buffers (PROTO). - - This method retrieves the Codec instance that corresponds to the - Protocol Buffers (PROTO) codec type. + """A convenience method for retrieving the protobuf codec. Returns: - Codec: The Codec instance for Protocol Buffers. - + The protobuf codec, or None if it is not available. """ return self.get(CodecNameType.PROTO) def names(self) -> list[str]: - """Retrieve a list of codec names. + """Get the names of all registered codecs. Returns: - list[str]: A list of codec names. - + A list of the registered codec names. """ return list(self.name_to_codec.keys()) diff --git a/src/connect/compression.py b/src/connect/compression.py index 4856cb9..5cca446 100644 --- a/src/connect/compression.py +++ b/src/connect/compression.py @@ -9,80 +9,94 @@ class Compression(abc.ABC): - """Abstract base class for compression algorithms. - - This class defines the interface for compression and decompression methods - that must be implemented by any concrete compression class. + """Abstract base class for defining compression and decompression logic. + This class provides a standard interface for different compression algorithms + used in Connect. Subclasses are expected to implement the `name` property, + and the `compress` and `decompress` methods. """ @property @abc.abstractmethod def name(self) -> str: - """Return the name of the compression algorithm. + """Gets the name of the compression algorithm. - Returns: - str: The name of the compression algorithm. + This is an abstract method that must be implemented by subclasses. + Raises: + NotImplementedError: This method is not implemented in the base class. + + Returns: + The name of the compression algorithm. """ raise NotImplementedError() @abc.abstractmethod def compress(self, data: bytes) -> bytes: - """Compress the given data using a specified compression algorithm. + """Compresses the given data. + + This is an abstract method that must be implemented by a subclass. Args: - data (bytes): The data to be compressed. + data: The bytes to be compressed. Returns: - bytes: The compressed data. - + The compressed data as bytes. """ raise NotImplementedError() @abc.abstractmethod def decompress(self, data: bytes, read_max_bytes: int) -> bytes: - """Decompress the given data. + """Decompresses the given data. Args: - data (bytes): The compressed data to be decompressed. - read_max_bytes (int): The maximum number of bytes to read from the decompressed data. + data: The compressed byte string. + read_max_bytes: The maximum number of bytes to read from the + decompressed data. This is a safeguard against decompression + bombs. Returns: - bytes: The decompressed data. + The decompressed byte string. + Raises: + NotImplementedError: This method must be implemented by a subclass. """ raise NotImplementedError() class GZipCompression(Compression): - """A class to handle GZip compression and decompression.""" + """Handles data compression and decompression using the GZip algorithm. + + This class implements the `Compression` interface, providing methods to compress + and decompress byte data using the standard GZip format. + + Attributes: + name (str): The identifier for this compression method, 'gzip'. + """ _name: str def __init__(self) -> None: - """Initialize the compression object with the default compression method.""" + """Initializes the compression algorithm with the GZIP name.""" self._name = COMPRESSION_GZIP @property def name(self) -> str: - """Return the name attribute of the object. + """The name of the compression algorithm. Returns: - str: The name attribute. - + The name of the compression algorithm. """ return self._name def compress(self, data: bytes) -> bytes: - """Compress the given data using gzip compression. + """Compresses data using gzip. Args: - data (bytes): The data to be compressed. + data: The bytes to be compressed. Returns: - bytes: The compressed data. - + The gzip-compressed data as bytes. """ buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, mode="wb") as f: @@ -91,16 +105,16 @@ def compress(self, data: bytes) -> bytes: return buf.getvalue() def decompress(self, data: bytes, read_max_bytes: int) -> bytes: - """Decompress the given gzip-compressed data. + """Decompresses a gzip-compressed byte string. Args: - data (bytes): The gzip-compressed data to decompress. - read_max_bytes (int): The maximum number of bytes to read from the decompressed data. - If read_max_bytes is less than or equal to 0, all decompressed data will be read. + data: The compressed byte string to decompress. + read_max_bytes: The maximum number of bytes to read from the + decompressed stream. If this value is zero or negative, the + entire stream is read. Returns: - bytes: The decompressed data. - + The decompressed data as a byte string. """ read_max_bytes = read_max_bytes if read_max_bytes > 0 else -1 @@ -112,15 +126,14 @@ def decompress(self, data: bytes, read_max_bytes: int) -> bytes: def get_compression_from_name(name: str | None, compressions: list[Compression]) -> Compression | None: - """Retrieve a Compression object from a list of compressions by its name. + """Finds a compression algorithm by its name from a list of available compressions. Args: - name (str): The name of the compression to retrieve. - compressions (list[Compression]): A list of Compression objects to search through. + name: The name of the compression algorithm to search for. + compressions: A list of available `Compression` objects. Returns: - Compression | None: The Compression object with the matching name, or None if not found. - + The matching `Compression` object if found, otherwise `None`. """ compression = ( next( diff --git a/src/connect/connect.py b/src/connect/connect.py index bfb35b2..ec6b6c3 100644 --- a/src/connect/connect.py +++ b/src/connect/connect.py @@ -1,4 +1,4 @@ -"""Defines the streaming handler connection interfaces and related utilities.""" +"""Core components and abstractions for the Connect protocol.""" import abc import asyncio @@ -18,7 +18,21 @@ class StreamType(Enum): - """Enum for the type of stream.""" + """Enumeration of the different types of RPC streams. + + This enum categorizes the communication patterns between a client and a server + for a specific RPC method, mirroring the concepts found in frameworks like gRPC. + + Attributes: + Unary: A simple RPC where the client sends a single request and receives a + single response. + ClientStream: An RPC where the client sends a stream of messages and the + server sends back a single response. + ServerStream: An RPC where the client sends a single request and receives a + stream of messages in response. + BiDiStream: An RPC where both the client and the server send a stream of + messages to each other. + """ Unary = "Unary" ClientStream = "ClientStream" @@ -27,7 +41,14 @@ class StreamType(Enum): class Spec(BaseModel): - """Spec class.""" + """Defines the specification for a remote procedure call. + + Args: + procedure: The fully qualified name of the procedure to be called. + descriptor: The descriptor for the procedure, which may contain schema or other metadata. + stream_type: The streaming behavior of the procedure. + idempotency_level: The idempotency level, which determines how the procedure handles retries. + """ procedure: str descriptor: Any @@ -36,14 +57,25 @@ class Spec(BaseModel): class Address(BaseModel): - """Address class.""" + """Represents a network address, consisting of a host and a port. + + Attributes: + host (str): The hostname or IP address. + port (int): The port number. + """ host: str port: int class Peer(BaseModel): - """Peer class.""" + """Represents a peer in the network. + + Attributes: + address: The network address of the peer. + protocol: The communication protocol used by the peer (e.g., 'http', 'ws'). + query: A mapping of query parameters for the peer connection. + """ address: Address | None protocol: str @@ -51,22 +83,18 @@ class Peer(BaseModel): class RequestCommon: - """A common base class for handling request-related functionality. + """Represents the common context for a Connect RPC request or response. - This class encapsulates the common properties and behaviors shared across - different types of requests, including specification details, peer information, - headers, and HTTP method configuration. + This class encapsulates information that is shared between requests and responses + in the Connect protocol, such as the RPC specification, peer details, HTTP + headers, and the HTTP method used. Attributes: - _spec (Spec): The specification for the request containing procedure details, - descriptor, stream type, and idempotency level. - _peer (Peer): The peer information including address, protocol, and query parameters. - _headers (Headers): The request headers as a collection of key-value pairs. - _method (str): The HTTP method used for the request (defaults to POST). - - The class provides property accessors for all attributes with appropriate getters - and setters where modification is allowed. Default values are provided for all - parameters during initialization to ensure the object is always in a valid state. + spec (Spec): The RPC specification, including procedure name, stream type, + and idempotency level. + peer (Peer): Information about the network peer, such as address and protocol. + headers (Headers): The HTTP headers associated with the request or response. + method (str): The HTTP method used for the request (e.g., 'POST'). """ _spec: Spec @@ -81,19 +109,17 @@ def __init__( headers: Headers | None = None, method: str | None = None, ) -> None: - """Initialize a Connect request/response context. + """Initializes the RPC context. Args: - spec: The RPC specification containing procedure name, descriptor, stream type, - and idempotency level. If None, creates a default Spec with empty procedure, - no descriptor, unary stream type, and idempotent level. - peer: The peer information including address, protocol, and query parameters. - If None, creates a default Peer with no address, empty protocol, and empty query. - headers: HTTP headers for the request/response. If None, creates an empty Headers object. - method: HTTP method to use for the request. If None, defaults to POST. - - Returns: - None + spec (Spec | None, optional): The specification for the RPC. + If None, a default Spec is created. Defaults to None. + peer (Peer | None, optional): Information about the network peer. + If None, a default Peer is created. Defaults to None. + headers (Headers | None, optional): The request headers. + If None, an empty Headers object is created. Defaults to None. + method (str | None, optional): The HTTP method of the request. + Defaults to POST. """ self._spec = ( spec @@ -111,46 +137,87 @@ def __init__( @property def spec(self) -> Spec: - """Return the request specification.""" + """Gets the service specification. + + Returns: + Spec: The service specification object. + """ return self._spec @spec.setter def spec(self, value: Spec) -> None: - """Set the request specification.""" + """Sets the specification for the Connect instance. + + Args: + value: The specification object. + """ self._spec = value @property def peer(self) -> Peer: - """Return the request peer.""" + """Gets the peer object for this connection. + + Returns: + Peer: The peer object. + """ return self._peer @peer.setter def peer(self, value: Peer) -> None: - """Set the request peer.""" + """Sets the peer for the connection. + + Args: + value: The Peer instance to set. + """ self._peer = value @property def headers(self) -> Headers: - """Return the request headers.""" + """Gets the headers for the message. + + Returns: + The headers for the message. + """ return self._headers @property def method(self) -> str: - """Return the request method.""" + """Gets the method name. + + Returns: + str: The name of the method. + """ return self._method @method.setter def method(self, value: str) -> None: - """Set the request method.""" + """Sets the HTTP method for the request. + + Args: + value: The HTTP method (e.g., "GET", "POST"). + """ self._method = value class StreamRequest[T](RequestCommon): - """StreamRequest class represents a request that can handle streaming messages. + """Represents a streaming request, containing an asynchronous iterable of messages. + + This class is used for RPCs where the client sends a stream of messages, + such as client streaming or bidirectional streaming calls. It provides + access to the messages as an async iterable and helper methods to consume them. + + Type Parameters: + T: The type of the messages in the stream. - This class provides a unified interface for handling both single and multiple - messages in streaming requests. It automatically determines the appropriate - method based on the stream type and usage context. + content: The content to be processed, which can be a single item of type T + or an async iterable of items. + + Attributes: + messages (AsyncIterable[T]): The request messages as an async iterable. + spec (Spec | None): The specification for the RPC. + peer (Peer | None): Information about the remote peer. + headers (Headers | None): The request headers. + method (str | None): The HTTP method used for the request. """ _messages: AsyncIterable[T] @@ -163,53 +230,67 @@ def __init__( headers: Headers | None = None, method: str | None = None, ) -> None: - """Initialize a new instance. + """Initializes a new request. Args: - content: The content to be processed, either a single item of type T or an async iterable of items. - spec: Optional specification object defining the behavior or configuration. - peer: Optional peer object representing the connection endpoint. - headers: Optional headers dictionary for metadata or configuration. - method: Optional string specifying the method or operation type. - - Returns: - None + content: The main content of the request. Can be a single message + or an asynchronous iterable of messages. + spec: The specification for the request. Defaults to None. + peer: The peer that initiated the request. Defaults to None. + headers: The headers associated with the request. Defaults to None. + method: The method name for the request. Defaults to None. """ super().__init__(spec, peer, headers, method) self._messages = content if isinstance(content, AsyncIterable) else aiterate([content]) @property def messages(self) -> AsyncIterable[T]: - """Return the request messages as an async iterable. + """An asynchronous iterator over received messages. - Use this when you expect multiple messages (client streaming, bidi streaming). + This allows you to iterate through messages from the server as they arrive + using an ``async for`` loop. - Example: - async for message in request.messages: - process(message) + Yields: + The next available message from the connection. """ return self._messages async def single(self) -> T: - """Return a single message from the request. + """Asynchronously waits for and returns the single expected message. - Use this when you expect exactly one message (server-side handlers for client streaming). - Raises ConnectError if there are zero or multiple messages. + This method is used when exactly one message is expected from the + underlying asynchronous message source. - Example: - message = await request.single() - process(message) + Returns: + T: The one and only message received. + + Raises: + ValueError: If the number of messages received is not equal to one + (i.e., zero or more than one). """ return await ensure_single(self._messages) class UnaryRequest[T](RequestCommon): - """A unary request wrapper that extends RequestCommon functionality. + """Represents a unary (non-streaming) request. + + This class encapsulates a single request message along with its associated + metadata, such as headers and peer information. It is used for interactions + where a single request message is sent and a single response is expected. + + Type Parameters: + T: The type of the request message/content. - This class encapsulates a single message/content of type T along with common request - metadata such as specifications, peer information, headers, and HTTP method. + Attributes: + message (T): The request message or payload. + spec (Spec | None): Specification object defining behavior or configuration. + peer (Peer | None): Peer object representing the remote endpoint. + headers (Headers | None): Metadata associated with the request. + method (str | None): The RPC method being called. """ + _message: T + def __init__( self, content: T, @@ -218,34 +299,37 @@ def __init__( headers: Headers | None = None, method: str | None = None, ) -> None: - """Initialize a new instance with content and optional parameters. + """Initializes the request object. Args: - content (T): The main content/message to be stored in this instance. - spec (Spec | None, optional): Specification object defining behavior or configuration. Defaults to None. - peer (Peer | None, optional): Peer object representing the remote endpoint or connection. Defaults to None. - headers (Headers | None, optional): HTTP headers or metadata associated with the request/response. Defaults to None. - method (str | None, optional): HTTP method or operation type (e.g., 'GET', 'POST'). Defaults to None. - - Returns: - None + content (T): The content of the message. + spec (Spec | None, optional): The request specification. Defaults to None. + peer (Peer | None, optional): Information about the peer. Defaults to None. + headers (Headers | None, optional): The request headers. Defaults to None. + method (str | None, optional): The request method. Defaults to None. """ super().__init__(spec, peer, headers, method) self._message = content @property def message(self) -> T: - """Return the request message.""" + """Get the underlying message. + + Returns: + The message object. + """ return self._message class ResponseCommon: - """ResponseCommon is a class that encapsulates common response attributes such as headers and trailers. + """A base class representing common properties for all Connect response types. - Attributes: - _headers (Headers): The headers of the response. - _trailers (Headers): The trailers of the response. + This class encapsulates the headers and trailers that are common to both + unary and streaming responses. + Attributes: + headers (Headers): The response headers. + trailers (Headers): The response trailers. """ _headers: Headers @@ -256,23 +340,53 @@ def __init__( headers: Headers | None = None, trailers: Headers | None = None, ) -> None: - """Initialize the response with a message.""" + """Initializes the instance. + + Args: + headers: Optional initial headers. + trailers: Optional initial trailers. + """ self._headers = headers if headers is not None else Headers() self._trailers = trailers if trailers is not None else Headers() @property def headers(self) -> Headers: - """Return the response headers.""" + """Returns the headers for the request. + + Returns: + Headers: The headers for the request. + """ return self._headers @property def trailers(self) -> Headers: - """Return the response trailers.""" + """Returns the trailers of the response. + + Trailers are headers sent after the message body. They are only available + after the entire response body has been read. + + Returns: + Headers: The trailers. An empty Headers object if no trailers were sent. + """ return self._trailers class UnaryResponse[T](ResponseCommon): - """Response class for handling responses.""" + """Represents a unary response from a Connect RPC. + + This class encapsulates a single response message, along with its + associated headers and trailers. + + Args: + content (T): The deserialized response message. + headers (Headers | None): The response headers. + trailers (Headers | None): The response trailers. + + Attributes: + message (T): The deserialized response message. + headers (Headers): The response headers. + trailers (Headers): The response trailers. + """ _message: T @@ -282,22 +396,44 @@ def __init__( headers: Headers | None = None, trailers: Headers | None = None, ) -> None: - """Initialize the response with a message.""" + """Initializes the message object. + + Args: + content: The message content. + headers: Optional initial headers. + trailers: Optional initial trailers. + """ super().__init__(headers, trailers) self._message = content @property def message(self) -> T: - """Return the response message.""" + """Returns the message associated with the response. + + Returns: + The message of type T. + """ return self._message class StreamResponse[T](ResponseCommon): - """Response class for handling streaming responses. + """Represents a streaming response from a Connect RPC. + + This class encapsulates the response headers, trailers, and the asynchronous + stream of response messages. It is used for server-streaming and + bidirectional-streaming RPCs where the server sends multiple messages over time. - This class provides a unified interface for handling both single and multiple - messages from streaming responses. It automatically determines the appropriate - method based on the stream type and usage context. + The primary way to interact with a `StreamResponse` is to iterate over its + `messages` property to consume the stream of incoming data. For RPCs that are + expected to return exactly one message in the stream (like client-streaming), + the `single()` method can be used for convenience. + + Type Parameters: + T: The type of the messages in the response stream. + + Attributes: + headers (Headers | None): The response headers. + trailers (Headers | None): The response trailers. """ _messages: AsyncIterable[T] @@ -308,37 +444,39 @@ def __init__( headers: Headers | None = None, trailers: Headers | None = None, ) -> None: - """Initialize the response with content. + """Initializes the request. Args: - content: Either a single message or an async iterable of messages - headers: Optional response headers - trailers: Optional response trailers + content: The content of the request. Can be a single message or an async iterable of messages. + headers: The headers of the request. + trailers: The trailers of the request. """ super().__init__(headers, trailers) self._messages = content if isinstance(content, AsyncIterable) else aiterate([content]) @property def messages(self) -> AsyncIterable[T]: - """Return the response messages as an async iterable. + """An asynchronous iterator over the messages received from the server. - Use this when you expect multiple messages (server streaming, bidi streaming). + This method provides a way to consume messages from the server as they + arrive. It is intended to be used with an `async for` loop. - Example: - async for message in response.messages: - print(message) + Yields: + T: The next message received from the server. """ return self._messages async def single(self) -> T: - """Return a single message from the response. + """Asynchronously gets the single message from the stream. - Use this when you expect exactly one message (client streaming results). - Raises ConnectError if there are zero or multiple messages. + This method consumes the underlying asynchronous message stream and + ensures that it contains exactly one message. - Example: - message = await response.single() - print(message) + Returns: + The single message from the stream. + + Raises: + ValueError: If the stream is empty or contains more than one message. """ return await ensure_single(self._messages) @@ -350,24 +488,21 @@ async def aclose(self) -> None: async def ensure_single[T](iterable: AsyncIterable[T], aclose: Callable[[], Awaitable[None]] | None = None) -> T: - """Asynchronously ensures that the given async iterable yields exactly one item. + """Ensures an async iterable yields exactly one item and returns it. - Iterates over the provided async iterable (after validating its content stream) - and returns the single item if present. Raises a ConnectError if the iterable - is empty or contains more than one item. Optionally closes the iterable by calling - the provided aclose function after processing. + This is a helper function for handling unary responses in a streaming context. + It consumes the iterable to verify its cardinality. Args: - iterable (AsyncIterable[T]): An asynchronous iterable expected to yield exactly one item. - aclose (Callable[[], Awaitable[None]] | None, optional): A callable that asynchronously - closes the stream when invoked. If provided, will be called in a finally block. + iterable: The asynchronous iterable to consume. + aclose: An optional awaitable callable to be executed for cleanup + in a finally block. Returns: - T: The single item yielded by the iterable. + The single item from the iterable. Raises: - ConnectError: If the iterable yields no items or more than one item. - + ConnectError: If the iterable contains zero or more than one item. """ try: iterator = iterable.__aiter__() @@ -386,245 +521,492 @@ async def ensure_single[T](iterable: AsyncIterable[T], aclose: Callable[[], Awai class StreamingHandlerConn(abc.ABC): - """Abstract base class for a streaming handler connection. + """Abstract base class for handling streaming connections. - This class defines the interface for handling streaming connections, including - methods for specifying the connection, handling peer communication, receiving - and sending messages, and managing request and response headers and trailers. + This class defines the interface for a streaming handler, which is responsible + for managing the lifecycle of a streaming request and response. It includes + methods for sending and receiving data streams, accessing request and response + metadata (headers and trailers), and handling errors. + Concrete implementations must provide logic for all abstract methods and + properties defined in this class to facilitate a specific communication + protocol or transport layer. """ @abc.abstractmethod def parse_timeout(self) -> float | None: - """Parse the timeout value.""" + """Abstract method to parse the timeout from the configuration. + + Subclasses must implement this method to extract the timeout value + from their specific configuration source. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + + Returns: + The request timeout in seconds as a float, or None if no timeout + is configured. + """ raise NotImplementedError() @property @abc.abstractmethod def spec(self) -> Spec: - """Return the specification details. + """Returns the specification of the connector. - Returns: - Spec: The specification details. + This is an abstract method that must be implemented by subclasses. + It should return a `Spec` object that defines the connector's + metadata, capabilities, and configuration schema. + Returns: + Spec: An object containing the connector's specification. """ raise NotImplementedError() @property @abc.abstractmethod def peer(self) -> Peer: - """Establish a connection to a peer in the network. + """Gets the peer of the connection. - Returns: - Any: The result of the connection attempt. The exact type and structure - of the return value will depend on the implementation details. + This is an abstract method that must be implemented by subclasses. + Raises: + NotImplementedError: This method is not implemented. + + Returns: + Peer: An object representing the connected peer. """ raise NotImplementedError() @abc.abstractmethod def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and returns an asynchronous content stream. + """Asynchronously receive messages. + + This method is intended to be implemented by subclasses to handle the + reception of a stream of messages. It should be an asynchronous generator + that yields messages as they are received. Args: - message (Any): The message to be processed. + message: The initial message or subscription request that triggers + the stream of incoming messages. - Returns: - AsyncContentStream[Any]: An asynchronous stream of content resulting from processing the message. + Yields: + Messages of any type received from the source. Raises: - NotImplementedError: This method should be implemented by subclasses. - + NotImplementedError: This base method is not implemented and must + be overridden in a subclass. """ raise NotImplementedError() @property @abc.abstractmethod def request_headers(self) -> Headers: - """Generate and return the request headers. + """Abstract method to get request headers. - Returns: - Headers: The request headers. + Subclasses must implement this method to provide the necessary headers + for making API requests. This typically includes headers for + authentication, content type, etc. + Raises: + NotImplementedError: If the method is not overridden in a subclass. + + Returns: + Headers: A dictionary-like object representing the HTTP headers. """ raise NotImplementedError() @abc.abstractmethod async def send(self, messages: AsyncIterable[Any]) -> None: - """Send a stream of messages asynchronously. + """Asynchronously sends a stream of messages. + + This method takes an asynchronous iterable of messages and sends them + over the connection. Args: - messages (AsyncIterable[Any]): The messages to be sent. - For unary operations, this should be an iterable with a single item. + messages: An asynchronous iterable yielding messages to be sent. Raises: - NotImplementedError: This method should be implemented by subclasses. - + NotImplementedError: This method is not implemented. """ raise NotImplementedError() @property @abc.abstractmethod def response_headers(self) -> Headers: - """Retrieve the response headers. + """Gets the response headers. - Returns: - Headers: The response headers. + This is an abstract method that must be implemented by subclasses. + + Raises: + NotImplementedError: If the method is not implemented by a subclass. + Returns: + Headers: A dictionary-like object containing the response headers. """ raise NotImplementedError() @property @abc.abstractmethod def response_trailers(self) -> Headers: - """Handle response trailers. + """Returns the response trailers. - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. + This method is called after the response body has been fully read. + It provides access to any trailing headers sent by the server. - Returns: - Headers: The response trailers. + Raises: + NotImplementedError: This method is not implemented in the base class + and must be overridden in a subclass. + Returns: + Headers: A Headers object containing the trailing headers of the response. """ raise NotImplementedError() @abc.abstractmethod async def send_error(self, error: ConnectError) -> None: - """Send an error message. + """Sends a ConnectError to the client. - This method should be implemented to handle the process of sending an error message - when a ConnectError occurs. + This is an abstract method that must be implemented by a subclass. + It is responsible for serializing the error and sending it over the + transport layer. Args: - error (ConnectError): The error that needs to be sent. + error: The ConnectError instance to send. Raises: - NotImplementedError: This method is not yet implemented. - + NotImplementedError: This method must be overridden by a subclass. """ raise NotImplementedError() class UnaryClientConn(abc.ABC): - """Abstract base class for a unary client connection.""" + """Abstract base class defining the interface for a unary client connection. + + This class outlines the contract for managing a single request-response + interaction with a server. Implementations of this class are responsible for + handling the specifics of the communication protocol. + + Attributes: + spec (Spec): The specification details for the RPC call. + peer (Peer): Information about the remote peer (server). + request_headers (Headers): The headers for the outgoing request. + response_headers (Headers): The headers from the server's response. + response_trailers (Headers): The trailers from the server's response. + """ @property @abc.abstractmethod def spec(self) -> Spec: - """Return the specification details.""" + """Returns the service specification. + + This is an abstract method that must be implemented by subclasses. + + Returns: + Spec: The specification for the service. + """ raise NotImplementedError() @property @abc.abstractmethod def peer(self) -> Peer: - """Return the peer information.""" + """Returns the peer of the connection. + + This is an abstract method that must be implemented by subclasses. + + Returns: + Peer: The `Peer` instance representing the other side of the connection. + """ raise NotImplementedError() @abc.abstractmethod def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and processes it.""" + """Receives a stream of messages in response to an initial message. + + This method is an asynchronous generator that sends an initial message + and then yields incoming messages as they are received. + + Args: + message: The initial message to send to initiate the stream. + + Yields: + An asynchronous iterator that provides messages as they are received. + """ raise NotImplementedError() @property @abc.abstractmethod def request_headers(self) -> Headers: - """Return the request headers.""" + """Get the headers for an API request. + + This is an abstract method that must be implemented by subclasses. + It is responsible for constructing and returning the headers required + for making requests, which may include authentication tokens. + + Raises: + NotImplementedError: This method must be overridden in a subclass. + + Returns: + Headers: A dictionary-like object containing the request headers. + """ raise NotImplementedError() @abc.abstractmethod async def send(self, message: Any, timeout: float | None, abort_event: asyncio.Event | None) -> bytes: - """Send a message.""" + """Sends a message and waits for a response. + + This is an abstract method that must be implemented by a subclass. + + Args: + message: The message payload to send. + timeout: The maximum time in seconds to wait for a response. + If None, the call will wait indefinitely. + abort_event: An optional asyncio event that can be set to + prematurely abort the send operation. + + Returns: + The raw response received as bytes. + + Raises: + NotImplementedError: This is an abstract method. + asyncio.TimeoutError: If the timeout is reached before a response is received. + """ raise NotImplementedError() @property @abc.abstractmethod def response_headers(self) -> Headers: - """Return the response headers.""" + """Get the response headers. + + This is an abstract method that must be implemented by subclasses. + + Returns: + An object representing the response headers. + + Raises: + NotImplementedError: This method is not implemented in the base class. + """ raise NotImplementedError() @property @abc.abstractmethod def response_trailers(self) -> Headers: - """Return response trailers.""" + """Returns the response trailers. + + This method is called after the response body has been fully read. + It will not be called if the server does not send trailers. + + Raises: + NotImplementedError: This method is not implemented. + + Returns: + Headers: The response trailers. + """ raise NotImplementedError() @abc.abstractmethod def on_request_send(self, fn: Callable[..., Any]) -> None: - """Handle the request send event.""" + """Registers a callback function to be executed before a request is sent. + + This method is intended to be used as a decorator. The decorated function + will be called with the request details, allowing for inspection or + modification of the request just before it is dispatched. + + Args: + fn: The callback function to execute when a request is about to be sent. + The arguments passed to this function will depend on the specific + implementation. + + Raises: + NotImplementedError: This method is not yet implemented and must be + overridden in a subclass. + """ raise NotImplementedError() @abc.abstractmethod async def aclose(self) -> None: - """Asynchronously close the connection.""" + """Asynchronously close the connection. + + Raises: + NotImplementedError: This method is not yet implemented. + """ raise NotImplementedError() class StreamingClientConn(abc.ABC): - """Abstract base class for a streaming client connection.""" + """Abstract base class defining the interface for a streaming client connection. + + This class outlines the contract that all concrete streaming client connection + implementations must adhere to. It provides a standardized way to handle + bidirectional streaming communication, including sending and receiving data streams, + accessing headers and trailers, and managing the connection lifecycle. + + Attributes: + spec (Spec): The specification details for the connection. + peer (Peer): Information about the connected peer. + request_headers (Headers): The headers for the outgoing request. + response_headers (Headers): The headers from the incoming response. + response_trailers (Headers): The trailers from the incoming response. + """ @property @abc.abstractmethod def spec(self) -> Spec: - """Return the specification details.""" + """Returns the component specification. + + This is an abstract method that must be implemented by subclasses. + + Returns: + Spec: The component specification object. + """ raise NotImplementedError() @property @abc.abstractmethod def peer(self) -> Peer: - """Return the peer information.""" + """Gets the peer for this connection. + + A peer represents the remote endpoint of the connection. + + Returns: + Peer: An object representing the connected peer. + """ raise NotImplementedError() @abc.abstractmethod def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: - """Receives a message and processes it.""" + """Asynchronously receives a stream of messages. + + This method sends an initial message and then listens for a stream of + responses. It is an async generator that yields messages as they arrive. + + Args: + message (Any): The initial message to send to start the stream. + abort_event (asyncio.Event | None): An optional event that can be set + to signal the termination of the receive operation. + + Yields: + Any: Messages received from the stream. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ raise NotImplementedError() @property @abc.abstractmethod def request_headers(self) -> Headers: - """Return the request headers.""" + """Abstract method to get the request headers. + + This method should be implemented by subclasses to provide the + necessary headers for making requests. + + Raises: + NotImplementedError: This is an abstract method that must be + implemented by a subclass. + + Returns: + Headers: A dictionary-like object containing the HTTP headers. + """ raise NotImplementedError() @abc.abstractmethod async def send( self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: - """Send a stream of messages.""" + """Asynchronously sends a stream of messages. + + This is an abstract method that must be implemented by a subclass. It is + designed to handle sending an asynchronous stream of messages, with support + for timeouts and external cancellation. + + Args: + messages: An asynchronous iterable of messages to send. + timeout: The maximum time in seconds to wait for the send operation + to complete. If None, there is no timeout. + abort_event: An asyncio.Event that, if set, will signal the + operation to abort. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ raise NotImplementedError() @property @abc.abstractmethod def response_headers(self) -> Headers: - """Return the response headers.""" + """Gets the HTTP response headers. + + This is an abstract method that must be implemented by subclasses. + + Returns: + Headers: An object containing the response headers. + """ raise NotImplementedError() @property @abc.abstractmethod def response_trailers(self) -> Headers: - """Return response trailers.""" + """Get the response trailers. + + This should only be called after the response body has been fully read. + Not all responses will have trailers. + + Returns: + Headers: A collection of the response trailer headers. + """ raise NotImplementedError() @abc.abstractmethod def on_request_send(self, fn: Callable[..., Any]) -> None: - """Handle the request send event.""" + """Registers a callback function to be executed before a request is sent. + + This method is intended to be used as a decorator. The decorated function + will be called with the request object as its argument before the request + is sent. This allows for last-minute modifications, logging, or other + pre-request processing. + + Example: + @client.on_request_send + def add_custom_header(request): + request.headers['X-Custom-Header'] = 'my-value' + + Args: + fn (Callable[..., Any]): The callback function to be executed. It will + receive the request object as its argument. + + Raises: + NotImplementedError: This method is not implemented and should be + overridden in a subclass. + """ raise NotImplementedError() @abc.abstractmethod async def aclose(self) -> None: - """Asynchronously close the connection.""" + """Asynchronously close the connection and release all resources. + + This method is a coroutine. + """ raise NotImplementedError() async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> UnaryRequest[T]: - """Receives a unary request from the given connection and returns a UnaryRequest object. + """Receives a single message from a streaming connection to form a unary request. + + This function reads from the provided connection's stream, ensuring that exactly + one message is present. It then packages this message along with metadata from the + connection (e.g., headers, peer, HTTP method) into a UnaryRequest object. Args: - conn (StreamingHandlerConn): The connection from which to receive the unary request. - t (type[T]): The type of the message to be received. + conn (StreamingHandlerConn): The streaming connection to receive from. + t (type[T]): The type to which the incoming message should be deserialized. Returns: - UnaryRequest[T]: A UnaryRequest object containing the received message. + UnaryRequest[T]: An object representing the complete unary request, including + the deserialized message and connection metadata. + Raises: + Exception: If the stream does not contain exactly one message. """ stream = conn.receive(t) message = await ensure_single(stream) @@ -644,16 +1026,25 @@ async def receive_unary_request[T](conn: StreamingHandlerConn, t: type[T]) -> Un async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> StreamRequest[T]: - """Receive a stream request and returns a StreamRequest object. + """Constructs a StreamRequest from an incoming streaming connection. + + This function adapts the raw message stream from a StreamingHandlerConn into a + standardized StreamRequest object. It intelligently handles different stream types: + - For Server Streams, it awaits and wraps a single incoming message into an + async iterator. + - For Client and Bidirectional Streams, it uses the incoming async iterator + of messages directly. + + Generic Parameters: + T: The data type of the message(s) in the stream. Args: - conn (StreamingHandlerConn): The connection handler for the streaming request. - t (type[T]): The type of the messages expected in the stream. + conn: The active streaming connection handler from which to receive data. + t: The expected type of the incoming message(s) for deserialization. Returns: - StreamRequest[T]: An object containing the stream messages, connection specifications, - peer information, request headers, and HTTP method. - + A StreamRequest object containing the message content as an async + iterator, along with connection metadata like headers and peer info. """ if conn.spec.stream_type == StreamType.ServerStream: message = await ensure_single(conn.receive(t)) @@ -678,24 +1069,26 @@ async def receive_stream_request[T](conn: StreamingHandlerConn, t: type[T]) -> S async def receive_unary_response[T]( conn: StreamingClientConn, t: type[T], abort_event: asyncio.Event | None ) -> UnaryResponse[T]: - """Receives a unary response message from a streaming client connection. + """Receives a single message from a streaming connection for a unary-style RPC. - This asynchronous function waits for a unary message of the specified type from the given - streaming client connection. It also handles optional abortion via an asyncio event. - The response, along with any headers and trailers from the connection, is wrapped in a - UnaryResponse object and returned. + This helper function waits for a single message from the given streaming + connection, ensuring the stream closes after one message is received. It's + intended for use with unary RPCs that are transported over a streaming protocol. Args: - conn (StreamingClientConn): The streaming client connection to receive the message from. - t (type[T]): The expected type of the message to be received. - abort_event (asyncio.Event | None): Optional event to signal abortion of the receive operation. + conn (StreamingClientConn): The streaming client connection to receive from. + t (type[T]): The expected type of the response message for deserialization. + abort_event (asyncio.Event | None): An optional event to signal cancellation + of the receive operation. Returns: - UnaryResponse[T]: The received message and associated response metadata. + UnaryResponse[T]: A response object containing the single deserialized + message, along with the response headers and trailers from the + connection. Raises: - Any exceptions raised by `receive_unary_message` or connection errors. - + Exception: If the stream is closed before a message is received, or if + more than one message is received. """ message = await ensure_single(conn.receive(t, abort_event), conn.aclose) @@ -705,26 +1098,22 @@ async def receive_unary_response[T]( async def receive_stream_response[T]( conn: StreamingClientConn, t: type[T], spec: Spec, abort_event: asyncio.Event | None ) -> StreamResponse[T]: - """Handle receiving a stream response from a streaming client connection. + """Receives a streaming response from the server. + + This function adapts the behavior based on the stream type defined in the + specification. For client streams, it awaits a single response message. For + server or bidirectional streams, it returns an async iterator for the + incoming messages. Args: - conn (StreamingClientConn): The streaming client connection used to receive the stream. - t (type[T]): The type of the messages expected in the stream. - spec (Spec): The specification of the stream, including its type. - abort_event (asyncio.Event | None): An optional event to signal abortion of the stream. + conn (StreamingClientConn): The streaming client connection. + t (type[T]): The expected type of the response message(s). + spec (Spec): The RPC method's specification. + abort_event (asyncio.Event | None): An event to signal abortion of the receive operation. Returns: - StreamResponse[T]: A response object containing the received stream, response headers, - and response trailers. - - Raises: - Any exceptions raised during the reception of the stream or processing of the messages. - - Notes: - - If the stream type is `StreamType.ClientStream`, it expects exactly one message - and wraps it in a single-message stream. - - For other stream types, it directly returns the received stream. - + StreamResponse[T]: A stream response object containing the data stream, + headers, and trailers. """ if spec.stream_type == StreamType.ClientStream: single_message = await ensure_single(conn.receive(t, abort_event)) diff --git a/src/connect/connection_pool.py b/src/connect/connection_pool.py index 030e742..8818794 100644 --- a/src/connect/connection_pool.py +++ b/src/connect/connection_pool.py @@ -1,3 +1,3 @@ -"""Provides connection pool functionality using httpcore's AsyncConnectionPool.""" +"""Provides an asynchronous connection pool for making HTTP requests.""" from httpcore import AsyncConnectionPool as AsyncConnectionPool diff --git a/src/connect/content_stream.py b/src/connect/content_stream.py index 5644d58..5bf8441 100644 --- a/src/connect/content_stream.py +++ b/src/connect/content_stream.py @@ -1,4 +1,4 @@ -"""Asynchronous byte stream utilities for HTTP core response handling.""" +"""Provides classes for handling asynchronous content and data streams.""" from collections.abc import ( AsyncIterable, @@ -14,62 +14,85 @@ class AsyncByteStream(AsyncIterable[bytes]): - """An abstract base class for asynchronous byte streams. + """Abstract base class for asynchronous byte streams. - This class defines the interface for an asynchronous byte stream, which - includes methods for iterating over the stream and closing it. + This class defines the interface for an asynchronous iterable that yields bytes. + It is intended to be subclassed to implement specific byte stream sources, + such as file I/O, network connections, or in-memory buffers. + Subclasses must implement the `__aiter__` method to provide the core + asynchronous iteration logic. The `aclose` method can be overridden to + release any underlying resources. """ async def __aiter__(self) -> AsyncIterator[bytes]: - """Asynchronous iterator method. + """Asynchronously iterates over the content stream. - This method should be implemented to provide asynchronous iteration - over the object. It must return an asynchronous iterator that yields - bytes. + This allows the object to be used in an `async for` loop, yielding + the content in chunks. - Raises: - NotImplementedError: If the method is not implemented. + Yields: + bytes: A chunk of the content from the stream. + Returns: + AsyncIterator[bytes]: An asynchronous iterator over the content stream. """ raise NotImplementedError("The '__aiter__' method must be implemented.") # pragma: no cover yield b"" async def aclose(self) -> None: - """Asynchronously close the byte stream.""" + """Closes the stream and the underlying connection. + + This method is an asynchronous generator and should be used with `async for`. + It will close the stream and the underlying connection when the generator is exhausted. + """ pass class BoundAsyncStream(AsyncByteStream): - """An asynchronous byte stream wrapper that binds to an existing async iterable of bytes. + """A wrapper for an asynchronous byte stream that ensures proper resource management. - This class provides an asynchronous iterator interface for reading byte chunks from the given stream, - and ensures proper resource cleanup by closing the underlying stream when needed. + This class takes an asynchronous iterable of bytes and provides an `AsyncByteStream` + interface. It is responsible for iterating over the underlying stream and ensuring + that it is properly closed, even in the event of an error during iteration. - Args: - stream (AsyncIterable[bytes]): The asynchronous iterable byte stream to wrap. + The `aclose` method is idempotent, meaning it can be called multiple times without + causing an error. Attributes: - stream (AsyncIterable[bytes]): The wrapped asynchronous byte stream. - _closed (bool): Indicates whether the stream has been closed. - + _stream (AsyncIterable[bytes] | None): The underlying asynchronous stream. + It is set to `None` once the stream is closed. + _closed (bool): A flag to indicate whether the stream has been closed. """ _stream: AsyncIterable[bytes] | None _closed: bool def __init__(self, stream: AsyncIterable[bytes]) -> None: - """Initialize the object with an asynchronous iterable stream of bytes. + """Initialize the content stream. Args: - stream (AsyncIterable[bytes]): An asynchronous iterable that yields bytes. - + stream: An asynchronous iterable of bytes representing the content stream. """ self._stream = stream self._closed = False async def __aiter__(self) -> AsyncIterator[bytes]: - """Asynchronous iterator method to read byte chunks from the stream.""" + """Asynchronously iterates over the response content. + + This method allows the response body to be consumed in chunks, which is + useful for handling large files or streaming data. It ensures that the + underlying stream is closed, even if an error occurs during iteration. + + Yields: + bytes: A chunk of the response body. + + Raises: + Exception: An exception that occurred during streaming. The stream + is closed before the exception is re-raised. + ExceptionGroup: Raised if an error occurs during iteration and a + separate error occurs while attempting to close the stream. + """ if self._stream is None: return @@ -85,7 +108,15 @@ async def __aiter__(self) -> AsyncIterator[bytes]: raise async def aclose(self) -> None: - """Asynchronously close the stream.""" + """Asynchronously closes the content stream. + + This method ensures that the underlying stream is properly closed and + resources are released. It is idempotent, meaning it can be called + e multiple times without raising an error or causing issues. + + Any exceptions that occur during the closing of the underlying `httpcore` + stream are caught and re-raised as `connect.exceptions.ConnectError`. + """ if self._closed: return @@ -101,15 +132,22 @@ async def aclose(self) -> None: class AsyncDataStream[T]: - """An asynchronous data stream wrapper that provides iteration and cleanup functionality. + """Wraps an asynchronous iterable to provide a uniform interface for iteration and closure. + + This class is designed to handle various asynchronous data sources, such as streaming + API responses, ensuring that the underlying resources are properly released after + consumption or in case of an error. + + It can be used directly in an `async for` loop. The `aclose()` method should be + called explicitly to ensure the stream is closed and resources are released. Type Parameters: T: The type of items yielded by the stream. Attributes: - _stream (AsyncIterable[T]): The underlying asynchronous iterable data stream. - aclose_func (Callable[..., Awaitable[None]] | None): Optional asynchronous cleanup function to be called on close. - + _stream (AsyncIterable[T] | None): The underlying asynchronous iterable. + _aclose_func (Callable[..., Awaitable[None]] | None): An optional custom close function. + _closed (bool): A flag indicating whether the stream has been closed. """ _stream: AsyncIterable[T] | None @@ -117,26 +155,31 @@ class AsyncDataStream[T]: _closed: bool def __init__(self, stream: AsyncIterable[T], aclose_func: Callable[..., Awaitable[None]] | None = None) -> None: - """Initialize the object with an asynchronous iterable stream and an optional asynchronous close function. + """Initializes the ContentStream. Args: - stream (AsyncIterable[T]): The asynchronous iterable stream to be wrapped. - aclose_func (Callable[..., Awaitable[None]], optional): An optional asynchronous function to be called when closing the stream. Defaults to None. - + stream: The asynchronous iterable that provides the content. + aclose_func: An optional asynchronous function to call when closing the stream. """ self._stream = stream self._aclose_func = aclose_func self._closed = False async def __aiter__(self) -> AsyncIterator[T]: - """Asynchronously iterates over the underlying stream, yielding each part. + """Asynchronously iterates over the content stream. + + This method allows the content stream to be used in an `async for` loop, + yielding each part of the stream as it is received. Yields: - T: The next part from the stream. + T: The next part of the content from the stream. Raises: - Propagates any exception raised during iteration after ensuring the stream is closed. - + Exception: Re-raises any exception encountered during stream iteration + after attempting to close the stream. + ExceptionGroup: Raised if an exception occurs during stream iteration + and another exception occurs while attempting to close the stream + in response to the first error. """ if self._stream is None: return @@ -152,14 +195,16 @@ async def __aiter__(self) -> AsyncIterator[T]: raise async def aclose(self) -> None: - """Asynchronously closes the underlying stream. + """Asynchronously closes the content stream and releases its resources. - If a custom asynchronous close function (`aclose_func`) is provided, it is awaited. - Otherwise, if the underlying stream has an `aclose` method, it is retrieved and awaited. - - Raises: - Any exception raised by the custom close function or the stream's `aclose` method. + This method marks the stream as closed to prevent further operations. + It will invoke the custom `_aclose_func` if one was provided during + initialization. Otherwise, it attempts to call the `aclose()` method + on the underlying stream object. + The method is idempotent, meaning calling it on an already closed + stream will have no effect. Finally, it clears internal references + to the stream and the close function. """ if self._closed: return diff --git a/src/connect/envelope.py b/src/connect/envelope.py index a8630f2..ac46d2e 100644 --- a/src/connect/envelope.py +++ b/src/connect/envelope.py @@ -13,12 +13,18 @@ class EnvelopeFlags(Flag): - """EnvelopeFlags is an enumeration that defines flags for an envelope. + """Flags for an envelope. - Attributes: - compressed (int): Flag indicating that the envelope is compressed. - end_stream (int): Flag indicating that the envelope marks the end of a stream. + This enumeration defines the bit flags that can be set on a Connect protocol + envelope to indicate special handling or metadata. + Attributes: + compressed: Indicates that the message is compressed. The compression + algorithm is determined by the `Content-Encoding` header. + end_stream: Signals the end of a stream. This is used in streaming RPCs + to indicate that no more messages will be sent. + trailer: Indicates that the envelope contains trailers instead of a + message. Trailers are sent as the last message in a stream. """ compressed = 0b00000001 @@ -27,13 +33,18 @@ class EnvelopeFlags(Flag): class Envelope: - """A class to represent an Envelope which contains data and flags. + """A class to represent a protocol message envelope. - Attributes: - data (bytes): The data contained in the envelope. - flags (EnvelopeFlags): The flags associated with the envelope. - _format (str): The format string used for struct packing and unpacking. + This class handles the encoding and decoding of messages, which consist of a + 5-byte header and a variable-length data payload. The header contains flags + and the length of the data payload. The structure of the header is defined + by the `_format` attribute, which is a struct format string '>BI' + (big-endian, 1-byte unsigned char for flags, 4-byte unsigned int for data length). + Attributes: + data (bytes): The payload of the envelope. + flags (EnvelopeFlags): An enum representing the flags associated with the envelope. + _format (str): The struct format string for encoding/decoding the header. """ data: bytes @@ -41,49 +52,59 @@ class Envelope: _format: str = ">BI" def __init__(self, data: bytes, flags: EnvelopeFlags) -> None: - """Initialize a new instance of the class. + """Initializes a new Envelope instance. Args: - data (bytes): The data to be processed. - flags (EnvelopeFlags): The flags associated with the envelope. - + data: The raw byte data of the envelope. + flags: The flags associated with the envelope, indicating its type. """ self.data = data self.flags = flags def encode(self) -> bytes: - """Encode the header and data into a byte sequence. + """Serializes the envelope into a byte representation. - Returns: - bytes: The encoded byte sequence consisting of the header and data. + The resulting byte string is a concatenation of the message header + and the message data. The header contains the flags and the length + of the data. + Returns: + The serialized envelope as a bytes object. """ return self.encode_header(self.flags.value, self.data) + self.data def encode_header(self, flags: int, data: bytes) -> bytes: - """Encode the header for a protocol message. + """Encodes the header for a message envelope. + + This method packs the given flags and the length of the data into a + binary structure according to the format defined in `self._format`. Args: - flags (int): The flags to include in the header. - data (bytes): The data to be sent, used to determine the length. + flags: An integer representing the message flags. + data: The byte string payload of the message. The length of this + data will be encoded in the header. Returns: - bytes: The encoded header as a byte string. - + The encoded header as a byte string. """ return struct.pack(self._format, flags, len(data)) @staticmethod def decode_header(data: bytes) -> tuple[EnvelopeFlags, int] | None: - """Decode the header from the given byte data. + """Decodes an envelope header from a byte string. + + This function reads the first 5 bytes of the provided data to extract + the envelope flags and the length of the main data payload. Args: - data (bytes): The byte data containing the header to decode. + data: The byte string containing the envelope header. Returns: - tuple[EnvelopeFlags, int] | None: A tuple containing the decoded EnvelopeFlags and data length if the data is valid, - otherwise None if the data length is less than 5 bytes. + A tuple containing the `EnvelopeFlags` and the data length as an + integer if the header is successfully decoded. Returns `None` if + the input data is too short to contain a valid header (i.e., less + than 5 bytes). """ if len(data) < 5: return None @@ -93,15 +114,20 @@ def decode_header(data: bytes) -> tuple[EnvelopeFlags, int] | None: @staticmethod def decode(data: bytes) -> "tuple[Envelope | None, int]": - """Decode the given byte data into an Envelope object and its length. + """Decodes a byte stream into an Envelope object. + + This method reads the envelope header to determine the payload size, + then attempts to construct an Envelope object from the payload. Args: - data (bytes): The byte data to decode. + data: The raw byte data to be decoded. Returns: - tuple[Envelope | None, int]: A tuple containing the decoded Envelope object (or None if decoding fails) - and the length of the data. If the data is insufficient to decode, returns (None, data_len). - + A tuple containing the decoded Envelope and the payload length. + - If decoding is successful, returns `(Envelope, payload_length)`. + - If the data is insufficient to contain the full payload as + indicated by the header, returns `(None, expected_payload_length)`. + - If the header is invalid or cannot be decoded, returns `(None, 0)`. """ header = Envelope.decode_header(data) if header is None: @@ -114,38 +140,33 @@ def decode(data: bytes) -> "tuple[Envelope | None, int]": return Envelope(data[5 : 5 + data_len], flags), data_len def is_set(self, flag: EnvelopeFlags) -> bool: - """Check if a specific flag is set in the envelope. + """Checks if a specific flag is set in the envelope's flags. Args: - flag (EnvelopeFlags): The flag to check. + flag: The flag to check for. Returns: - bool: True if the flag is set, False otherwise. - + True if the flag is set, False otherwise. """ return flag in self.flags class EnvelopeWriter: - """EnvelopeWriter is responsible for marshaling messages, optionally compressing them, and writing them into envelopes for transmission. - - Attributes: - codec (Codec | None): The codec used for encoding and decoding messages. - send_max_bytes (int): The maximum number of bytes allowed per message. - compression (Compression | None): The compression method to use, or None for no compression. - - Methods: - __init__(codec, compression, compress_min_bytes, send_max_bytes): - Initializes the EnvelopeWriter with the specified codec, compression, and size constraints. - - async _marshal(messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - Asynchronously marshals and optionally compresses messages from an async iterable, yielding encoded envelope bytes. - Raises ConnectError if marshaling fails or message size exceeds the allowed limit. + """Manages the process of marshaling, compressing, and framing messages into envelopes. - write_envelope(data: bytes, flags: EnvelopeFlags) -> Envelope: - Writes an envelope, optionally compressing its data if conditions are met, and updates envelope flags accordingly. - Raises ConnectError if the (compressed) message size exceeds the allowed maximum. + This class is responsible for taking application-level messages, encoding them into + bytes using a specified codec, and then optionally compressing them. The resulting + data is wrapped in an Envelope object, which includes flags and the payload, + ready for transmission. It also enforces size limits on outgoing messages. + Attributes: + codec (Codec | None): The codec used for marshaling messages. + compress_min_bytes (int): The minimum size in bytes a message must be + before compression is applied. + send_max_bytes (int): The maximum allowed size in bytes for a message + payload after any compression. + compression (Compression | None): The compression algorithm to use. If None, + compression is disabled. """ codec: Codec | None @@ -156,14 +177,13 @@ class EnvelopeWriter: def __init__( self, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, send_max_bytes: int ) -> None: - """Initialize the ProtocolConnect instance. + """Initializes the Envelope. Args: - codec (Codec): The codec to be used for encoding and decoding. - compression (Compression | None): The compression method to be used, or None if no compression is to be applied. - compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes that can be sent in a single message. - + codec: The codec to use for encoding messages. + compression: The compression algorithm to use. + compress_min_bytes: The minimum number of bytes a message must be to be compressed. + send_max_bytes: The maximum number of bytes for a message to be sent. """ self.codec = codec self.compress_min_bytes = compress_min_bytes @@ -171,17 +191,21 @@ def __init__( self.compression = compression async def marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Asynchronously marshals and compresses messages from an asynchronous iterator. + """Marshals an asynchronous stream of messages into Connect envelopes. + + This asynchronous generator takes an iterable of messages, marshals each one + using the configured codec, wraps it in a Connect envelope, and yields + the encoded envelope as bytes. Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled. + messages: An asynchronous iterable of messages to be marshaled. Yields: - AsyncIterator[bytes]: An asynchronous iterator of marshaled and optionally compressed messages in bytes. + The next marshaled and enveloped message as bytes. Raises: - ConnectError: If there is an error during marshaling or if the message size exceeds the allowed limit. - + ConnectError: If the codec is not set or if an error occurs + during message marshaling. """ if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) @@ -196,23 +220,28 @@ async def marshal(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: yield env.encode() def write_envelope(self, data: bytes, flags: EnvelopeFlags) -> Envelope: - """Write an envelope containing the provided data, applying compression if required. + """Creates an Envelope from the given data, handling compression. + + This method takes raw byte data and prepares it for sending. It will + attempt to compress the data if a compression algorithm is configured, + the data is larger than `compress_min_bytes`, and the `compressed` + flag is not already set. + + If the data is compressed, the `EnvelopeFlags.compressed` flag is added. + The method also validates the final data size against the `send_max_bytes` + limit, raising an error if it's exceeded. Args: - data (bytes): The message payload to be written into the envelope. - flags (EnvelopeFlags): Flags indicating envelope properties, such as compression. + data (bytes): The raw message data to be enveloped. + flags (EnvelopeFlags): The initial flags for the envelope. Returns: - Envelope: An envelope object containing the (optionally compressed) data and updated flags. + Envelope: An envelope containing the potentially compressed data and + updated flags. Raises: - ConnectError: If the (compressed or uncompressed) data size exceeds the configured send_max_bytes limit. - - Notes: - - Compression is applied only if the flags do not already indicate compression, - compression is enabled, and the data size exceeds the minimum threshold. - - The flags are updated to include the compressed flag if compression is performed. - + ConnectError: If the size of the data (either raw or compressed) + exceeds the configured `send_max_bytes` limit. """ if EnvelopeFlags.compressed in flags or self.compression is None or len(data) < self.compress_min_bytes: if self.send_max_bytes > 0 and len(data) > self.send_max_bytes: @@ -238,14 +267,22 @@ def write_envelope(self, data: bytes, flags: EnvelopeFlags) -> Envelope: class EnvelopeReader: - """A class to handle the unmarshaling of streaming data. + """Reads and decodes enveloped messages from an asynchronous byte stream. - Attributes: - codec (Codec): The codec used for unmarshaling data. - compression (Compression | None): The compression method used, if any. - stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. - buffer (bytes): The buffer to store incoming data chunks. + This class is responsible for processing the Connect protocol's envelope format. + It reads data from a stream, parses envelopes (which consist of a flag byte, + a 4-byte length prefix, and the message data), handles decompression, and + uses a specified codec to unmarshal the message data into Python objects. + Attributes: + codec (Codec | None): The codec used for unmarshaling message data. + read_max_bytes (int): The maximum permitted size in bytes for a single message. + compression (Compression | None): The algorithm used for decompressing message data. + stream (AsyncIterable[bytes] | None): The source asynchronous byte stream. + buffer (bytes): An internal buffer for accumulating data from the stream. + bytes_read (int): A counter for the total number of bytes read. + last (Envelope | None): Stores the final envelope, which typically contains + end-of-stream metadata. """ codec: Codec | None @@ -263,14 +300,13 @@ def __init__( stream: AsyncIterable[bytes] | None = None, compression: Compression | None = None, ) -> None: - """Initialize the protocol connection. + """Initializes the EnvelopeReader. Args: - codec (Codec): The codec to use for encoding and decoding data. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read from. Defaults to None. - compression (Compression | None, optional): The compression method to use. Defaults to None. - + codec: The codec to use for decoding messages. + read_max_bytes: The maximum number of bytes to read from the stream. + stream: The asynchronous stream of bytes to read from. + compression: The compression algorithm to use for decompression. """ self.codec = codec self.read_max_bytes = read_max_bytes @@ -281,18 +317,27 @@ def __init__( self.last = None async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: - """Asynchronously unmarshals messages from the stream. + """Unmarshals a stream of enveloped messages according to the Connect protocol. + + This asynchronous generator reads byte chunks from the underlying stream, + buffering them until a complete message envelope can be decoded. It handles + message framing, decompression, and unmarshaling of the payload. Args: - message (Any): The message type to unmarshal. + message (Any): The target message type (e.g., a protobuf message class) + into which the payload will be unmarshaled. Yields: - Any: The unmarshaled message object. + tuple[Any, bool]: An async iterator yielding tuples where the first element + is the unmarshaled message object and the second is a boolean flag. + The flag is `True` if this is the final message (i.e., an end-of-stream + envelope), otherwise `False`. Raises: - ConnectError: If the stream is not set, if there is an error in the - unmarshaling process, or if there is a protocol error. - + ConnectError: If the stream or codec is not configured, if a message + size exceeds the configured `read_max_bytes`, if a compressed + message is received without a configured decompressor, or if any + other protocol, decompression, or unmarshaling error occurs. """ if self.stream is None: raise ConnectError("stream is not set", Code.INTERNAL) @@ -351,16 +396,11 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: raise ConnectError(message, Code.INVALID_ARGUMENT) async def aclose(self) -> None: - """Asynchronously closes the stream if it has an `aclose` method. - - This method checks if the `self.stream` object has an asynchronous - `aclose` method. If the method exists, it is invoked to close the stream. - - The bytes_read counter is not reset when closing the stream. - - Returns: - None + """Asynchronously closes the underlying stream. + This method checks for an `aclose` callable on the stream + and awaits it if found, ensuring proper resource cleanup in an + asynchronous environment. """ aclose = get_acallable_attribute(self.stream, "aclose") if aclose: diff --git a/src/connect/error.py b/src/connect/error.py index fd39f59..cbbd91e 100644 --- a/src/connect/error.py +++ b/src/connect/error.py @@ -1,7 +1,7 @@ # Copyright 2024 Gaudiy, Inc. # SPDX-License-Identifier: Apache-2.0 -"""Error represents an error in the Connect protocol.""" +"""Defines Connect protocol errors and related utilities.""" import google.protobuf.any_pb2 as any_pb2 import google.protobuf.symbol_database as symbol_database @@ -14,7 +14,25 @@ def type_url_to_message(type_url: str) -> Message: - """Return a message instance corresponding to a given type URL.""" + """Converts a Protobuf `Any` type URL into an empty message instance. + + This function takes a type URL, as found in the `type_url` field of a + `google.protobuf.any_pb2.Any` message, and returns an empty instance of the + corresponding Protobuf message class. It uses the default symbol database + to look up the message type. + + Args: + type_url: The type URL string to resolve. It must start with + 'type.googleapis.com/'. + + Returns: + An empty instance of the resolved Protobuf message. + + Raises: + ValueError: If the `type_url` does not have the expected prefix. + KeyError: If the message type for the given `type_url` cannot be + found in the symbol database. + """ if not type_url.startswith(DEFAULT_ANY_RESOLVER_PREFIX): raise ValueError(f"Type URL has to start with a prefix {DEFAULT_ANY_RESOLVER_PREFIX}: {type_url}") @@ -31,13 +49,20 @@ def type_url_to_message(type_url: str) -> Message: class ErrorDetail: - """ErrorDetail class represents the details of an error. + """Represents a detailed error message from a Connect RPC. - Attributes: - pb_any (any_pb2.Any): A protobuf Any type containing the error details. - pb_inner (Message): A protobuf Message containing the inner error details. - wire_json (str | None): A JSON string representation of the error, if available. + Connect errors can include a list of details, which are Protobuf messages + that provide more context about the error. These details are serialized as + `google.protobuf.any_pb2.Any` messages. This class provides a wrapper + around an `Any` message, allowing for lazy unpacking of the specific, + underlying error message. + Attributes: + pb_any (any_pb2.Any): The raw `google.protobuf.any_pb2.Any` message. + pb_inner (Message | None): The unpacked, specific Protobuf error message. + This is lazily populated when `get_inner` is called for the first time. + wire_json (str | None): The raw JSON representation of the error detail, + as received over the wire. """ pb_any: any_pb2.Any @@ -45,13 +70,33 @@ class ErrorDetail: wire_json: str | None = None def __init__(self, pb_any: any_pb2.Any, pb_inner: Message | None = None, wire_json: str | None = None) -> None: - """Initialize an ErrorDetail.""" + """Initializes a new ConnectErrorDetail. + + Args: + pb_any (any_pb2.Any): The Protobuf Any message containing the error detail. + pb_inner (Message | None): The specific, deserialized Protobuf message from the detail. + wire_json (str | None): The raw JSON representation of the detail from the wire. + """ self.pb_any = pb_any self.pb_inner = pb_inner self.wire_json = wire_json def get_inner(self) -> Message: - """Get the inner error message.""" + """Unpacks and returns the inner protobuf message from the error detail. + + This method deserializes the `google.protobuf.Any` message contained + within the error detail into its specific message type. The result is + cached, so subsequent calls to this method will not re-unpack the + message. + + Returns: + The unpacked protobuf message. + + Raises: + ValueError: If the type URL in the `Any` field does not match the + type of the packed message, indicating a data corruption or + mismatch. + """ if self.pb_inner: return self.pb_inner @@ -65,21 +110,39 @@ def get_inner(self) -> Message: return msg -# Helper function to create error messages with code prefix def create_message(message: str, code: Code) -> str: - """Create an error message with a code prefix.""" + """Creates a formatted error message from a code and a detail message. + + If the `message` is empty, this function returns the string representation + of the `code`. Otherwise, it returns a string formatted as + ": ". + + Args: + message: The specific error message. + code: The error code enum instance. + + Returns: + The formatted error message string. + """ return code.string() if message == "" else f"{code.string()}: {message}" class ConnectError(Exception): - """Exception raised for errors that occur within the Connect system. + """Represents an error in the Connect protocol. - Attributes: - raw_message (str): The original error message. - code (Code): The error code, default is Code.UNKNOWN. - metadata (MutableMapping[str, str]): Additional metadata related to the error. - details (list[ErrorDetail]): Detailed information about the error. + Connect errors are sent by servers when a request fails. They have a code, + a message, and optional binary details. This exception is raised by clients + when they receive an error from a server. It may also be raised by the + framework to indicate a client-side problem (e.g., a network error). + Attributes: + raw_message (str): The original, unformatted error message from the server or client. + code (Code): The Connect error code. + metadata (Headers): Any metadata (headers) associated with the error. + details (list[ErrorDetail]): A list of structured, typed error details. + wire_error (bool): True if the error was raised due to a protocol-level issue + (e.g., malformed response, network error), rather than an + error returned by the application logic. """ raw_message: str @@ -96,7 +159,17 @@ def __init__( details: list[ErrorDetail] | None = None, wire_error: bool = False, ) -> None: - """Initialize a Error.""" + """Initializes a new ConnectError. + + Args: + message (str): The error message. + code (Code): The Connect error code. Defaults to Code.UNKNOWN. + metadata (Headers | None): Any metadata to attach to the error. Defaults to None. + details (list[ErrorDetail] | None): A list of protobuf Any messages to attach as error details. + Defaults to None. + wire_error (bool): Whether this error was created from a serialized error on the wire. + Defaults to False. + """ super().__init__(create_message(message, code)) self.raw_message = message self.code = code diff --git a/src/connect/handler.py b/src/connect/handler.py index d14bee6..82a7c9b 100644 --- a/src/connect/handler.py +++ b/src/connect/handler.py @@ -1,4 +1,4 @@ -"""Module provides handler configurations and implementations for unary procedures and stream types.""" +"""Defines the server-side handlers for Connect, gRPC, and gRPC-Web RPCs.""" import asyncio from collections.abc import Awaitable, Callable @@ -27,7 +27,7 @@ from connect.handler_interceptor import apply_interceptors from connect.headers import Headers from connect.idempotency_level import IdempotencyLevel -from connect.options import ConnectOptions +from connect.options import HandlerOptions from connect.protocol import ( HEADER_CONTENT_LENGTH, HEADER_CONTENT_TYPE, @@ -55,18 +55,23 @@ class HandlerConfig: - """HandlerConfig encapsulates the configuration for a handler in the Connect framework. + """Configuration for an RPC handler. + + This class encapsulates all the configuration required to execute a specific RPC + procedure. It includes details about the procedure itself, serialization codecs, + compression algorithms, and various protocol-level settings. Attributes: - codecs (dict[str, Codec]): A mapping of codec names to codec instances supported by the handler. - compressions (list[Compression]): A list of compression algorithms supported by the handler. - descriptor (Any): The descriptor providing metadata about the procedure. - compress_min_bytes (int): The minimum message size (in bytes) before compression is applied. - read_max_bytes (int): The maximum number of bytes allowed to be read in a single message. - send_max_bytes (int): The maximum number of bytes allowed to be sent in a single message. - require_connect_protocol_header (bool): Whether the Connect protocol header is required. + procedure (str): The full name of the RPC procedure (e.g., /acme.foo.v1.FooService/Bar). + stream_type (StreamType): The type of stream for the procedure (unary, client, server, or bidi). + codecs (dict[str, Codec]): A dictionary mapping codec names to their respective Codec implementations. + compressions (list[Compression]): A list of supported compression algorithms. + descriptor (Any): The protobuf message or service descriptor. + compress_min_bytes (int): The minimum number of bytes a message must have to be considered for compression. + read_max_bytes (int): The maximum number of bytes to read for a single message. + send_max_bytes (int): The maximum number of bytes to send for a single message. + require_connect_protocol_header (bool): Whether to require the `Connect-Protocol-Version` header. idempotency_level (IdempotencyLevel): The idempotency level of the procedure. - """ procedure: str @@ -80,14 +85,25 @@ class HandlerConfig: require_connect_protocol_header: bool idempotency_level: IdempotencyLevel - def __init__(self, procedure: str, stream_type: StreamType, options: ConnectOptions): - """Initialize a new handler instance with the specified procedure, stream type, and options. + def __init__(self, procedure: str, stream_type: StreamType, options: HandlerOptions): + """Initializes a new Handler. Args: - procedure (str): The name of the procedure to handle. - stream_type (StreamType): The type of stream (e.g., unary, server streaming, etc.). - options (ConnectOptions): Configuration options for the handler, including descriptor, compression, and protocol settings. - + procedure (str): The full name of the RPC procedure. + stream_type (StreamType): The type of stream for the procedure. + options (HandlerOptions): Configuration options for the handler. + + Attributes: + procedure (str): The full name of the RPC procedure. + stream_type (StreamType): The type of stream for the procedure. + codecs (dict[str, Codec]): A dictionary of supported codecs, keyed by name. + compressions (list[Compression]): A list of supported compression algorithms. + descriptor: The protobuf method descriptor. + compress_min_bytes (int): The minimum number of bytes for a response to be compressed. + read_max_bytes (int): The maximum number of bytes to read for a request message. + send_max_bytes (int): The maximum number of bytes to send for a response message. + require_connect_protocol_header (bool): Whether to require the Connect protocol header. + idempotency_level: The idempotency level of the procedure. """ self.procedure = procedure self.stream_type = stream_type @@ -106,11 +122,11 @@ def __init__(self, procedure: str, stream_type: StreamType, options: ConnectOpti self.idempotency_level = options.idempotency_level def spec(self) -> Spec: - """Create and returns a Spec object initialized with the current handler's procedure, descriptor, stream type, and idempotency level. + """Get the specification for the handler. Returns: - Spec: An instance of the Spec class containing the handler's configuration. - + Spec: A `Spec` object containing the handler's specification, + including procedure, descriptor, stream type, and idempotency level. """ return Spec( procedure=self.procedure, @@ -121,15 +137,19 @@ def spec(self) -> Spec: def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: - """Create and returns a list of protocol handlers based on the provided configuration. + """Creates and configures protocol handlers based on the provided configuration. + + This function initializes handlers for the Connect, gRPC, and gRPC-Web protocols. + Each handler is configured with parameters extracted from the `config` object, + such as codecs, compression algorithms, message size limits, and other + protocol-specific settings. Args: - config (HandlerConfig): The configuration object containing settings for codecs, compressions, - byte limits, protocol requirements, and idempotency level. + config: A HandlerConfig object containing the configuration + for the protocol handlers. Returns: - list[ProtocolHandler]: A list of initialized protocol handler instances for each supported protocol. - + A list of initialized ProtocolHandler instances. """ protocols = [ProtocolConnect(), ProtocolGRPC(web=False), ProtocolGRPC(web=True)] @@ -156,14 +176,22 @@ def create_protocol_handlers(config: HandlerConfig) -> list[ProtocolHandler]: class Handler: - """Handler is an abstract base class for handling HTTP requests in a protocol-agnostic way, supporting both unary and streaming RPCs. + """A base handler for a single RPC procedure. - Attributes: - protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): Mapping of HTTP methods to their protocol handlers. - allow_methods (str): String specifying allowed HTTP methods. - accept_post (str): String specifying accepted content types for POST requests. - protocol_handler (ProtocolHandler): The protocol handler selected for the current request. + This class is responsible for routing an incoming HTTP request to the correct + protocol-specific handler (e.g., Connect, gRPC, gRPC-Web) based on the + HTTP method and the Content-Type header. It manages the request lifecycle, + including validation, asynchronous processing, and error handling. + Subclasses must implement the `implementation` method to define the + procedure's business logic. + + Attributes: + procedure (str): The fully-qualified name of the procedure (e.g., /acme.foo.v1.FooService/Bar). + protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A mapping of HTTP methods to the protocol handlers that support them. + allow_methods (str): A comma-separated string of allowed HTTP methods, used in the `Allow` header for 405 responses. + accept_post (str): A comma-separated string of supported `Content-Type` values for POST requests, used in the `Accept-Post` header for 415 responses. + protocol_handler (ProtocolHandler): The specific protocol handler chosen to handle the current request. This is set within the `handle` method. """ procedure: str @@ -179,14 +207,15 @@ def __init__( allow_methods: str, accept_post: str, ) -> None: - """Initialize the handler with the specified procedure, protocol handlers, and HTTP method configurations. + """Initializes a handler for a specific RPC procedure. Args: - procedure (str): The name of the procedure to be handled. - protocol_handlers (dict[HTTPMethod, list[ProtocolHandler]]): A mapping of HTTP methods to their corresponding protocol handlers. - allow_methods (str): A string specifying which HTTP methods are allowed. - accept_post (str): A string specifying the accepted content types for POST requests. - + procedure: The full name of the procedure. + protocol_handlers: A dictionary mapping HTTP methods to a list of + protocol-specific handlers that can process requests for this procedure. + allow_methods: The value for the 'Allow' HTTP header, listing supported methods. + accept_post: The value for the 'Accept-Post' HTTP header, listing supported + content types for POST requests. """ self.procedure = procedure self.protocol_handlers = protocol_handlers @@ -194,42 +223,45 @@ def __init__( self.accept_post = accept_post async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: - """Abstract method to be implemented by subclasses to handle streaming connections. + """The actual implementation of the streaming handler logic. - Args: - conn (StreamingHandlerConn): The streaming connection handler instance. - timeout (float | None): Optional timeout value in seconds for the operation. - - Raises: - NotImplementedError: If the method is not implemented by a subclass. + This method must be overridden by subclasses to define the specific + behavior of the handler. It is called to process the streaming + connection. + Args: + conn (StreamingHandlerConn): The connection object for the streaming + session, used to send and receive messages. + timeout (float | None): An optional timeout in seconds for the + entire handling operation. """ raise NotImplementedError() async def handle(self, request: Request) -> Response: - """Handle an incoming HTTP request and returns an appropriate response. - - This method determines the correct protocol handler based on the HTTP method and content type, - validates the request (including checking for unsupported methods or media types), and manages - asynchronous processing of the request and response writing. + """Handles an incoming HTTP request and routes it to the appropriate Connect protocol handler. + + This method acts as the main entry point for the Connect service. It performs the + following steps: + 1. Validates the HTTP method. If the method is not supported, it returns a + 405 Method Not Allowed response. + 2. Determines the correct protocol handler (e.g., Connect, gRPC-Web) based on + the request's Content-Type header. If no suitable handler is found, it + returns a 415 Unsupported Media Type response. + 3. For GET requests, it ensures there is no request body, returning a 415 + response if a body is present, as per the Connect protocol specification. + 4. It creates two concurrent tasks: + - One to execute the actual RPC logic (`_handle`). + - One to wait for the response headers to be written by the logic task. + 5. It waits for the first task to complete. The response is typically generated + as soon as the headers are available, allowing for streaming responses. + 6. Ensures proper cleanup by cancelling any lingering tasks. + 7. Returns the generated `Response` object to the web server. Args: - request (Request): The incoming HTTP request to be handled. + request: The incoming Starlette Request object. Returns: - Response: The HTTP response generated by the handler. - - Raises: - Exception: Propagates any exception raised during request handling. - asyncio.CancelledError: If the handling task is cancelled. - - Behavior: - - Returns 405 Method Not Allowed if the HTTP method is not supported. - - Returns 415 Unsupported Media Type if no protocol handler can handle the request's content type. - - For GET requests, returns 415 if a request body is present. - - Handles the request asynchronously, ensuring proper cleanup of tasks. - - Returns a 500 Internal Server Error if no response is generated. - + A Starlette Response object to be sent to the client. """ response_headers = Headers(encoding="latin-1") response_trailers = Headers(encoding="latin-1") @@ -274,7 +306,7 @@ async def handle(self, request: Request) -> Response: writer = ServerResponseWriter() - handle_task = asyncio.create_task(self._handle(request, response_headers, response_trailers, writer)) + handle_task = asyncio.create_task(self._handle_rpc(request, response_headers, response_trailers, writer)) writer_task = asyncio.create_task(writer.receive()) response: Response | None = None @@ -308,23 +340,24 @@ async def handle(self, request: Request) -> Response: return response - async def _handle( + async def _handle_rpc( self, request: Request, response_headers: Headers, response_trailers: Headers, writer: ServerResponseWriter ) -> None: - """Handle an incoming request by establishing a connection, parsing timeout values, and invoking the implementation logic. - - Args: - request (Request): The incoming request object. - response_headers (Headers): Headers to be sent in the response. - response_trailers (Headers): Trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send responses to the client. - - Returns: - None + """Handles a single RPC request. - Raises: - Sends an appropriate ConnectError to the client if an exception occurs during processing, including timeout, unimplemented, or internal errors. + This internal method orchestrates the processing of a request by: + 1. Initializing a protocol-specific connection handler. + 2. Parsing the request timeout. + 3. Executing the user-provided service implementation within the timeout. + 4. Catching any exceptions, including timeouts, and mapping them to + the appropriate Connect protocol error. + 5. Sending the error back to the client if one occurs. + Args: + request: The incoming request object. + response_headers: The headers for the response. + response_trailers: The trailers for the response. + writer: The server response writer to send data to the client. """ conn = await self.protocol_handler.conn(request, response_headers, response_trailers, writer) if conn is None: @@ -352,12 +385,17 @@ async def _handle( class UnaryHandler[T_Request, T_Response](Handler): - """UnaryHandler is a generic handler class for unary RPC procedures. + """A concrete implementation of the `Handler` class for unary RPCs. - Type Parameters: - T_Request: The type of the request message. - T_Response: The type of the response message. + This handler is responsible for processing RPCs that involve a single request message + and a single response message. It is generic over the request and response types. + Attributes: + stream_type (StreamType): The type of stream, fixed to `StreamType.Unary`. + input (type[T_Request]): The type of the input request message. + output (type[T_Response]): The type of the output response message. + call (UnaryFunc[T_Request, T_Response]): The asynchronous function that implements the RPC logic, + potentially wrapped with interceptors. """ stream_type: StreamType = StreamType.Unary @@ -371,21 +409,22 @@ def __init__( unary: UnaryFunc[T_Request, T_Response], input: type[T_Request], output: type[T_Response], - options: ConnectOptions | None = None, + options: HandlerOptions | None = None, ) -> None: - """Initialize a handler for a unary RPC procedure. - - Args: - procedure (str): The name of the RPC procedure. - unary (UnaryFunc[T_Request, T_Response]): The asynchronous function implementing the unary RPC logic. - input (type[T_Request]): The expected input type for the request. - output (type[T_Response]): The expected output type for the response. - options (ConnectOptions | None, optional): Optional configuration for the handler, such as interceptors. Defaults to None. + """Initializes a new unary handler. - Calls the superclass initializer with the configured protocol handlers and method options. + This sets up the necessary components for handling a unary RPC call, + including protocol-specific handlers (Connect, gRPC, gRPC-Web) and + any configured interceptors. + Args: + procedure: The full name of the procedure, e.g., "/package.Service/Method". + unary: The asynchronous function that implements the RPC logic. + input: The type of the request message. + output: The type of the response message. + options: Optional configuration for the handler, including interceptors. """ - options = options if options is not None else ConnectOptions() + options = options if options is not None else HandlerOptions() config = HandlerConfig(procedure=procedure, stream_type=StreamType.Unary, options=options) protocol_handlers = create_protocol_handlers(config) @@ -409,20 +448,16 @@ async def _call(request: UnaryRequest[T_Request], context: HandlerContext) -> Un ) async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: - """Handle the implementation of a streaming handler connection. + """Implementation of the unary handler. - This asynchronous method receives a unary request from the given connection, - optionally sets a timeout on the request, invokes the handler's call method, - and sends the response message back through the connection. It also updates - the connection's response headers and trailers, excluding protocol-specific headers. + This method orchestrates the handling of a single incoming request. It + receives the request message, invokes the user-defined RPC logic via + `self.call`, and sends back the resulting response message. It also + propagates headers and trailers from the response to the connection. Args: - conn (StreamingHandlerConn): The streaming handler connection to process. - timeout (float | None): Optional timeout value to set on the request. - - Returns: - None - + conn (StreamingHandlerConn): The connection object for the stream. + timeout (float | None): The timeout for the request, in seconds. """ request = await receive_unary_request(conn, self.input) context = HandlerContext(timeout=timeout) @@ -435,17 +470,21 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None class ServerStreamHandler[T_Request, T_Response](Handler): - """ServerStreamHandler is a handler class for server-streaming RPC procedures. + """Handler for server-streaming RPCs. - This generic class manages the lifecycle and protocol handling for server-streaming RPCs, - where a single request from the client results in a stream of responses from the server. - It sets up protocol handlers, applies interceptors, and provides an asynchronous implementation - method to process incoming streaming requests and send responses. + This class manages the lifecycle of a server-streaming RPC. It is responsible for + receiving a single request message from the client, invoking the user-defined stream + function to generate a stream of response messages, and sending these messages back + to the client. - Type Parameters: - T_Request: The type of the request message. - T_Response: The type of the response message. + It is generic over the request type `T_Request` and the response type `T_Response`. + Attributes: + stream_type (StreamType): The type of stream, always `StreamType.ServerStream`. + input (type[T_Request]): The protobuf message type for the request. + output (type[T_Response]): The protobuf message type for the response. + call (StreamFunc[T_Request, T_Response]): The wrapped, user-provided stream function, + including any configured interceptors. """ stream_type: StreamType = StreamType.ServerStream @@ -459,22 +498,20 @@ def __init__( stream: StreamFunc[T_Request, T_Response], input: type[T_Request], output: type[T_Response], - options: ConnectOptions | None = None, + options: HandlerOptions | None = None, ) -> None: - """Initialize a server-streaming handler for a given procedure. + """Initializes a new server streaming handler. Args: - procedure (str): The name of the RPC procedure. - stream (StreamFunc[T_Request, T_Response]): The asynchronous stream function handling the server-streaming logic. - input (type[T_Request]): The expected request message type. - output (type[T_Response]): The expected response message type. - options (ConnectOptions | None, optional): Additional configuration options for the handler. Defaults to None. - - Raises: - Any exceptions raised by the parent class initializer. - + procedure (str): The full name of the procedure, e.g., /my.service.v1.MyService/MyMethod. + stream (StreamFunc[T_Request, T_Response]): The async function that implements the server + streaming logic. It takes a request stream and a context, and returns a response stream. + input (type[T_Request]): The type of the request message. + output (type[T_Response]): The type of the response message. + options (HandlerOptions | None, optional): Optional configuration for the handler. + Defaults to None. """ - options = options if options is not None else ConnectOptions() + options = options if options is not None else HandlerOptions() config = HandlerConfig(procedure=procedure, stream_type=StreamType.ServerStream, options=options) protocol_handlers = create_protocol_handlers(config) @@ -496,22 +533,18 @@ async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> S ) async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: - """Handle the implementation of a streaming handler. + """Handles the logic for a streaming RPC. - This asynchronous method receives a stream request, optionally sets a timeout, - invokes the handler's call method, updates the connection's response headers and trailers, - and sends the response messages through the connection. + This method orchestrates the handling of a streaming request. It receives + the request data, invokes the user-defined service logic via the `call` + method, and sends the resulting response back to the client. It also + manages the transfer of headers and trailers. Args: - conn (StreamingHandlerConn): The streaming connection handler. - timeout (float | None): Optional timeout value for the request in seconds. - - Returns: - None - - Raises: - Any exceptions raised by `receive_stream_request`, `self.call`, or `conn.send` will propagate. - + conn (StreamingHandlerConn): The connection object representing the + bidirectional stream with the client. + timeout (float | None): An optional timeout in seconds for the handler's + execution. """ request = await receive_stream_request(conn, self.input) context = HandlerContext(timeout=timeout) @@ -524,19 +557,19 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None class ClientStreamHandler[T_Request, T_Response](Handler): - """ClientStreamHandler is a handler class for client-streaming RPC procedures. - - This generic class manages the lifecycle of a client-streaming RPC, including request handling, - stream invocation, interceptor application, and response transmission. It is parameterized by - the request and response message types. - - Type Parameters: - T_Request: The type of the input message for the stream. - T_Response: The type of the output message for the stream. + """A handler for client streaming RPCs. - stream_type (StreamType): The type of stream handled (ClientStream). - call (StreamFunc[T_Request, T_Response]): The wrapped stream call function with applied interceptors. + This handler manages RPCs where the client sends a stream of messages and the + server responds with a single message. It orchestrates receiving the client's + stream, invoking the user-defined implementation, and sending the final + response. + Attributes: + stream_type (StreamType): The type of stream, always `StreamType.ClientStream`. + input (type[T_Request]): The protobuf message class for the request. + output (type[T_Response]): The protobuf message class for the response. + call (StreamFunc[T_Request, T_Response]): The wrapped, interceptor-aware + asynchronous function that implements the RPC logic. """ stream_type: StreamType = StreamType.ClientStream @@ -550,22 +583,24 @@ def __init__( stream: StreamFunc[T_Request, T_Response], input: type[T_Request], output: type[T_Response], - options: ConnectOptions | None = None, + options: HandlerOptions | None = None, ) -> None: - """Initialize a handler for a client-streaming RPC procedure. - - Args: - procedure (str): The name of the RPC procedure. - stream (StreamFunc[T_Request, T_Response]): The asynchronous stream function handling the client-streaming logic. - input (type[T_Request]): The expected input message type. - output (type[T_Response]): The expected output message type. - options (ConnectOptions | None, optional): Additional configuration options for the handler. Defaults to None. + """Initializes a client streaming RPC handler. - Raises: - Any exceptions raised by the parent class initializer or protocol handler creation. + This handler is responsible for processing a client streaming RPC, where the + client sends a stream of messages and the server responds with a single message. + Args: + procedure: The full name of the RPC procedure. + stream: The asynchronous function that implements the RPC logic. + It receives a `StreamRequest` (an async iterator of request + messages) and a `HandlerContext`, and returns a `StreamResponse` + containing the single response message. + input: The protobuf message class for the request. + output: The protobuf message class for the response. + options: Optional configuration for the handler, including interceptors. """ - options = options if options is not None else ConnectOptions() + options = options if options is not None else HandlerOptions() config = HandlerConfig(procedure=procedure, stream_type=StreamType.ClientStream, options=options) protocol_handlers = create_protocol_handlers(config) @@ -587,22 +622,18 @@ async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> S ) async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: - """Handle the implementation of a streaming handler. + """The core implementation for the streaming handler. - This asynchronous method receives a streaming request, optionally sets a timeout, - calls the handler logic, updates the connection's response headers and trailers, - and sends the response messages back through the connection. + This method orchestrates the handling of a streaming request. It receives + the request from the connection, invokes the user-defined call logic + with a context object, and sends the resulting response headers, + trailers, and messages back to the client. Args: - conn (StreamingHandlerConn): The streaming connection handler. - timeout (float | None): Optional timeout value for the request in seconds. - - Returns: - None - - Raises: - Any exceptions raised by `receive_stream_request` or `self.call` will propagate. - + conn (StreamingHandlerConn): The connection object for the streaming RPC, + used for receiving the request and sending the response. + timeout (float | None): The maximum time in seconds to allow for the + handler's execution. """ request = await receive_stream_request(conn, self.input) context = HandlerContext(timeout=timeout) @@ -616,20 +647,22 @@ async def implementation(self, conn: StreamingHandlerConn, timeout: float | None class BidiStreamHandler[T_Request, T_Response](Handler): - """BidiStreamHandler is a handler class for bidirectional streaming procedures. + """Handler for bidirectional streaming procedures. - This generic class manages the lifecycle of a bidirectional streaming RPC, including request/response type validation, - application of interceptors, and integration with protocol-specific handlers. It wraps the provided stream function, - applies any configured interceptors, and exposes an asynchronous implementation method to process streaming requests - and send responses. + This class manages the lifecycle of a bidirectional streaming RPC, where both the client + and the server can send a stream of messages to each other. It wraps the user-provided + stream function with necessary protocol logic and interceptors. - Type Parameters: + Generic Types: T_Request: The type of the request messages. T_Response: The type of the response messages. - stream_type (StreamType): The type of stream handled (always StreamType.BiDiStream). - call (StreamFunc[T_Request, T_Response]): The wrapped stream call function with applied interceptors. - + Attributes: + stream_type (StreamType): The type of stream, set to BiDiStream. + input (type[T_Request]): The expected type for request messages. + output (type[T_Response]): The expected type for response messages. + call (StreamFunc[T_Request, T_Response]): The wrapped, user-provided stream function + that processes the request and generates the response. """ stream_type: StreamType = StreamType.BiDiStream @@ -643,22 +676,18 @@ def __init__( stream: StreamFunc[T_Request, T_Response], input: type[T_Request], output: type[T_Response], - options: ConnectOptions | None = None, + options: HandlerOptions | None = None, ) -> None: - """Initialize a handler for a bidirectional streaming procedure. + """Initializes a bi-directional streaming handler. Args: - procedure (str): The name of the procedure to handle. - stream (StreamFunc[T_Request, T_Response]): The asynchronous stream function handling requests and responses. - input (type[T_Request]): The expected input type for requests. - output (type[T_Response]): The expected output type for responses. - options (ConnectOptions | None, optional): Configuration options for the handler. Defaults to None. - - Raises: - Any exceptions raised by the parent class initializer. - + procedure: The full name of the procedure (e.g., /acme.foo.v1.FooService/Bar). + stream: The async function that implements the bi-directional stream logic. + input: The type of the request message. + output: The type of the response message. + options: Handler-specific options. """ - options = options if options is not None else ConnectOptions() + options = options if options is not None else HandlerOptions() config = HandlerConfig(procedure=procedure, stream_type=StreamType.BiDiStream, options=options) protocol_handlers = create_protocol_handlers(config) @@ -680,22 +709,18 @@ async def _call(request: StreamRequest[T_Request], context: HandlerContext) -> S ) async def implementation(self, conn: StreamingHandlerConn, timeout: float | None) -> None: - """Handle the implementation of a streaming handler. + """Handles the logic for a streaming RPC. - This asynchronous method receives a streaming request, optionally sets a timeout, - calls the main processing function, updates the connection's response headers and trailers, - and sends the response messages back through the connection. + This method orchestrates the handling of a streaming request. It receives + the request data from the connection, invokes the user-defined `call` + method with the request and a context object, and then sends the + resulting response, including headers and trailers, back to the client. Args: - conn (StreamingHandlerConn): The streaming connection handler. - timeout (float | None): Optional timeout value for the request in seconds. - - Returns: - None - - Raises: - Any exceptions raised by `receive_stream_request` or `self.call` will propagate. - + conn (StreamingHandlerConn): The connection object representing the + bidirectional stream with the client. + timeout (float | None): An optional timeout in seconds for the + handler's execution. """ request = await receive_stream_request(conn, self.input) context = HandlerContext(timeout=timeout) diff --git a/src/connect/handler_context.py b/src/connect/handler_context.py index 522a6cd..33f9450 100644 --- a/src/connect/handler_context.py +++ b/src/connect/handler_context.py @@ -1,33 +1,36 @@ -"""Provides the HandlerContext class for managing operation timeouts and tracking remaining time.""" +"""Manages the context for a handler, particularly for handling timeouts.""" import time class HandlerContext: - """HandlerContext manages an optional timeout for operations, allowing tracking of the remaining time until a deadline. + """Manages the context for a handler, particularly for handling timeouts. - Attributes: - _deadline (float | None): The UNIX timestamp representing the deadline, or None if no timeout is set. + This class allows setting a deadline upon initialization and provides a method + to check the remaining time until that deadline. + Attributes: + _deadline (float | None): The timestamp for the deadline, or None if no timeout is set. """ _deadline: float | None def __init__(self, timeout: float | None) -> None: - """Initialize HandlerContext with an optional timeout. + """Initializes a new handler context. Args: - timeout (float | None): The timeout duration in seconds, or None for no timeout. - + timeout: The timeout in seconds. If None, no deadline is set. """ self._deadline = time.time() + timeout if timeout else None def timeout_remaining(self) -> float | None: - """Return the remaining time in seconds until the deadline, or None if no deadline is set. + """Calculates the remaining time in seconds until the handler's deadline. - Returns: - float | None: The number of seconds remaining until the deadline, or None if no deadline is set. + If the request has no deadline, this method returns None. Otherwise, it + returns the difference between the deadline and the current time. + Returns: + float | None: The remaining time in seconds, or None if no deadline is set. """ if self._deadline is None: return None diff --git a/src/connect/handler_interceptor.py b/src/connect/handler_interceptor.py index 45f2d18..a7b7090 100644 --- a/src/connect/handler_interceptor.py +++ b/src/connect/handler_interceptor.py @@ -1,4 +1,4 @@ -"""Defines interceptors and request/response classes for unary and streaming RPC calls.""" +"""Defines handler-side interceptors for the Connect RPC framework.""" import inspect from collections.abc import Awaitable, Callable @@ -12,24 +12,44 @@ class HandlerInterceptor: - """Abstract base class for interceptors that can wrap unary functions.""" + """A handler-side interceptor for wrapping and modifying RPC handlers. + + Interceptors are a powerful mechanism for observing and modifying RPCs without + changing the core application logic. They can be used for tasks like logging, + authentication, metrics collection, or adding custom headers. + + A HandlerInterceptor instance is configured with wrapper functions that are + applied to the RPC handlers. The Connect framework will call these wrappers + before invoking the actual handler. + + Attributes: + wrap_unary (Callable[[UnaryFunc], UnaryFunc] | None): A callable that + takes a unary RPC handler and returns a new, wrapped handler. The + framework calls this function to build the handler chain. + wrap_stream (Callable[[StreamFunc], StreamFunc] | None): A callable that + takes a streaming RPC handler and returns a new, wrapped handler. The + framework calls this function to build the handler chain. + """ wrap_unary: Callable[[UnaryFunc], UnaryFunc] | None = None wrap_stream: Callable[[StreamFunc], StreamFunc] | None = None def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: - """Determine if the given function is a unary function. + """Type guard to determine if a handler function is a unary function. - A unary function is defined as a callable that takes a single parameter - whose type annotation has an origin of `UnaryRequest`. + This function inspects the signature of the provided callable (`next`) to + determine if it matches the expected signature of a `UnaryFunc`. A function + is considered a unary function if it is callable, accepts exactly two + parameters, and the type annotation of its first parameter is `UnaryRequest`. Args: - next (UnaryFunc | StreamFunc): The function to be checked. + next: The handler function to be checked, which can be either a + `UnaryFunc` or a `StreamFunc`. Returns: - TypeGuard[UnaryFunc]: True if the function is a unary function, False otherwise. - + True if the function signature corresponds to a `UnaryFunc`, + False otherwise. """ signature = inspect.signature(next) parameters = list(signature.parameters.values()) @@ -41,17 +61,16 @@ def is_unary_func(next: UnaryFunc | StreamFunc) -> TypeGuard[UnaryFunc]: def is_stream_func(next: UnaryFunc | StreamFunc) -> TypeGuard[StreamFunc]: - """Determine if the given function is a StreamFunc. + """Type guard to determine if a handler function is a StreamFunc. - This function checks if the provided function `next` is callable, has exactly one parameter, - and if the annotation of that parameter has an origin of `StreamRequest`. + A function is considered a StreamFunc if it is a callable that accepts + two arguments, and the first argument is annotated as a StreamRequest. Args: - next (UnaryFunc | StreamFunc): The function to be checked. + next: The function to inspect. Returns: - TypeGuard[StreamFunc]: True if `next` is a StreamFunc, False otherwise. - + True if the function signature matches StreamFunc, False otherwise. """ signature = inspect.signature(next) parameters = list(signature.parameters.values()) @@ -73,20 +92,24 @@ def apply_interceptors(next: StreamFunc, interceptors: list[HandlerInterceptor] def apply_interceptors( next: UnaryFunc | StreamFunc, interceptors: list[HandlerInterceptor] | None ) -> UnaryFunc | StreamFunc: - """Apply a list of interceptors to a given function. + """Applies a list of interceptors to a handler function. + + This function takes a handler function (either unary or streaming) and wraps it + with the provided interceptors. The interceptors are applied in the order they + appear in the list, meaning the first interceptor in the list will be the + outermost wrapper and the last to execute before the actual handler. Args: - next (UnaryFunc | StreamFunc): The function to which interceptors will be applied. - It can be either a unary function or a stream function. - interceptors (list[Interceptor] | None): A list of interceptors to apply. If None, the original function is returned. + next: The handler function (either UnaryFunc or StreamFunc) to be wrapped. + interceptors: An optional list of HandlerInterceptor instances. If None, + the original handler function is returned unmodified. Returns: - UnaryFunc | StreamFunc: The function wrapped with the provided interceptors. + The wrapped handler function, which is of the same type as the input `next` + function. Raises: - ValueError: If an interceptor does not implement the required wrap method for the function type, - or if the provided function type is invalid. - + ValueError: If the `next` function is not a valid UnaryFunc or StreamFunc. """ if interceptors is None: return next diff --git a/src/connect/headers.py b/src/connect/headers.py index 7bed0dc..2f3fb93 100644 --- a/src/connect/headers.py +++ b/src/connect/headers.py @@ -1,4 +1,4 @@ -"""Provides a Headers class for managing HTTP headers.""" +"""Provides a `Headers` class for managing HTTP headers and a utility function for request headers.""" from collections.abc import AsyncIterable, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence from typing import Any, Union @@ -38,14 +38,30 @@ def _normalize_header_value(value: str | bytes, encoding: str | None = None) -> class Headers(MutableMapping[str, str]): - """A class to represent HTTP headers. + """A case-insensitive, multi-valued dictionary for HTTP headers. - Attributes: - raw : list[tuple[bytes, bytes]] - A list of tuples containing the raw header keys and values. - encoding : str - The encoding used for the headers. + This class provides a dictionary-like interface for managing HTTP headers, + with special handling for their unique characteristics. Keys are treated in a + case-insensitive manner, and it's possible to have multiple values for the + same key, which are handled gracefully. + + Internally, headers are stored as bytes to manage character encodings + correctly. The class can auto-detect the encoding or use a specified one. + + When accessing a header that has multiple values via standard dictionary + lookup (`headers['key']`), the values are concatenated into a single, + comma-separated string. To retrieve all key-value pairs, including + duplicates, the `multi_items()` method should be used. + + It inherits from `collections.abc.MutableMapping`, providing standard + dictionary methods like `keys()`, `items()`, `__getitem__`, `__setitem__`, + and `__delitem__`. + Attributes: + encoding (str): The character encoding used for header keys and values. + It can be set manually or is auto-detected. + raw (list[tuple[bytes, bytes]]): A list of the raw (key, value) byte + pairs, preserving the original case of the keys. """ _list: list[tuple[bytes, bytes, bytes]] @@ -55,18 +71,14 @@ def __init__( headers: HeaderTypes | None = None, encoding: str | None = None, ) -> None: - """Initialize the Headers object. + """Initializes a new Headers object. Args: - headers (HeaderTypes | None): Initial headers to populate the object. - Can be another Headers object, a mapping of header key-value pairs, - or an iterable of key-value pairs. - encoding (str | None): The encoding to use for header keys and values. - If None, defaults to the system's default encoding. - - Returns: - None - + headers: An optional initial set of headers. Can be provided as + another `Headers` instance, a mapping (e.g., a dictionary), + or an iterable of (key, value) tuples. + encoding: The character encoding to use for converting string + header keys and values to bytes. """ self._list = [] @@ -87,39 +99,38 @@ def __init__( @property def raw(self) -> list[tuple[bytes, bytes]]: - """Return the raw headers as a list of tuples. - - Each tuple contains the raw key and value as bytes. + """Get the raw headers as a list of key-value pairs. Returns: - list[tuple[bytes, bytes]]: A list of tuples where each tuple contains - the raw key and value as bytes. - + list[tuple[bytes, bytes]]: A list of (key, value) tuples, + where both the key and the value are bytes. """ return [(raw_key, value) for raw_key, _, value in self._list] def keys(self) -> KeysView[str]: - """Return a view of the dictionary's keys. + """Return a new view of the header keys. - This method decodes the keys from the internal list using the specified encoding - and returns a view of these decoded keys. + The keys are decoded from bytes into strings using the configured encoding. Returns: - KeysView[str]: A view object that displays a list of the dictionary's keys. - + A view object displaying a list of all header keys. """ return {key.decode(self.encoding): None for _, key, value in self._list}.keys() @property def encoding(self) -> str: - """Determine and returns the encoding used for the headers. + """Determine and return the encoding for the headers. + + The method iterates through a list of preferred encodings ('ascii', 'utf-8') + and attempts to decode all header keys and values. The first encoding + that successfully decodes all headers without a `UnicodeDecodeError` is + chosen and cached for subsequent calls. - This method attempts to decode the headers using "ascii" and "utf-8" encodings. - If neither of these encodings work, it defaults to "iso-8859-1" which can decode - any byte sequence. + If neither 'ascii' nor 'utf-8' is suitable, it falls back to 'iso-8859-1', + which can represent any byte value and is thus a safe default. Returns: - str: The encoding used for the headers. + str: The name of the determined encoding for the headers. """ if self._encoding is None: @@ -146,32 +157,39 @@ def encoding(self, value: str) -> None: self._encoding = value def copy(self) -> "Headers": - """Return a copy of the Headers object.""" + """Returns a copy of the Headers object. + + Returns: + Headers: A new Headers instance. + """ return Headers(self, encoding=self.encoding) def multi_items(self) -> list[tuple[str, str]]: - """Return a list of tuples containing decoded key-value pairs from the internal list. + """Returns a list of all header key-value pairs. - The keys and values are decoded using the specified encoding. + The keys and values are decoded to strings using the specified encoding. + This method is useful for headers that can appear multiple times, + as it returns all occurrences. Returns: - list[tuple[str, str]]: A list of tuples where each tuple contains a decoded key and value. - + list[tuple[str, str]]: A list of (key, value) tuples. """ return [(key.decode(self.encoding), value.decode(self.encoding)) for _, key, value in self._list] def __getitem__(self, key: str) -> str: - """Retrieve the value associated with the given key in the headers. + """Retrieves a header value by its case-insensitive key. + + If multiple headers share the same key, their values are concatenated + into a single string, separated by a comma and a space. Args: - key (str): The key to look up in the headers. + key: The case-insensitive name of the header to retrieve. Returns: - str: The value(s) associated with the key, joined by ", " if multiple values exist. + The corresponding header value. Raises: - KeyError: If the key is not found in the headers. - + KeyError: If no header with the given key is found. """ normalized_key = key.lower().encode(self.encoding) @@ -187,18 +205,18 @@ def __getitem__(self, key: str) -> str: raise KeyError(key) def __setitem__(self, key: str, value: str) -> None: - """Set the value for a given key in the headers list. + """Sets a header value, treating the header name case-insensitively. - If the key already exists, update its value. If the key appears multiple times, remove all but the first - occurrence before updating. + If a header with the same name (case-insensitively) already exists, + its value is updated. The original casing of the new key is preserved. - Args: - key (str): The header key to set. - value (str): The header value to set. - - Returns: - None + If multiple headers with the same name exist, all subsequent occurrences + are removed, and the first one is updated with the new value. If the + header does not exist, it is added. + Args: + key: The name of the header. + value: The value for the header. """ set_key = key.encode(self._encoding or "utf-8") set_value = value.encode(self._encoding or "utf-8") @@ -216,14 +234,16 @@ def __setitem__(self, key: str, value: str) -> None: self._list.append((set_key, lookup_key, set_value)) def __delitem__(self, key: str) -> None: - """Remove the item with the specified key from the list. + """Delete all headers matching the given key. + + The key matching is case-insensitive. If multiple headers have the same + name, all of them will be removed. Args: - key (str): The key of the item to be removed. + key: The case-insensitive name of the header(s) to remove. Raises: - KeyError: If the key is not found in the list. - + KeyError: If no header with the given key is found. """ del_key = key.lower().encode(self.encoding) @@ -260,21 +280,26 @@ def include_request_headers( content: bytes | Iterable[bytes] | AsyncIterable[bytes] | None, method: str | None = None, ) -> Headers: - """Include necessary request headers if they are not already present. + """Adds required request headers like 'Host' and 'Content-Length' if not present. + + This function inspects the request details (URL, content, method) and + populates essential HTTP headers if they are missing. - This function ensures that the "Host" and "Content-Length" headers are included in the request headers. - If the "Host" header is missing, it will be set based on the URL's host and port. - If the "Content-Length" header is missing and content is provided, it will be set to the length of the content. + - It sets the 'Host' header based on the target URL, including the port + if it's non-standard. + - It determines whether to use 'Content-Length' (for fixed-size byte content) + or 'Transfer-Encoding: chunked' (for streaming content) for the request + body. This is skipped for 'GET' requests or if these headers are already set. Args: - headers (Headers): The original request headers. - url (URL): The URL object containing the scheme, host, and port. - content (bytes | None): The request content, if any. - method (str): The HTTP method of the request. + headers: The mutable headers dictionary-like object for the request. + url: The URL object of the request, used to determine the 'Host' header. + content: The request body, which can be bytes, an iterable of bytes, + or an async iterable of bytes. + method: The HTTP method of the request (e.g., "GET", "POST"). Returns: - Headers: The updated request headers with the necessary headers included. - + The updated headers object. """ if headers.get("Host") is None: default_port = DEFAULT_PORTS.get(url.scheme.encode()) diff --git a/src/connect/idempotency_level.py b/src/connect/idempotency_level.py index c821d82..43d9a5e 100644 --- a/src/connect/idempotency_level.py +++ b/src/connect/idempotency_level.py @@ -4,13 +4,24 @@ class IdempotencyLevel(IntEnum): - """IdempotencyLevel is an enumeration that represents different levels of idempotency. + """Defines the idempotency level of an API operation. - Attributes: - IDEMPOTENCY_UNKNOWN (int): Represents an unknown idempotency level. - NO_SIDE_EFFECTS (int): Indicates that the operation has no side effects. - IDEMPOTENT (int): Indicates that the operation is idempotent, meaning it can be performed multiple times without changing the result. + Idempotency is the property of certain operations that can be applied + multiple times without changing the result beyond the initial application. + In the context of APIs, this means that making the same request multiple + times will have the same effect as making it once. This is crucial for + building robust systems that can safely retry requests in case of + network failures or other transient errors. + Attributes: + IDEMPOTENCY_UNKNOWN: The idempotency level is not specified or known. + This is the default value. + NO_SIDE_EFFECTS: The operation has no side effects on the server state. + It is safe to retry indefinitely. This typically corresponds to + read operations like HTTP GET. + IDEMPOTENT: The operation is idempotent. It can be safely retried as + multiple identical requests will produce the same result as a single + request. This typically corresponds to operations like HTTP PUT or DELETE. """ IDEMPOTENCY_UNKNOWN = 0 diff --git a/src/connect/middleware.py b/src/connect/middleware.py index 9c7f03b..b4e004b 100644 --- a/src/connect/middleware.py +++ b/src/connect/middleware.py @@ -1,4 +1,4 @@ -"""Middleware for handling HTTP requests.""" +"""Provides ASGI middleware for routing requests to Connect protocol handlers.""" from collections.abc import Awaitable, Callable @@ -13,44 +13,48 @@ class ConnectMiddleware: - """Middleware for handling ASGI applications with unary handlers. + """ASGI middleware for routing requests to Connect-style handlers. - Attributes: - app (ASGIApp): The ASGI application to wrap. - handlers (list[Handler]): A list of unary handlers to process requests. + This middleware intercepts incoming HTTP requests and attempts to match them + against a list of registered `Handler` instances based on the request path. + If a matching handler is found for the request's route, it processes the + request and sends a response. If no handler matches the route, the request + is forwarded to the next ASGI application in the stack. + + This allows for integrating Connect-protocol services within a standard + ASGI application framework, such as Starlette or FastAPI. + Attributes: + app (ASGIApp): The next ASGI application in the middleware stack. + handlers (list[Handler]): A list of Connect handlers to which requests + can be routed. """ app: ASGIApp handlers: list[Handler] def __init__(self, app: ASGIApp, handlers: list[Handler]) -> None: - """Initialize the middleware with the given ASGI application and handlers. + """Initializes the middleware. Args: - app (ASGIApp): The ASGI application instance. - handlers (list[Handler]): A list of unary handlers to process requests. - + app: The ASGI application. + handlers: A list of handlers to be used by the middleware. """ self.app = app self.handlers = handlers async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Asynchronous callable method to handle incoming ASGI requests. + """The ASGI application entry point. - This method intercepts HTTP requests, determines the appropriate handler - based on the route path, and delegates the request to the handler if found. - If no handler is found, the request is passed to the next application in the - middleware stack. + This method is called for each request. It checks if the request is an HTTP + request and if the path matches any of the registered handlers. If a match + is found, the request is handled by the corresponding handler. Otherwise, + the request is passed on to the next ASGI application in the stack. Args: - scope (Scope): The ASGI scope dictionary containing request information. - receive (Receive): The ASGI receive callable to receive messages. - send (Send): The ASGI send callable to send messages. - - Returns: - None - + scope (Scope): The ASGI connection scope. + receive (Receive): The ASGI receive channel. + send (Send): The ASGI send channel. """ if scope["type"] == "http": route_path = get_route_path(scope) diff --git a/src/connect/options.py b/src/connect/options.py index 9c8770c..89460e0 100644 --- a/src/connect/options.py +++ b/src/connect/options.py @@ -1,4 +1,4 @@ -"""Options for the UniversalHandler class.""" +"""Defines configuration options for Connect RPC clients and handlers.""" from typing import Any, Literal @@ -9,8 +9,26 @@ from connect.idempotency_level import IdempotencyLevel -class ConnectOptions(BaseModel): - """Options for the connect command.""" +class HandlerOptions(BaseModel): + """Configuration options for a handler. + + This class encapsulates various settings that control the behavior of a handler + in the Connect protocol implementation. It allows for customization of interceptors, + protocol requirements, and data handling limits. + + Attributes: + interceptors (list[HandlerInterceptor]): A list of interceptors to apply to the handler. + descriptor (Any): The descriptor for the RPC method. + idempotency_level (IdempotencyLevel): The idempotency level of the RPC method. + require_connect_protocol_header (bool): A boolean indicating whether requests + using the Connect protocol should include the protocol version header. + compress_min_bytes (int): The minimum number of bytes for a response to be + eligible for compression. A value of -1 disables compression. + read_max_bytes (int): The maximum number of bytes to read from a request body. + A value of -1 indicates no limit. + send_max_bytes (int): The maximum number of bytes to send in a response body. + A value of -1 indicates no limit. + """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -35,16 +53,19 @@ class ConnectOptions(BaseModel): send_max_bytes: int = Field(default=-1) """The maximum number of bytes to send.""" - def merge(self, override_options: "ConnectOptions | None" = None) -> "ConnectOptions": - """Merge this options object with an override options object. + def merge(self, override_options: "HandlerOptions | None" = None) -> "HandlerOptions": + """Merges this HandlerOptions instance with another, creating a new instance. + + The values from the `override_options` will take precedence over the + values in the current instance. Only the fields that are explicitly set + in `override_options` are used for the merge. Args: - override_options (ConnectOptions | None): Optional override options object. - If None, this options object is returned as is. + override_options: An optional HandlerOptions object to merge with. Returns: - ConnectOptions: A new instance with attributes merged from both objects. - + A new HandlerOptions instance with the merged options. If + `override_options` is None, the original instance is returned. """ if override_options is None: return self @@ -53,11 +74,27 @@ def merge(self, override_options: "ConnectOptions | None" = None) -> "ConnectOpt explicit_overrides = override_options.model_dump(exclude_unset=True) merged_data.update(explicit_overrides) - return ConnectOptions(**merged_data) + return HandlerOptions(**merged_data) class ClientOptions(BaseModel): - """Options for the Connect client.""" + """Configuration options for a client. + + This class holds settings that control the behavior of client-side RPC calls, + such as interceptors, compression, and protocol-specific details. + + Attributes: + interceptors (list[ClientInterceptor]): A list of interceptors to apply to the handler. + descriptor (Any): The descriptor for the RPC method. + idempotency_level (IdempotencyLevel): The idempotency level of the RPC method. + request_compression_name (str | None): The name of the compression method to use for requests. + compress_min_bytes (int): The minimum number of bytes to compress. + read_max_bytes (int): The maximum number of bytes to read. + send_max_bytes (int): The maximum number of bytes to send. + enable_get (bool): A boolean indicating whether to enable GET requests. + protocol (Literal["connect", "grpc", "grpc-web"]): The protocol to use for the request. + use_binary_format (bool): A boolean indicating whether to use binary format for the request. + """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -92,15 +129,22 @@ class ClientOptions(BaseModel): """A boolean indicating whether to use binary format for the request.""" def merge(self, override_options: "ClientOptions | None" = None) -> "ClientOptions": - """Merge this options object with an override options object. + """Creates a new ClientOptions instance by merging with override options. + + If override_options is provided, this method returns a new ClientOptions + instance that is a copy of the current options, updated with any + explicitly set values from the override_options. If override_options + is None, it returns the current instance. Args: - override_options (ClientOptions | None): Optional override options object. - If None, this options object is returned as is. + override_options (ClientOptions | None, optional): + The options to merge with. Fields explicitly set in this + object will override the corresponding values in the current + options. Defaults to None. Returns: - ClientOptions: A new instance with attributes merged from both objects. - + ClientOptions: A new instance with the merged options, or the + current instance if no override options are provided. """ if override_options is None: return self diff --git a/src/connect/protocol.py b/src/connect/protocol.py index 30a4f3e..ded1d62 100644 --- a/src/connect/protocol.py +++ b/src/connect/protocol.py @@ -1,4 +1,4 @@ -"""Module defining the protocol handling classes and functions.""" +"""Defines the abstract interfaces and helpers for RPC protocol implementations.""" import abc from http import HTTPMethod @@ -37,16 +37,20 @@ class ProtocolHandlerParams(BaseModel): - """ProtocolHandlerParams is a data model that holds parameters for handling protocol operations. + """Parameters for configuring a protocol handler. - Attributes: - spec (Spec): The specification details for the protocol. - codecs (ReadOnlyCodecs): The codecs used for encoding and decoding data. - compressions (list[Compression]): A list of compression methods to be used. - compress_min_bytes (int): The minimum number of bytes required to trigger compression. - read_max_bytes (int): The maximum number of bytes that can be read at once. - send_max_bytes (int): The maximum number of bytes that can be sent at once. + This class encapsulates all the configuration options needed to set up + a protocol handler for Connect RPC communication. + Attributes: + spec: The service specification defining available methods and types. + codecs: Read-only collection of codecs for message serialization/deserialization. + compressions: List of supported compression algorithms. + compress_min_bytes: Minimum message size in bytes before compression is applied. + read_max_bytes: Maximum number of bytes that can be read in a single operation. + send_max_bytes: Maximum number of bytes that can be sent in a single operation. + require_connect_protocol_header: Whether to require Connect protocol headers. + idempotency_level: The level of idempotency support for operations. """ model_config = ConfigDict( @@ -54,94 +58,165 @@ class ProtocolHandlerParams(BaseModel): ) spec: Spec + """The service specification defining available methods and types.""" codecs: ReadOnlyCodecs + """Read-only collection of codecs for message serialization/deserialization.""" compressions: list[Compression] + """List of supported compression algorithms.""" compress_min_bytes: int + """Minimum message size in bytes before compression is applied.""" read_max_bytes: int + """Maximum number of bytes that can be read in a single operation.""" send_max_bytes: int + """Maximum number of bytes that can be sent in a single operation.""" require_connect_protocol_header: bool + """Whether to require Connect protocol headers.""" idempotency_level: IdempotencyLevel + """The level of idempotency support for operations.""" class ProtocolClientParams(BaseModel): - """ProtocolClientParams is a data model for configuring protocol client parameters.""" + """Parameters for configuring a protocol client. + + This class defines the configuration parameters needed to create and operate + a protocol client for network communication. + + Attributes: + pool (AsyncConnectionPool): The connection pool for managing async connections. + codec (Codec): The codec used for encoding/decoding messages. + url (URL): The target URL for the protocol client. + compression_name (str | None): The name of the compression algorithm to use. + Defaults to None if no compression is specified. + compressions (list[Compression]): List of available compression algorithms. + compress_min_bytes (int): Minimum number of bytes required before applying compression. + read_max_bytes (int): Maximum number of bytes that can be read in a single operation. + send_max_bytes (int): Maximum number of bytes that can be sent in a single operation. + enable_get (bool): Whether GET requests are enabled for this client. + """ model_config = ConfigDict( arbitrary_types_allowed=True, ) pool: AsyncConnectionPool + """The connection pool for managing async connections.""" codec: Codec + """The codec used for encoding/decoding messages.""" url: URL + """The target URL for the protocol client.""" compression_name: str | None = Field(default=None) + """The name of the compression algorithm to use. Defaults to None if no compression is specified.""" compressions: list[Compression] + """List of available compression algorithms.""" compress_min_bytes: int + """Minimum number of bytes required before applying compression.""" read_max_bytes: int + """Maximum number of bytes that can be read in a single operation.""" send_max_bytes: int + """Maximum number of bytes that can be sent in a single operation.""" enable_get: bool + """Whether GET requests are enabled for this client.""" class ProtocolClient(abc.ABC): - """Abstract base class for defining a protocol client.""" + """Abstract base class for protocol clients that handle communication with remote services. + + This class defines the interface for protocol clients that manage connections, + handle request headers, and provide peer information for different streaming protocols. + + The ProtocolClient serves as a foundation for implementing specific protocol + handlers (such as HTTP/1.1, HTTP/2, or gRPC) while maintaining a consistent + interface for client operations. + """ @property @abc.abstractmethod def peer(self) -> Peer: - """Retern the peer for the client.""" + """Get the peer information for this connection. + + Returns: + Peer: The peer object containing information about the connected peer. + """ raise NotImplementedError() @abc.abstractmethod def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: - """Write the request headers.""" + """Write request headers to the stream. + + Args: + stream_type (StreamType): The type of stream being used for the request. + headers (Headers): The headers to be written to the request stream. + """ raise NotImplementedError() @abc.abstractmethod def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Return the connection for the client.""" + """Establish a connection to a streaming service. + + Args: + spec (Spec): The specification object containing connection details and configuration. + headers (Headers): HTTP headers to be included with the connection request. + + Returns: + StreamingClientConn: A streaming client connection object for communicating with the service. + """ raise NotImplementedError() class ProtocolHandler(abc.ABC): - """Abstract base class for handling different protocols.""" + """Abstract base class for defining protocol handlers. + + This class provides a standardized interface for different communication + protocols, such as gRPC, Connect, and gRPC-Web. By subclassing + `ProtocolHandler`, developers can integrate custom or standard protocols + into the server framework. + + Subclasses must implement the abstract methods defined in this class to + specify the supported HTTP methods, content types, and to provide the + core logic for handling connections and processing requests. + """ @property @abc.abstractmethod def methods(self) -> list[HTTPMethod]: - """Retrieve a list of HTTP methods. + """Gets the HTTP methods that this protocol supports. - Returns: - list[HTTPMethod]: A list of HTTP methods. + This is an abstract method that must be implemented by subclasses. + Returns: + A list of the supported HTTP methods. """ raise NotImplementedError() @abc.abstractmethod def content_types(self) -> list[str]: - """Handle content types. + """Gets the list of supported content types. - This method currently does nothing and is intended to be implemented - in the future to handle different content types as needed. + This is an abstract method that must be implemented by subclasses. It should + return a list of MIME types that the protocol implementation can handle. Returns: - None + list[str]: A list of supported content type strings (e.g., "application/json"). + Raises: + NotImplementedError: If the method is not overridden by a subclass. """ raise NotImplementedError() @abc.abstractmethod def can_handle_payload(self, request: Request, content_type: str) -> bool: - """Determine if the payload of the given request can be handled based on the content type. + """Checks if the protocol can handle the request payload. + + This method determines whether the current protocol implementation is capable + of processing the payload of a given request based on its content type. + Subclasses must implement this method. Args: - request (Request): The request object containing the payload. - content_type (str): The content type of the payload. + request (Request): The incoming request object. + content_type (str): The content type of the request's payload. Returns: - bool: True if the payload can be handled, False otherwise. - - Raises: - NotImplementedError: This method should be implemented by subclasses. - + bool: True if the payload can be handled by this protocol, False otherwise. """ raise NotImplementedError() @@ -153,65 +228,82 @@ async def conn( response_trailers: Headers, writer: ServerResponseWriter, ) -> StreamingHandlerConn | None: - """Handle a connection request. + """Initializes a streaming connection handler. + + This method is called by the server to begin handling a streaming RPC. + It sets up the necessary context and returns a handler object that will + process the incoming and outgoing messages for the stream. Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. + request: The incoming request details. + response_headers: A mutable headers object to which response headers should be written. + response_trailers: A mutable headers object to which response trailers should be written. + writer: The writer for sending response messages and trailers. Returns: - StreamingHandlerConn | None: The connection handler or None if not implemented. - - Raises: - NotImplementedError: If the method is not implemented. - + An instance of a `StreamingHandlerConn` to handle the stream, or `None` to + terminate the connection. """ raise NotImplementedError() class Protocol(abc.ABC): - """Abstract base class for defining a protocol. + """Defines the abstract interface for a communication protocol. - This class serves as a blueprint for creating protocol handlers and clients. - Subclasses must implement the following abstract methods. + This abstract base class (ABC) establishes a contract for different + communication protocols. It ensures that any protocol implementation + provides a consistent way to create both a server-side handler and a + client. + Subclasses are required to implement the `handler` and `client` methods + to provide the specific logic for their respective protocol. """ @abc.abstractmethod def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: - """Handle the protocol with the given parameters. + """Gets the appropriate protocol handler for the given parameters. + + This method is intended to be overridden by subclasses to return a specific + handler instance based on the provided parameters. Args: - params (ProtocolHandlerParams): The parameters required to handle the protocol. + params: The parameters used to determine which handler to return. Returns: - ProtocolHandler: An instance of ProtocolHandler based on the provided parameters. - + An instance of a class that implements the ProtocolHandler protocol. """ raise NotImplementedError() @abc.abstractmethod def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Implement client functionality. + """Creates and returns a client instance for this protocol. + + This is an abstract method that must be implemented by subclasses. It should + take the necessary parameters and return a fully configured client object + capable of communicating using the defined protocol. - This method currently does nothing and is intended to be implemented - in the future with the necessary client-side logic. + Args: + params: The parameters required to initialize the client. + + Returns: + An instance of the protocol client. """ raise NotImplementedError() def mapped_method_handlers(handlers: list[ProtocolHandler]) -> dict[HTTPMethod, list[ProtocolHandler]]: - """Map protocol handlers to their respective HTTP methods. + """Groups protocol handlers by the HTTP methods they support. + + This function takes a flat list of protocol handlers and transforms it into a + dictionary where keys are HTTP methods and values are lists of handlers + that support that method. Args: - handlers (list[ProtocolHandler]): A list of protocol handlers. + handlers: A list of ProtocolHandler instances to be mapped. Returns: - dict[HTTPMethod, list[ProtocolHandler]]: A dictionary where the keys are HTTP methods and the values are lists of protocol handlers that support those methods. - + A dictionary mapping each HTTPMethod to a list of ProtocolHandlers + that support it. """ method_handlers: dict[HTTPMethod, list[ProtocolHandler]] = {} for handler in handlers: @@ -224,21 +316,24 @@ def mapped_method_handlers(handlers: list[ProtocolHandler]) -> dict[HTTPMethod, def negotiate_compression( available: list[Compression], sent: str | None, accept: str | None ) -> tuple[Compression | None, Compression | None, ConnectError | None]: - """Negotiate the compression method to be used based on the available options. + """Negotiates the compression algorithms for the request and response. - The compression method sent by the client, and the compression methods accepted - by the server. + This function determines which compression algorithm to use for decompressing + the request body and compressing the response body based on the client's + headers and the server's available options. Args: - available (list[Compression]): A list of available compression methods. - sent (str | None): The compression method sent by the client, or None if not specified. - accept (str | None): A comma-separated string of compression methods accepted by the server, or None if not specified. - header_name_accept_encoding (str): The name of the header used to specify the accepted compression methods. + available (list[Compression]): A list of compression algorithms supported by the server. + sent (str | None): The value of the `connect-content-encoding` header from the request, + indicating the compression used for the request body. + accept (str | None): The value of the `connect-accept-encoding` header from the request, + indicating the compression algorithms the client accepts for the response. Returns: - tuple[Compression | None, Compression | None]: A tuple containing the selected compression method for the request - and the response. If no suitable compression method is found, None is returned for that position in the tuple. - + tuple[Compression | None, Compression | None, ConnectError | None]: A tuple containing: + - The compression algorithm for the request body (or None). + - The compression algorithm for the response body (or None). + - A ConnectError if the client sent an unsupported compression, otherwise None. """ request = None response = None @@ -271,37 +366,61 @@ def negotiate_compression( def sorted_allow_method_value(handlers: list[ProtocolHandler]) -> str: - """Sort the allowed methods for a list of protocol handlers. + """Generates a sorted, comma-separated string of method values from handlers. + + This function aggregates all unique HTTP methods from a list of protocol + handlers, sorts them alphabetically, and returns them as a single + comma-separated string. This is typically used to generate the value for + the `Allow` HTTP header. Args: - handlers (list[ProtocolHandler]): A list of protocol handlers. + handlers: A list of ProtocolHandler instances from which to extract + HTTP methods. Returns: - str: A comma-separated string of the allowed methods. - + A string containing the sorted, unique HTTP method values, + joined by ", ". For example: "GET, POST, PUT". """ methods = {method for handler in handlers for method in handler.methods} return ", ".join(sorted(method.value for method in methods)) def sorted_accept_post_value(handlers: list[ProtocolHandler]) -> str: - """Sort the allowed methods for a list of protocol handlers. + """Generates a sorted, comma-separated string of content types. + + This function takes a list of protocol handlers, collects all unique + content types they support, sorts them alphabetically, and formats them + into a single string suitable for an `Accept-Post` header. Args: - handlers (list[ProtocolHandler]): A list of protocol handlers. + handlers: A list of `ProtocolHandler` instances. Returns: - str: A comma-separated string of the allowed methods. - + A string containing the sorted, comma-separated list of + supported content types. """ content_types = {content_type for handler in handlers for content_type in handler.content_types()} return ", ".join(sorted(content_type for content_type in content_types)) def code_from_http_status(status: int) -> Code: - """Determine the gRPC-web error code for the given HTTP status code. + """Converts an HTTP status code to a gRPC `Code`. - See https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. + This function implements the mapping from HTTP status codes to gRPC status codes + as specified by the gRPC-HTTP2 mapping. It handles common error codes like 400, + 401, 403, 404, 429, and 5xx. + + Note that a 200 OK status is mapped to `Code.UNKNOWN` because a successful + gRPC response over HTTP is expected to include a `grpc-status` header to + indicate the true status. + + See: https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. + + Args: + status (int): The HTTP status code. + + Returns: + Code: The corresponding gRPC status code. """ match status: case 400: # Bad Request @@ -325,18 +444,17 @@ def code_from_http_status(status: int) -> Code: def exclude_protocol_headers(headers: Headers) -> Headers: - """Exclude protocol-specific headers from the given Headers object. + """Filters out protocol-specific headers from a Headers object. - This function filters out headers that are either standard HTTP headers - or specific to the Connect protocol, and returns a new Headers object - containing only the non-protocol headers. + This function iterates through a given set of headers and creates a new + Headers object that excludes common HTTP, Connect, and gRPC protocol + headers. The resulting object contains only application-specific headers. Args: - headers (Headers): The original Headers object containing all headers. + headers (Headers): The input headers to be filtered. Returns: - Headers: A new Headers object containing only the non-protocol headers. - + Headers: A new Headers object containing only non-protocol headers. """ non_protocol_headers = Headers(encoding=headers.encoding) for key, value in headers.items(): diff --git a/src/connect/protocol_connect/base64_utils.py b/src/connect/protocol_connect/base64_utils.py index f265c7a..4bf8fcf 100644 --- a/src/connect/protocol_connect/base64_utils.py +++ b/src/connect/protocol_connect/base64_utils.py @@ -15,7 +15,6 @@ def decode_base64_with_padding(value: str) -> bytes: Raises: Exception: If base64 decoding fails """ - # Add padding if needed padded_value = value + "=" * (-len(value) % 4) return base64.b64decode(padded_value.encode()) @@ -32,7 +31,6 @@ def decode_urlsafe_base64_with_padding(value: str) -> bytes: Raises: Exception: If base64 decoding fails """ - # Add padding if needed padded_value = value + "=" * (-len(value) % 4) return base64.urlsafe_b64decode(padded_value) diff --git a/src/connect/protocol_connect/connect_client.py b/src/connect/protocol_connect/connect_client.py index bf1656f..22ec274 100644 --- a/src/connect/protocol_connect/connect_client.py +++ b/src/connect/protocol_connect/connect_client.py @@ -1,4 +1,4 @@ -"""Provides a ConnectClient class for handling connections using the Connect protocol.""" +"""Connect protocol client implementation for unary and streaming RPCs.""" import asyncio import contextlib @@ -68,23 +68,30 @@ class ConnectClient(ProtocolClient): - """ConnectClient is a client for handling connections using the Connect protocol. + """ConnectClient is a client implementation for the Connect protocol, extending ProtocolClient. - Attributes: - params (ProtocolClientParams): Parameters for the protocol client. - _peer (Peer): The peer object representing the connection endpoint. + This class is responsible for initializing and managing the connection parameters, peer information, + and handling the construction of request headers and client connections for both unary and streaming + communication patterns. + Attributes: + params (ProtocolClientParams): The parameters required to initialize the client, including codec, + compression settings, connection pool, and URL. + _peer (Peer): The peer instance representing the remote endpoint for the connection. """ params: ProtocolClientParams _peer: Peer def __init__(self, params: ProtocolClientParams) -> None: - """Initialize the ProtocolConnect instance with the given parameters. + """Initializes the ConnectClient with the given protocol client parameters. Args: - params (ProtocolClientParams): The parameters required to initialize the ProtocolConnect instance. + params (ProtocolClientParams): The parameters required to configure the protocol client, including URL information. + Attributes: + params (ProtocolClientParams): Stores the provided protocol client parameters. + _peer (Peer): Represents the peer connection, initialized with the host and port from the provided URL, using the CONNECT protocol. """ self.params = params self._peer = Peer( @@ -95,27 +102,29 @@ def __init__(self, params: ProtocolClientParams) -> None: @property def peer(self) -> Peer: - """Return the peer associated with this instance. + """Returns the associated Peer object for this client. - :return: The peer associated with this instance. - :rtype: Peer + Returns: + Peer: The peer instance associated with this client. """ return self._peer def write_request_headers(self, stream_type: StreamType, headers: Headers) -> None: - """Write the necessary request headers to the provided headers dictionary. - - This method ensures that the headers dictionary contains the required headers - for a request, including user agent, protocol version, content type, and - optionally, compression settings. + """Sets and updates HTTP headers for a Connect protocol request based on the stream type and client parameters. Args: - stream_type (StreamType): The type of stream for the request. - headers (Headers): The dictionary of headers to be updated. - - Returns: - None - + stream_type (StreamType): The type of stream (e.g., Unary or Streaming) for the request. + headers (Headers): The dictionary of HTTP headers to be sent with the request. This dictionary is modified in-place. + + Behavior: + - Ensures the 'User-Agent' header is set to a default value if not already present. + - Sets the Connect protocol version and appropriate 'Content-Type' header based on the codec. + - Configures compression-related headers depending on the stream type and client compression settings. + - For streaming requests, sets both accepted and used compression headers if applicable. + - Updates the 'Accept-Encoding' header to list all supported compression algorithms if provided. + + Modifies: + The `headers` dictionary is updated in-place with the necessary protocol and compression headers. """ if headers.get(HEADER_USER_AGENT, None) is None: headers[HEADER_USER_AGENT] = DEFAULT_CONNECT_USER_AGENT @@ -134,15 +143,21 @@ def write_request_headers(self, stream_type: StreamType, headers: Headers) -> No headers[accept_compression_header] = ", ".join(c.name for c in self.params.compressions) def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Establish a unary client connection with the given specifications and headers. + """Creates and returns a streaming client connection based on the provided specification and headers. + + Depending on the `stream_type` in the `spec`, this method initializes either a unary or streaming client connection + with the appropriate marshaler and unmarshaler configurations. Args: - spec (Spec): The specification for the connection. - headers (Headers): The headers to be included in the request. + spec (Spec): The specification for the connection, including stream type and idempotency level. + headers (Headers): The request headers to be used for the connection. Returns: - UnaryClientConn: The established unary client connection. + StreamingClientConn: An initialized client connection object, either unary or streaming. + Notes: + - For unary connections with `IdempotencyLevel.NO_SIDE_EFFECTS`, additional marshaler parameters are set. + - The connection is configured with compression, codec, and byte limit settings as specified in `self.params`. """ conn: StreamingClientConn if spec.stream_type == StreamType.Unary: @@ -195,21 +210,25 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: class ConnectUnaryClientConn(StreamingClientConn): - """A client connection for unary RPCs using the Connect protocol. + """ConnectUnaryClientConn provides a client-side connection for unary (single-request, single-response) Connect protocol calls. - Attributes: - _spec (Spec): The specification for the connection. - _peer (Peer): The peer information. - url (URL): The URL for the connection. - compressions (list[Compression]): List of supported compressions. - marshaler (ConnectUnaryRequestMarshaler): The marshaler for requests. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for responses. - response_content (bytes | None): The content of the response. - _response_headers (Headers): The headers of the response. - _response_trailers (Headers): The trailers of the response. - _request_headers (Headers): The headers of the request. - _event_hooks (dict[str, list[EventHook]]): Event hooks for request and response. + This class manages the lifecycle of a unary request, including marshaling the request, sending it over HTTP (GET or POST), + handling timeouts and aborts, processing response headers and trailers, and unmarshaling the response. It also supports + event hooks for request and response processing, and manages compression and content-type validation. + Attributes: + pool (AsyncConnectionPool): The connection pool used for sending requests. + _spec (Spec): The protocol specification for the connection. + _peer (Peer): The peer information for the connection. + url (URL): The target URL for the connection. + compressions (list[Compression]): Supported compression methods. + marshaler (ConnectUnaryRequestMarshaler): Marshaler for encoding requests. + unmarshaler (ConnectUnaryUnmarshaler): Unmarshaler for decoding responses. + response_content (bytes | None): The raw response content, if available. + _response_headers (Headers): Headers received in the response. + _response_trailers (Headers): Trailers received after the response body. + _request_headers (Headers): Headers to send with the request. + _event_hooks (dict[str, list[EventHook]]): Registered event hooks for request and response events. """ pool: AsyncConnectionPool @@ -223,6 +242,7 @@ class ConnectUnaryClientConn(StreamingClientConn): _response_headers: Headers _response_trailers: Headers _request_headers: Headers + _event_hooks: dict[str, list[EventHook]] def __init__( self, @@ -236,22 +256,32 @@ def __init__( unmarshaler: ConnectUnaryUnmarshaler, event_hooks: None | (Mapping[str, list[EventHook]]) = None, ) -> None: - """Initialize the ConnectProtocol instance. + """Initializes a new instance of the client. Args: - pool (AsyncConnectionPool): The connection pool for the client. - spec (Spec): The specification for the connection. - peer (Peer): The peer information. - url (URL): The URL for the connection. - compressions (list[Compression]): List of compression methods. - request_headers (Headers): The headers for the request. - marshaler (ConnectUnaryRequestMarshaler): The marshaler for the request. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for the response. - event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. - - Returns: - None - + pool (AsyncConnectionPool): The connection pool to use for managing connections. + spec (Spec): The specification object describing the protocol or service. + peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + compressions (list[Compression]): List of supported compression algorithms. + request_headers (Headers): Headers to include in outgoing requests. + marshaler (ConnectUnaryRequestMarshaler): Marshaler for serializing requests. + unmarshaler (ConnectUnaryUnmarshaler): Unmarshaler for deserializing responses. + event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. + + Attributes: + pool (AsyncConnectionPool): The connection pool instance. + _spec (Spec): The protocol or service specification. + _peer (Peer): The peer information. + url (URL): The endpoint URL. + compressions (list[Compression]): Supported compression algorithms. + marshaler (ConnectUnaryRequestMarshaler): Request marshaler. + unmarshaler (ConnectUnaryUnmarshaler): Response unmarshaler. + response_content: The content of the response (initialized as None). + _response_headers (Headers): Headers from the response. + _response_trailers (Headers): Trailers from the response. + _request_headers (Headers): Headers for outgoing requests. + _event_hooks (dict): Event hooks for "request" and "response" events. """ event_hooks = {} if event_hooks is None else event_hooks @@ -273,20 +303,19 @@ def __init__( @property def spec(self) -> Spec: - """Return the specification of the protocol. + """Returns the specification object associated with this client. Returns: - Spec: The specification object of the protocol. - + Spec: The specification instance used by the client. """ return self._spec @property def peer(self) -> Peer: - """Return the peer object associated with this instance. + """Returns the associated Peer object for this client. - :return: The peer object. - :rtype: Peer + Returns: + Peer: The peer instance associated with this client, representing the remote endpoint. """ return self._peer @@ -298,67 +327,67 @@ async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: Yields: Any: The unmarshaled object. - """ obj = await self.unmarshaler.unmarshal(message) yield obj def receive(self, message: Any, _abort_event: asyncio.Event | None) -> AsyncIterator[Any]: - """Receives a message and returns an asynchronous iterator over the processed message. + """Receives messages asynchronously based on the provided input message. Args: - message (Any): The message to be received and processed. + message (Any): The input message or request to process. + _abort_event (asyncio.Event | None): Optional event to signal abortion of the receive operation. - Returns: - AsyncIterator[Any]: An asynchronous iterator yielding processed message(s). + Yields: + Any: Messages received from the underlying message stream. + Returns: + AsyncIterator[Any]: An asynchronous iterator yielding received messages. """ return self._receive_messages(message) @property def request_headers(self) -> Headers: - """Retrieve the request headers. + """Returns the HTTP headers to be included in the request. Returns: - Headers: A dictionary-like object containing the request headers. - + Headers: The headers to be sent with the request. """ return self._request_headers def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be called when a request is sent. + """Registers a callback function to be invoked whenever a request is sent. Args: - fn (EventHook): The callback function to be registered. This function - will be called with the request details when a request - is sent. + fn (EventHook): The callback function to be added to the 'request' event hook. + Returns: + None """ self._event_hooks["request"].append(fn) async def send( self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: - """Send a single message asynchronously using either HTTP GET or POST, with support for timeouts and request abortion. + """Sends a single message asynchronously using either HTTP GET or POST, with optional timeout and abort support. Args: messages (AsyncIterable[Any]): An asynchronous iterable yielding the message(s) to send. Only a single message is allowed. - timeout (float | None): Optional timeout in seconds for the request. If provided, sets a read timeout for the request. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, aborts the request. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets the request timeout. + abort_event (asyncio.Event | None): Optional asyncio event that, when set, aborts the request. Raises: - ConnectError: If the request is aborted before or during execution, or if other connection errors occur. + ConnectError: If the marshaler URL is not set when required, if the request is aborted, or for internal errors. + Exception: Propagates exceptions raised by the underlying HTTP client. Side Effects: - - Modifies request headers for timeout and content length as needed. + - Modifies request headers based on timeout and content length. - Invokes registered request and response event hooks. - - Sets the unmarshaler's stream to the response stream for further processing. - - Validates the response after receiving it. - - Notes: - - If `marshaler.enable_get` is True, sends the request as HTTP GET; otherwise, uses HTTP POST. - - Handles cancellation and cleanup if the abort event is triggered during the request. + - Sets the unmarshaler's stream to the response stream. + - Validates the response. + Returns: + None """ extensions = {} if timeout: @@ -447,23 +476,21 @@ async def send( @property def response_headers(self) -> Headers: - """Return the response headers. + """Returns the headers from the HTTP response. Returns: - Headers: A dictionary-like object containing the response headers. - + Headers: The headers of the HTTP response. """ return self._response_headers @property def response_trailers(self) -> Headers: - """Return the response trailers. + """Returns the response trailers as a Headers object. - Response trailers are additional headers sent after the response body. + Response trailers are HTTP headers sent after the response body, typically used in protocols like gRPC. Returns: - Headers: A dictionary containing the response trailers. - + Headers: The response trailers associated with the response. """ return self._response_trailers @@ -518,16 +545,11 @@ def json_ummarshal(data: bytes, _message: Any) -> Any: @property def event_hooks(self) -> dict[str, list[EventHook]]: - """Return the event hooks. - - This method returns a dictionary where the keys are strings representing - event names, and the values are lists of EventHook objects associated with - those events. + """Returns the dictionary of registered event hooks. Returns: - dict[str, list[EventHook]]: A dictionary mapping event names to lists - of EventHook objects. - + dict[str, list[EventHook]]: A dictionary where each key is a string representing the event name, + and the value is a list of EventHook instances associated with that event. """ return self._event_hooks @@ -539,20 +561,35 @@ def event_hooks(self, event_hooks: dict[str, list[EventHook]]) -> None: } async def aclose(self) -> None: - """Asynchronously closes the connection or releases any resources held by the object. - - This method should be called when the object is no longer needed to ensure proper cleanup. - Currently, this implementation does not perform any actions, but it can be overridden in subclasses. - - Returns: - None + """Asynchronously closes the client connection and releases any associated resources. + This method should be called when the client is no longer needed to ensure proper cleanup. + Currently, this implementation does not perform any actions, but it can be extended in the future. """ return class ConnectStreamingClientConn(StreamingClientConn): - """ConnectStreamingClientConn is a class that manages a streaming client connection using the Connect protocol.""" + """ConnectStreamingClientConn manages a streaming client connection for the Connect protocol. + + This class handles the lifecycle of a streaming RPC client connection, including sending and receiving messages, + managing request and response headers, handling compression, marshaling/unmarshaling of messages, and supporting + event hooks for request and response events. It integrates with an asynchronous connection pool and supports + abortable operations via asyncio events. + + Attributes: + _spec (Spec): The protocol specification for the connection. + _peer (Peer): The peer associated with this connection. + url (URL): The URL endpoint for the connection. + codec (Codec): Codec used for encoding and decoding messages. + compressions (list[Compression]): Supported compression methods. + marshaler (ConnectStreamingMarshaler): Marshaler for outgoing streaming messages. + unmarshaler (ConnectStreamingUnmarshaler): Unmarshaler for incoming streaming messages. + response_content (bytes | None): Raw response content, if any. + _response_headers (Headers): Headers received in the response. + _response_trailers (Headers): Trailers received after the response body. + _request_headers (Headers): Headers sent with the request. + """ _spec: Spec _peer: Peer @@ -579,23 +616,19 @@ def __init__( unmarshaler: ConnectStreamingUnmarshaler, event_hooks: None | (Mapping[str, list[EventHook]]) = None, ) -> None: - """Initialize a new instance of the class. + """Initializes a new instance of the client. Args: - pool (AsyncConnectionPool): The connection pool for the client. - spec (Spec): The specification object. - peer (Peer): The peer object. - url (URL): The URL for the connection. - codec (Codec): The codec to be used for encoding and decoding. - compressions (list[Compression]): List of compression methods. - request_headers (Headers): The headers for the request. - marshaler (ConnectStreamingMarshaler): The marshaler for streaming. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. - event_hooks (None | Mapping[str, list[EventHook]], optional): Event hooks for request and response. Defaults to None. - - Returns: - None - + pool (AsyncConnectionPool): The asynchronous connection pool to use for network operations. + spec (Spec): The service specification or schema. + peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + codec (Codec): The codec used for encoding and decoding messages. + compressions (list[Compression]): List of supported compression algorithms. + request_headers (Headers): Headers to include in outgoing requests. + marshaler (ConnectStreamingMarshaler): Marshaler for streaming request bodies. + unmarshaler (ConnectStreamingUnmarshaler): Unmarshaler for streaming response bodies. + event_hooks (Optional[Mapping[str, list[EventHook]]]): Optional mapping of event hooks for 'request' and 'response' events. """ event_hooks = {} if event_hooks is None else event_hooks @@ -618,79 +651,78 @@ def __init__( @property def spec(self) -> Spec: - """Return the specification of the protocol. + """Returns the specification object associated with this client. Returns: - Spec: The specification object of the protocol. - + Spec: The specification instance used by the client. """ return self._spec @property def peer(self) -> Peer: - """Return the peer object associated with this instance. + """Returns the associated Peer object for this client. - :return: The peer object. - :rtype: Peer + Returns: + Peer: The peer instance linked to this client. """ return self._peer @property def request_headers(self) -> Headers: - """Retrieve the request headers. + """Returns the HTTP headers to be included in the request. Returns: - Headers: A dictionary-like object containing the request headers. - + Headers: The headers to be sent with the request. """ return self._request_headers @property def response_headers(self) -> Headers: - """Return the response headers. + """Returns the HTTP response headers. Returns: - Headers: A dictionary-like object containing the response headers. - + Headers: The headers received in the HTTP response. """ return self._response_headers @property def response_trailers(self) -> Headers: - """Return the response trailers. + """Returns the response trailers as a Headers object. - Response trailers are additional headers sent after the response body. + Response trailers are additional HTTP headers sent after the response body, + typically used in protocols like gRPC for sending metadata at the end of a response. Returns: - Headers: A dictionary containing the response trailers. - + Headers: The response trailers associated with the response. """ return self._response_trailers def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be called when a request is sent. + """Registers a callback function to be invoked whenever a request is sent. Args: - fn (EventHook): The callback function to be registered. This function - will be called with the request details when a request - is sent. + fn (EventHook): The callback function to be executed on request send events. + Returns: + None """ self._event_hooks["request"].append(fn) async def receive(self, message: Any, abort_event: asyncio.Event | None = None) -> AsyncIterator[Any]: - """Asynchronously receives and processes a message. + """Asynchronously receives and yields messages from the unmarshaler, handling stream control and errors. Args: - message (Any): The message to be processed. - abort_event (asyncio.Event | None): Event to signal abortion of the operation. + message (Any): The incoming message or stream to be unmarshaled. + abort_event (asyncio.Event | None, optional): An event to signal abortion of the receive operation. + If set, the operation is canceled and a ConnectError is raised. Yields: - Any: Objects obtained from unmarshaling the message. + Any: The next unmarshaled object from the message stream. Raises: - ConnectError: If stream is malformed or aborted. - + ConnectError: If the receive operation is aborted, if extra end stream messages are received, + if a message is received after the end of the stream, if an error is encountered in the end stream, + or if the end stream message is missing. """ end_stream_received = False @@ -727,25 +759,21 @@ async def receive(self, message: Any, abort_event: asyncio.Event | None = None) async def send( self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: - """Send an asynchronous HTTP POST request with the given messages and handle the response. + """Sends a stream of messages asynchronously to the server using HTTP POST. Args: messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. - timeout (float | None): Optional timeout value in seconds for the request. If provided, - it sets the read timeout for the request. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request. + timeout (float | None): Optional timeout in seconds for the request. If provided, sets the request timeout. + abort_event (asyncio.Event | None): Optional asyncio event to abort the request. If set and triggered, the request will be cancelled. Raises: - ConnectError: If the request is aborted or if there is an error during the request. + ConnectError: If the request is aborted via the abort_event. + Exception: Propagates exceptions raised during the request or response handling. - Hooks: - - Executes hooks registered in `self._event_hooks["request"]` before sending the request. - - Executes hooks registered in `self._event_hooks["response"]` after receiving the response. - - Notes: - - If `abort_event` is provided and set during the request, the request will be canceled, - and a `ConnectError` with code `Code.CANCELED` will be raised. - - The response stream is unmarshaled and validated after the request is completed. + Side Effects: + - Invokes registered request and response event hooks. + - Sets up the response stream for unmarshaling. + - Validates the server response. """ extensions = {} @@ -853,10 +881,9 @@ async def _validate_response(self, response: httpcore.Response) -> None: self._response_headers.update(response_headers) async def aclose(self) -> None: - """Asynchronously closes the connection by invoking the `aclose` method of the unmarshaler. - - Returns: - None + """Asynchronously closes the client by closing the associated unmarshaler. + This method should be called to properly release any resources held by the unmarshaler + when the client is no longer needed. """ await self.unmarshaler.aclose() diff --git a/src/connect/protocol_connect/connect_handler.py b/src/connect/protocol_connect/connect_handler.py index 6b9e1c9..f792a48 100644 --- a/src/connect/protocol_connect/connect_handler.py +++ b/src/connect/protocol_connect/connect_handler.py @@ -1,4 +1,4 @@ -"""Provides a ConnectHandler class for handling connection protocols.""" +"""Connect protocol handler implementation for unary and streaming RPCs.""" import json from collections.abc import ( @@ -61,13 +61,12 @@ class ConnectHandler(ProtocolHandler): - """A handler for managing protocol connections. + """ConnectHandler is a protocol handler for the Connect protocol. Attributes: - params (ProtocolHandlerParams): Parameters for the protocol handler. - __methods (list[HTTPMethod]): List of HTTP methods supported by the handler. - accept (list[str]): List of accepted content types. - + params (ProtocolHandlerParams): The parameters for the protocol handler, including specification and compression options. + _methods (list[HTTPMethod]): The list of HTTP methods supported by this handler. + accept (list[str]): The list of accepted content types. """ params: ProtocolHandlerParams @@ -75,13 +74,15 @@ class ConnectHandler(ProtocolHandler): accept: list[str] def __init__(self, params: ProtocolHandlerParams, methods: list[HTTPMethod], accept: list[str]) -> None: - """Initialize the ProtocolConnect instance. + """Initializes the handler with the given parameters, supported HTTP methods, and accepted content types. Args: - params (ProtocolHandlerParams): The parameters for the protocol handler. - methods (list[HTTPMethod]): A list of HTTP methods. + params (ProtocolHandlerParams): The parameters required for the protocol handler. + methods (list[HTTPMethod]): A list of supported HTTP methods. accept (list[str]): A list of accepted content types. + Returns: + None """ self.params = params self._methods = methods @@ -89,25 +90,35 @@ def __init__(self, params: ProtocolHandlerParams, methods: list[HTTPMethod], acc @property def methods(self) -> list[HTTPMethod]: - """Return the list of HTTP methods. + """Returns the list of HTTP methods supported by this handler. Returns: - list[HTTPMethod]: A list of HTTP methods. - + list[HTTPMethod]: A list containing the supported HTTP methods. """ return self._methods def content_types(self) -> list[str]: - """Handle content types. - - This method currently does nothing and serves as a placeholder for future - implementation related to content types. + """Returns a list of accepted content types. + Returns: + list[str]: A list of MIME types that are accepted. """ return self.accept def can_handle_payload(self, request: Request, content_type: str) -> bool: - """Check if the handler can handle the payload.""" + """Determines if the handler can process the given request payload based on the content type. + + Args: + request (Request): The incoming HTTP request object. + content_type (str): The content type of the request payload. + + Returns: + bool: True if the handler can accept the payload with the specified content type, False otherwise. + + Notes: + - For GET requests, the content type may be determined from a query parameter and the stream type. + - For other HTTP methods, the provided content_type is used directly. + """ if HTTPMethod(request.method) == HTTPMethod.GET: codec_name = request.query_params.get(CONNECT_UNARY_ENCODING_QUERY_PARAMETER, "") content_type = connect_content_type_from_codec_name(self.params.spec.stream_type, codec_name) @@ -121,21 +132,24 @@ async def conn( response_trailers: Headers, writer: ServerResponseWriter, ) -> StreamingHandlerConn | None: - """Handle a connection request. + """Handles the connection for a Connect protocol request. Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. + request (Request): The incoming HTTP request object. + response_headers (Headers): Mutable headers to be sent with the response. + response_trailers (Headers): Mutable trailers to be sent with the response. + writer (ServerResponseWriter): Writer for sending responses to the client. Returns: - StreamingHandlerConn | None: The connection handler or None if not implemented. - - Raises: - ConnectError: If there is an error in negotiating compression, protocol version, or message encoding. - + StreamingHandlerConn | None: A connection handler for the request, or None if an error occurred. + + Workflow: + - Determines stream type (Unary or Streaming) and negotiates compression and encoding. + - Validates protocol version and required parameters. + - Parses and decodes the request message for unary GET requests. + - Sets appropriate response headers based on negotiated compression and encoding. + - Constructs and returns the appropriate connection handler (unary or streaming). + - Sends an error response and returns None if any validation or negotiation fails. """ query_params = request.query_params @@ -268,14 +282,23 @@ async def conn( class ConnectUnaryHandlerConn(StreamingHandlerConn): - """ConnectUnaryHandlerConn is a handler connection class for unary RPCs in the Connect protocol. + """Handler for unary Connect protocol requests. + + This class manages the lifecycle of a unary RPC connection, including + request parsing, response serialization, error handling, and header/trailer + management. It provides methods to receive and unmarshal incoming messages, + marshal and send responses, and handle protocol-specific metadata. Attributes: + writer (ServerResponseWriter): The writer used to send responses. request (Request): The incoming request object. - marshaler (ConnectUnaryMarshaler): An instance of ConnectUnaryMarshaler used to marshal messages. - unmarshaler (ConnectUnaryUnmarshaler): An instance of ConnectUnaryUnmarshaler used to unmarshal messages. - headers (Headers): The headers for the response. - + _peer (Peer): Information about the remote peer. + _spec (Spec): The protocol specification object. + marshaler (ConnectUnaryMarshaler): Marshaler for serializing response messages. + unmarshaler (ConnectUnaryUnmarshaler): Unmarshaler for deserializing request messages. + _request_headers (Headers): Headers from the incoming request. + _response_headers (Headers): Headers to be sent in the response. + _response_trailers (Headers): Trailers to be sent in the response. """ writer: ServerResponseWriter @@ -300,18 +323,18 @@ def __init__( response_headers: Headers, response_trailers: Headers | None = None, ) -> None: - """Initialize the protocol connection. + """Initializes a new instance of the class. Args: - writer (ServerResponseWriter): The writer to send the response. + writer (ServerResponseWriter): The writer used to send responses to the client. request (Request): The incoming request object. - peer (Peer): The peer information. - spec (Spec): The specification object. - marshaler (ConnectUnaryMarshaler): The marshaler to serialize data. - unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler to deserialize data. - request_headers (Headers): The headers for the request. - response_headers (Headers): The headers for the response. - response_trailers (Headers, optional): The trailers for the response. + peer (Peer): Information about the remote peer. + spec (Spec): The specification for the current operation. + marshaler (ConnectUnaryMarshaler): The marshaler for serializing responses. + unmarshaler (ConnectUnaryUnmarshaler): The unmarshaler for deserializing requests. + request_headers (Headers): Headers from the incoming request. + response_headers (Headers): Headers to include in the response. + response_trailers (Headers | None, optional): Trailers to include in the response. Defaults to None. """ self.writer = writer @@ -325,7 +348,18 @@ def __init__( self._response_trailers = response_trailers if response_trailers is not None else Headers() def parse_timeout(self) -> float | None: - """Parse the timeout value.""" + """Parses the timeout value from the request headers. + + Retrieves the timeout value from the `CONNECT_HEADER_TIMEOUT` header in the request. + If the header is not present, returns None. If present, attempts to convert the value + to an integer (milliseconds), and returns the timeout in seconds as a float. + + Raises: + ConnectError: If the timeout value cannot be converted to an integer. + + Returns: + float | None: The timeout value in seconds, or None if not specified. + """ try: timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) if timeout is None: @@ -339,67 +373,68 @@ def parse_timeout(self) -> float | None: @property def spec(self) -> Spec: - """Return the specification object. + """Returns the specification object associated with this handler. Returns: - Spec: The specification object. - + Spec: The specification instance for this handler. """ return self._spec @property def peer(self) -> Peer: - """Return the peer associated with this instance. + """Returns the associated Peer object for this handler. - :return: The peer associated with this instance. - :rtype: Peer + Returns: + Peer: The Peer object containing information about the remote peer. """ return self._peer async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: - """Receives and unmarshals a message into an object. + """Asynchronously receives and unmarshals a message, yielding the result. Args: - message (Any): The message to be unmarshaled. + message (Any): The raw message to be unmarshaled. - Returns: - AsyncIterator[Any]: An async iterator yielding the unmarshaled object. + Yields: + Any: The unmarshaled message object. """ yield await self.unmarshaler.unmarshal(message) def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message, unmarshals it, and returns the resulting object. + """Receives a message and returns an asynchronous iterator over the processed messages. Args: - message (Any): The message to be unmarshaled. + message (Any): The input message to be processed. Returns: - AsyncIterator[Any]: An async iterator yielding the unmarshaled object. - + AsyncIterator[Any]: An asynchronous iterator yielding processed messages. """ return self._receive_messages(message) @property def request_headers(self) -> Headers: - """Retrieve the headers from the request. + """Returns the HTTP headers associated with the current request. Returns: - Mapping[str, str]: A dictionary-like object containing the request headers. - + Headers: The headers of the request. """ return self._request_headers async def send(self, messages: AsyncIterable[Any]) -> None: - """Send message(s) by marshaling them into bytes. + """Sends a single message over the connection. - Args: - messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, - this should be an iterable with a single item. + This asynchronous method expects an asynchronous iterable of messages, + ensures that only a single message is present, marshals it, and writes + the response using the provided writer. Response trailers are merged + before sending the message. - Returns: - None + Args: + messages (AsyncIterable[Any]): An asynchronous iterable containing the message to send. + Raises: + ValueError: If the iterable contains zero or more than one message. + Exception: Propagates exceptions raised during marshaling or writing. """ self.merge_response_trailers() @@ -410,49 +445,46 @@ async def send(self, messages: AsyncIterable[Any]) -> None: @property def response_headers(self) -> Headers: - """Retrieve the response headers. + """Returns the HTTP response headers. Returns: - Any: The response headers. - + Headers: The headers of the HTTP response. """ return self._response_headers @property def response_trailers(self) -> Headers: - """Handle response trailers. + """Returns the HTTP response trailers as a Headers object. - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. + Response trailers are additional HTTP headers sent after the response body, + typically used in protocols like gRPC or HTTP/2 for metadata that is only + available once the response body has been generated. Returns: - Any: The processed response trailer data. - + Headers: The response trailers associated with the HTTP response. """ return self._response_trailers def get_http_method(self) -> HTTPMethod: - """Retrieve the HTTP method from the request. + """Returns the HTTP method of the current request as an `HTTPMethod` enum. Returns: - HTTPMethod: The HTTP method from the request. - + HTTPMethod: The HTTP method (e.g., GET, POST) of the request. """ return HTTPMethod(self.request.method) async def send_error(self, error: ConnectError) -> None: - """Send an error response. - - This method updates the response headers with the error metadata, - sets the response trailers, converts the error code to an HTTP status code, - serializes the error to JSON, and writes the response. + """Sends an error response to the client in the Connect protocol format. Args: - error (ConnectError): The error to be sent in the response. - - Returns: - None - + error (ConnectError): The error object containing error details, code, and metadata. + + Behavior: + - Updates response headers with error metadata, excluding protocol-specific headers if `wire_error` is False. + - Merges response trailers into the headers. + - Sets the appropriate HTTP status code based on the Connect error code. + - Sets the response content type to Connect JSON. + - Serializes the error to JSON bytes and writes the response to the client. """ if not error.wire_error: self.response_headers.update(exclude_protocol_headers(error.metadata)) @@ -467,34 +499,36 @@ async def send_error(self, error: ConnectError) -> None: await self.writer.write(Response(content=body, headers=self.response_headers, status_code=status_code)) def merge_response_trailers(self) -> None: - """Merge response trailers into the response headers. + """Merges the response trailers into the response headers by prefixing each trailer key with CONNECT_UNARY_TRAILER_PREFIX and adding it to the response headers dictionary. - This method iterates through the `_response_trailers` dictionary and adds - each trailer key-value pair to the `_response_headers` dictionary, - prefixing the trailer keys with `CONNECT_UNARY_TRAILER_PREFIX`. + This is typically used to ensure that trailer metadata is included in the headers + for protocols or transports that do not natively support trailers. Returns: None - """ for key, value in self._response_trailers.items(): self._response_headers[CONNECT_UNARY_TRAILER_PREFIX + key] = value class ConnectStreamingHandlerConn(StreamingHandlerConn): - """ConnectStreamingHandlerConn is a class that handles streaming connections for the Connect protocol. + """ConnectStreamingHandlerConn manages the lifecycle and data flow of a streaming connection using the Connect protocol. + + It handles marshaling and unmarshaling of streaming messages, manages request and response + headers/trailers, and provides methods for sending and receiving messages asynchronously. + This class is designed to work with a server response writer and encapsulates protocol-specific + logic for error handling and timeout parsing. Attributes: - writer (ServerResponseWriter): The writer used to send responses. + writer (ServerResponseWriter): The writer used to send responses to the client. request (Request): The incoming request object. - _peer (Peer): The peer associated with this connection. - _spec (Spec): The specification object. - marshaler (ConnectStreamingMarshaler): The marshaler used to serialize messages. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler used to deserialize messages. - _request_headers (Headers): The headers from the request. - _response_headers (Headers): The headers for the response. - _response_trailers (Headers): The trailers for the response. - + _peer (Peer): Information about the remote peer. + _spec (Spec): The protocol specification details. + marshaler (ConnectStreamingMarshaler): Marshals outgoing streaming messages. + unmarshaler (ConnectStreamingUnmarshaler): Unmarshals incoming streaming messages. + _request_headers (Headers): Headers from the incoming request. + _response_headers (Headers): Headers to be sent in the response. + _response_trailers (Headers): Trailers to be sent at the end of the response stream. """ writer: ServerResponseWriter @@ -519,18 +553,18 @@ def __init__( response_headers: Headers, response_trailers: Headers | None = None, ) -> None: - """Initialize the protocol connection. + """Initializes the ConnectHandler with the provided writer, request, peer, specification, marshaler, unmarshaler, and headers. Args: - writer (ServerResponseWriter): The writer for server responses. - request (Request): The request object. - peer (Peer): The peer information. - spec (Spec): The specification details. - marshaler (ConnectStreamingMarshaler): The marshaler for streaming. - unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming. - request_headers (Headers): The headers for the request. - response_headers (Headers): The headers for the response. - response_trailers (Headers, optional): The trailers for the response. Defaults to None. + writer (ServerResponseWriter): The writer used to send responses to the client. + request (Request): The incoming request object. + peer (Peer): The peer information for the connection. + spec (Spec): The specification for the connection. + marshaler (ConnectStreamingMarshaler): The marshaler for streaming responses. + unmarshaler (ConnectStreamingUnmarshaler): The unmarshaler for streaming requests. + request_headers (Headers): Headers from the incoming request. + response_headers (Headers): Headers to include in the response. + response_trailers (Headers | None, optional): Trailing headers to include in the response. Defaults to None. """ self.writer = writer @@ -544,7 +578,19 @@ def __init__( self._response_trailers = response_trailers if response_trailers is not None else Headers() def parse_timeout(self) -> float | None: - """Parse the timeout value.""" + """Parses the timeout value from the request headers. + + Retrieves the timeout value specified in the CONNECT_HEADER_TIMEOUT header, + converts it from milliseconds to seconds, and returns it as a float. + If the header is not present, returns None. + Raises a ConnectError with Code.INVALID_ARGUMENT if the header value is not a valid integer. + + Returns: + float | None: The timeout value in seconds, or None if not specified. + + Raises: + ConnectError: If the timeout value cannot be converted to an integer. + """ try: timeout = self.request.headers.get(CONNECT_HEADER_TIMEOUT) if timeout is None: @@ -558,80 +604,73 @@ def parse_timeout(self) -> float | None: @property def spec(self) -> Spec: - """Return the specification object. + """Returns the specification object associated with this handler. Returns: - Spec: The specification object. - + Spec: The specification instance for this handler. """ return self._spec @property def peer(self) -> Peer: - """Return the peer associated with this instance. + """Returns the associated Peer object for this handler. - :return: The peer associated with this instance. - :rtype: Peer + Returns: + Peer: The Peer object containing information about the remote peer. """ return self._peer async def _receive_messages(self, message: Any) -> AsyncIterator[Any]: - """Asynchronously receives a message and yields unmarshaled objects. - - This method unmarshals the received message and yields each - unmarshaled object one by one as an asynchronous iterator. + """Asynchronously receives and yields unmarshaled message objects. Args: - message (Any): The message to unmarshal. - - Returns: - AsyncIterator[Any]: An asynchronous iterator yielding unmarshaled objects. + message (Any): The incoming message to be unmarshaled. Yields: - Any: Each unmarshaled object from the message. + Any: Each unmarshaled object extracted from the message. + Raises: + Any exceptions raised by the unmarshaler during processing. """ async for obj, _ in self.unmarshaler.unmarshal(message): yield obj def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and returns an asynchronous content stream. - - This method processes the incoming message through the receive_message method - and wraps the result in an AsyncContentStream with the appropriate stream type. + """Receives a message and returns an asynchronous iterator over the processed messages. Args: - message (Any): The message to be processed. + message (Any): The message to be received and processed. Returns: - AsyncContentStream[Any]: An asynchronous stream of content based on the - processed message, configured with the specification's stream type. - + AsyncIterator[Any]: An asynchronous iterator yielding processed messages. """ return self._receive_messages(message) @property def request_headers(self) -> Headers: - """Retrieve the headers from the request. + """Returns the HTTP headers associated with the current request. Returns: - Mapping[str, str]: A dictionary-like object containing the request headers. - + Headers: The headers of the request. """ return self._request_headers async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Create an async iterator that marshals messages with error handling. + """Asynchronously sends marshaled messages and yields them as byte streams. - Args: - messages (AsyncIterable[Any]): Messages to marshal + Iterates over the provided asynchronous iterable of messages, marshals each message, + and yields the resulting bytes. If an exception occurs during marshaling, it captures + the error and ensures that an end-of-stream message is marshaled and yielded with + appropriate error information and response trailers. - Returns: - AsyncIterator[bytes]: Marshaled bytes with end stream message + Args: + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent. Yields: - bytes: Each marshaled message followed by an end stream message + bytes: Marshaled message bytes, including a final end-of-stream message. + Raises: + ConnectError: If an internal error occurs during marshaling. """ error: ConnectError | None = None try: @@ -644,21 +683,16 @@ async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[by yield body async def send(self, messages: AsyncIterable[Any]) -> None: - """Send a stream of messages asynchronously. - - This method marshals the provided messages and sends them using the writer. - If an error occurs during the marshaling process, it captures the error, - converts it to a JSON object, and sends it as the final message in the stream. + """Asynchronously sends a stream of messages to the client using a streaming HTTP response. Args: - messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent to the client. Returns: None Raises: - ConnectError: If an error occurs during the marshaling process. - + Any exceptions raised by the writer or during message streaming will propagate. """ await self.writer.write( StreamingResponse( @@ -670,36 +704,35 @@ async def send(self, messages: AsyncIterable[Any]) -> None: @property def response_headers(self) -> Headers: - """Retrieve the response headers. + """Returns the HTTP response headers. Returns: - Any: The response headers. - + Headers: The headers included in the HTTP response. """ return self._response_headers @property def response_trailers(self) -> Headers: - """Handle response trailers. + """Returns the HTTP response trailers. - This method is intended to be overridden in subclasses to provide - specific functionality for processing response trailers. + Response trailers are additional headers sent after the response body, typically used in protocols like HTTP/2 or gRPC to provide metadata that is only available once the response body has been generated. Returns: - Any: The processed response trailer data. - + Headers: The response trailers as a Headers object. """ return self._response_trailers async def send_error(self, error: ConnectError) -> None: - """Send an error response in the form of a JSON object. + """Sends an error response to the client using the provided ConnectError. + + This method marshals the error and response trailers into a response body, + then writes a streaming HTTP response with the appropriate headers and a status code of 200. Args: - error (ConnectError): The error object to be sent. + error (ConnectError): The error to be sent to the client. Returns: None - """ body = self.marshaler.marshal_end_stream(error, self.response_trailers) @@ -709,17 +742,19 @@ async def send_error(self, error: ConnectError) -> None: def connect_check_protocol_version(request: Request, required: bool) -> ConnectError | None: - """Check the protocol version in the request headers for POST requests. + """Validates the protocol version in a Connect request based on the HTTP method. - Args: - request (Request): The incoming HTTP request. - required (bool): Flag indicating whether the protocol version is required. + For GET requests, checks the presence and value of a specific query parameter. + For POST requests, checks the presence and value of a specific header. + Returns a ConnectError if the required protocol version is missing or incorrect, + or if the HTTP method is unsupported. Returns None if the protocol version is valid. - Raises: - ValueError: If the protocol version is required but not present in the headers. - ValueError: If the protocol version is present but unsupported. - ValueError: If the HTTP method is unsupported. + Args: + request (Request): The incoming HTTP request to validate. + required (bool): Whether the protocol version is required. + Returns: + ConnectError | None: A ConnectError describing the validation failure, or None if valid. """ match HTTPMethod(request.method): case HTTPMethod.GET: @@ -751,18 +786,16 @@ def connect_check_protocol_version(request: Request, required: bool) -> ConnectE def error_to_json_bytes(error: ConnectError) -> bytes: - """Serialize a ConnectError object to a JSON-encoded byte string. + """Serializes a ConnectError object to a JSON-formatted bytes object. Args: - error (ConnectError): The ConnectError object to serialize. + error (ConnectError): The ConnectError instance to serialize. Returns: - bytes: The JSON-encoded byte string representation of the error. + bytes: The JSON representation of the error, encoded as UTF-8 bytes. Raises: - ConnectError: If serialization fails, a ConnectError is raised with an - appropriate error message and code. - + ConnectError: If serialization fails, raises a new ConnectError with an INTERNAL code. """ try: json_obj = error_to_json(error) diff --git a/src/connect/protocol_connect/connect_protocol.py b/src/connect/protocol_connect/connect_protocol.py index 0aaefec..91a334b 100644 --- a/src/connect/protocol_connect/connect_protocol.py +++ b/src/connect/protocol_connect/connect_protocol.py @@ -21,17 +21,30 @@ class ProtocolConnect(Protocol): - """ProtocolConnect is a class that implements the Protocol interface for handling connection protocols.""" + """ProtocolConnect is a protocol handler for the Connect protocol, responsible for creating handler and client instances based on provided parameters. + + Methods: + handler(params: ProtocolHandlerParams) -> ConnectHandler: + Handles the creation of a ConnectHandler instance, configuring supported HTTP methods and accepted content types + based on the stream type and idempotency level specified in the parameters. + + client(params: ProtocolClientParams) -> ProtocolClient: + Creates and returns a ConnectClient instance initialized with the provided parameters. + """ def handler(self, params: ProtocolHandlerParams) -> ConnectHandler: - """Handle the creation of a ConnectHandler based on the provided ProtocolHandlerParams. + """Creates and returns a ConnectHandler instance configured with appropriate HTTP methods and accepted content types based on the provided ProtocolHandlerParams. Args: - params (ProtocolHandlerParams): The parameters required to create the ConnectHandler. + params (ProtocolHandlerParams): The parameters specifying the protocol handler configuration, including stream type, idempotency level, and codecs. Returns: - ConnectHandler: An instance of ConnectHandler configured with the appropriate methods and content types. + ConnectHandler: An instance of ConnectHandler configured with the determined HTTP methods and accepted content types. + Behavior: + - Allows POST requests by default. + - Adds GET as an allowed method if the stream type is Unary and the idempotency level is NO_SIDE_EFFECTS. + - Constructs a list of accepted content types based on the stream type and available codec names. """ methods = [HTTPMethod.POST] @@ -49,13 +62,12 @@ def handler(self, params: ProtocolHandlerParams) -> ConnectHandler: return ConnectHandler(params, methods=methods, accept=content_types) def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Create and returns a ConnectClient instance. + """Creates and returns a new instance of `ConnectClient` using the provided parameters. Args: - params (ProtocolClientParams): The parameters required to initialize the client. + params (ProtocolClientParams): The parameters required to initialize the protocol client. Returns: - ProtocolClient: An instance of ConnectClient. - + ProtocolClient: An instance of `ConnectClient` initialized with the given parameters. """ return ConnectClient(params) diff --git a/src/connect/protocol_connect/constants.py b/src/connect/protocol_connect/constants.py index 0d01589..56d2a8b 100644 --- a/src/connect/protocol_connect/constants.py +++ b/src/connect/protocol_connect/constants.py @@ -1,4 +1,4 @@ -"""Constants used in the Connect protocol implementation for Python.""" +"""Constants for Connect protocol headers, content types, and user agent.""" import sys diff --git a/src/connect/protocol_connect/content_type.py b/src/connect/protocol_connect/content_type.py index c3141e0..bc14d07 100644 --- a/src/connect/protocol_connect/content_type.py +++ b/src/connect/protocol_connect/content_type.py @@ -1,4 +1,4 @@ -"""Utilities for handling Connect protocol content types.""" +"""Helpers for Connect protocol content type handling.""" from http import HTTPStatus @@ -18,15 +18,22 @@ def connect_codec_from_content_type(stream_type: StreamType, content_type: str) -> str: - """Extract the codec from the content type based on the stream type. + """Extracts the codec name from a given content type string based on the stream type. Args: - stream_type (StreamType): The type of stream (Unary or Streaming). - content_type (str): The content type string from which to extract the codec. + stream_type (StreamType): The type of the stream (e.g., Unary or Streaming). + content_type (str): The full content type string, which includes a prefix and the codec name. Returns: - str: The extracted codec from the content type. + str: The codec name extracted from the content type. + + Raises: + IndexError: If the content_type string is shorter than the expected prefix length. + Note: + The function assumes that the content_type string starts with either + CONNECT_UNARY_CONTENT_TYPE_PREFIX or CONNECT_STREAMING_CONTENT_TYPE_PREFIX, + depending on the stream_type. """ if stream_type == StreamType.Unary: return content_type[len(CONNECT_UNARY_CONTENT_TYPE_PREFIX) :] @@ -35,15 +42,18 @@ def connect_codec_from_content_type(stream_type: StreamType, content_type: str) def connect_content_type_from_codec_name(stream_type: StreamType, codec_name: str) -> str: - """Generate the content type string for a given stream type and codec name. + """Generates a Connect protocol content type string based on the stream type and codec name. Args: - stream_type (StreamType): The type of the stream (e.g., Unary or Streaming). - codec_name (str): The name of the codec. + stream_type (StreamType): The type of stream (e.g., Unary or Streaming). + codec_name (str): The name of the codec (e.g., "proto", "json"). Returns: - str: The content type string constructed from the stream type and codec name. + str: The content type string for the Connect protocol, combining the appropriate prefix and codec name. + Example: + connect_content_type_from_codec_name(StreamType.Unary, "proto") + # Returns: "application/connect+proto" """ if stream_type == StreamType.Unary: return CONNECT_UNARY_CONTENT_TYPE_PREFIX + codec_name @@ -56,17 +66,24 @@ def connect_validate_unary_response_content_type( status_code: int, response_content_type: str, ) -> ConnectError | None: - """Validate the content type of a unary response based on the HTTP status code and method. + """Validates the content type of a unary response in the Connect protocol. Args: - request_codec_name (str): The name of the codec used for the request. - http_method (HTTPMethod): The HTTP method used for the request. + request_codec_name (str): The codec name used in the request (e.g., "json", "json; charset=utf-8"). status_code (int): The HTTP status code of the response. response_content_type (str): The content type of the response. + Returns: + ConnectError | None: Returns a ConnectError if the response content type is invalid or does not match + the expected codec, otherwise returns None. + Raises: - ConnectError: If the status code is not OK and the response content type is not valid. + ConnectError: If the response content type is invalid or does not match the expected format. + Behavior: + - For non-OK HTTP status codes, ensures the response is JSON-encoded. + - For OK responses, checks that the content type starts with the expected prefix and matches the request codec. + - Allows for compatibility between "json" and "json; charset=utf-8" codecs. """ if status_code != HTTPStatus.OK: # Error response must be JSON-encoded. diff --git a/src/connect/protocol_connect/end_stream.py b/src/connect/protocol_connect/end_stream.py index 7b75010..09a25e5 100644 --- a/src/connect/protocol_connect/end_stream.py +++ b/src/connect/protocol_connect/end_stream.py @@ -1,4 +1,4 @@ -"""Module for handling end-of-stream JSON serialization and deserialization for Connect protocol.""" +"""Helpers for serializing and deserializing Connect end-of-stream messages.""" import json from typing import Any @@ -10,15 +10,14 @@ def end_stream_to_json(error: ConnectError | None, trailers: Headers) -> dict[str, Any]: - """Convert the end of a stream to a JSON-serializable dictionary. + """Converts the end-of-stream state, including an optional error and trailers, into a JSON-serializable dictionary. Args: - error (ConnectError | None): An optional error object that may contain metadata. - trailers (Headers): Headers object containing metadata. + error (ConnectError | None): An optional error object representing the stream error, if any. + trailers (Headers): The headers (trailers) to include as metadata in the JSON output. Returns: - dict[str, Any]: A dictionary containing the error and metadata information in JSON-serializable format. - + dict[str, Any]: A dictionary containing the serialized error (if present) and metadata extracted from the trailers. """ json_obj = {} @@ -34,18 +33,16 @@ def end_stream_to_json(error: ConnectError | None, trailers: Headers) -> dict[st def end_stream_from_bytes(data: bytes) -> tuple[ConnectError | None, Headers]: - """Parse a byte stream to extract metadata and error information. + """Parses a byte string representing an end stream message and returns a tuple containing a possible ConnectError and Headers. Args: - data (bytes): The byte stream to be parsed. + data (bytes): The byte string to parse, expected to be a JSON-encoded object. Returns: - tuple[ConnectError | None, Headers]: A tuple containing an optional ConnectError - and a Headers object with the parsed metadata. + tuple[ConnectError | None, Headers]: A tuple where the first element is a ConnectError if an error is present in the input, or None otherwise; the second element is a Headers object containing parsed metadata. Raises: - ConnectError: If the byte stream is invalid or the metadata format is incorrect. - + ConnectError: If the input data is not valid JSON, or if the metadata format is invalid. """ parse_error = ConnectError("invalid end stream", Code.UNKNOWN) try: diff --git a/src/connect/protocol_connect/error_code.py b/src/connect/protocol_connect/error_code.py index 349a9a1..06ec115 100644 --- a/src/connect/protocol_connect/error_code.py +++ b/src/connect/protocol_connect/error_code.py @@ -1,4 +1,4 @@ -"""Module for mapping Connect error codes to HTTP status codes.""" +"""HTTP status code mapping for Connect protocol error codes.""" from connect.code import Code diff --git a/src/connect/protocol_connect/error_json.py b/src/connect/protocol_connect/error_json.py index c45a336..6c939f1 100644 --- a/src/connect/protocol_connect/error_json.py +++ b/src/connect/protocol_connect/error_json.py @@ -1,4 +1,4 @@ -"""Module for serializing and deserializing ConnectError objects to and from JSON.""" +"""Utilities for serializing and deserializing ConnectError objects to and from JSON.""" import contextlib import json @@ -17,18 +17,16 @@ def code_to_string(value: Code) -> str: - """Convert a Code object to its string representation. + """Converts a Code enum value to its string representation. - If the Code object has a 'name' attribute and it is not None, the method returns - the lowercase version of the 'name'. Otherwise, it returns the string representation - of the 'value' attribute. + If the value has a 'name' attribute and it is not None, returns the lowercase name. + Otherwise, returns the string representation of the value's 'value' attribute. Args: - value (Code): The Code object to be converted to a string. + value (Code): The enum value to convert. Returns: - str: The string representation of the Code object. - + str: The string representation of the enum value. """ if not hasattr(value, "name") or value.name is None: return str(value.value) @@ -37,20 +35,17 @@ def code_to_string(value: Code) -> str: def code_from_string(value: str) -> Code | None: - """Convert a string representation of a code to its corresponding Code enum value. - - This function uses a global dictionary to cache the mapping from string to Code enum values. - If the cache is not initialized, it populates the cache by iterating over all Code enum values - and mapping their string representations to the corresponding Code enum. + """Converts a string representation of a code to its corresponding `Code` enum value. - This function is thread-safe and ensures the global cache is initialized only once. + This function uses a thread-safe, lazily-initialized mapping to efficiently look up + the `Code` enum associated with the given string. If the mapping is not yet initialized, + it will be created in a thread-safe manner using double-checked locking. Args: value (str): The string representation of the code. Returns: - Code | None: The corresponding Code enum value if found, otherwise None. - + Code | None: The corresponding `Code` enum value if found, otherwise `None`. """ global _string_to_code @@ -68,19 +63,21 @@ def code_from_string(value: str) -> Code | None: def error_from_json(obj: dict[str, Any], fallback: ConnectError) -> ConnectError: - """Convert a JSON-serializable dictionary to a ConnectError object. + """Deserializes a JSON object into a ConnectError instance. Args: - obj (dict[str, Any]): The dictionary representing the error in JSON format. - fallback (ConnectError): A fallback ConnectError object to use in case of missing or invalid fields. + obj (dict[str, Any]): The JSON object representing the error, expected to contain + at least a "message" field, and optionally "code" and "details". + fallback (ConnectError): A fallback ConnectError instance to use for default values + or to raise in case of malformed input. Returns: - ConnectError: The ConnectError object converted from the dictionary. + ConnectError: The deserialized ConnectError instance with populated message, code, + and details. Raises: - ConnectError: If the dictionary is missing required fields or contains invalid values, - a ConnectError is raised with an appropriate error message and code. - + ConnectError: If required fields in the details are missing or if base64 decoding fails, + the fallback error is raised. """ code = fallback.code if "code" in obj: @@ -115,21 +112,21 @@ def error_from_json(obj: dict[str, Any], fallback: ConnectError) -> ConnectError def error_to_json(error: ConnectError) -> dict[str, Any]: - """Convert a ConnectError object to a JSON-serializable dictionary. + """Converts a ConnectError object into a JSON-serializable dictionary. Args: error (ConnectError): The error object to convert. Returns: - dict[str, Any]: A dictionary representing the error in JSON format. - - "code" (str): The error code as a string. - - "message" (str, optional): The raw error message, if available. - - "details" (list[dict[str, Any]], optional): A list of dictionaries containing error details, if available. - Each detail dictionary contains: - - "type" (str): The type name of the detail. - - "value" (str): The base64-encoded value of the detail. - - "debug" (str, optional): The JSON-encoded debug information, if available. - + dict[str, Any]: A dictionary representing the error, including its code, message, and details if present. + + The returned dictionary contains: + - "code": The string representation of the error code. + - "message": The raw error message, if available. + - "details": A list of dictionaries for each error detail, each containing: + - "type": The type name of the detail. + - "value": The base64-encoded value of the detail. + - "debug": (optional) A dictionary representation of the inner message, if available. """ obj: dict[str, Any] = {"code": error.code.string()} diff --git a/src/connect/protocol_connect/marshaler.py b/src/connect/protocol_connect/marshaler.py index b6ec197..501f18b 100644 --- a/src/connect/protocol_connect/marshaler.py +++ b/src/connect/protocol_connect/marshaler.py @@ -1,4 +1,4 @@ -"""Provides marshaling utilities for the Connect protocol.""" +"""Marshaling utilities for Connect protocol unary and streaming messages.""" import base64 import contextlib @@ -32,15 +32,16 @@ class ConnectUnaryMarshaler: - """ConnectUnaryMarshaler is responsible for serializing and optionally compressing messages. + """ConnectUnaryMarshaler is responsible for marshaling unary messages in the Connect protocol. - Attributes: - codec (Codec): The codec used for serializing messages. - compression (Compression | None): The compression method used for compressing messages, if any. - compress_min_bytes (int): The minimum size in bytes for a message to be compressed. - send_max_bytes (int): The maximum allowed size in bytes for a message to be sent. - headers (Headers | Headers): The headers to be included in the message. + This class handles the encoding and optional compression of messages before they are sent over the network. + Attributes: + codec (Codec | None): Codec used for encoding/decoding messages. + compression (Compression | None): Compression algorithm to use, or None for no compression. + compress_min_bytes (int): Minimum message size (in bytes) before compression is applied. + send_max_bytes (int): Maximum allowed size (in bytes) for a message to be sent. + headers (Headers): Headers to include in the connection. """ codec: Codec | None @@ -57,18 +58,14 @@ def __init__( send_max_bytes: int, headers: Headers, ) -> None: - """Initialize the protocol connection. + """Initializes the object with the specified codec, compression settings, and headers. Args: - codec (Codec): The codec to be used for encoding/decoding. - compression (Compression | None): The compression method to be used, or None if no compression. + codec (Codec | None): The codec to use for encoding/decoding, or None if not specified. + compression (Compression | None): The compression algorithm to use, or None for no compression. compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes to send in a single message. - headers (Headers): The headers to be included in the connection. - - Returns: - None - + send_max_bytes (int): The maximum number of bytes allowed to send in a single message. + headers (Headers): The headers to include with each message. """ self.codec = codec self.compression = compression @@ -77,17 +74,24 @@ def __init__( self.headers = headers def marshal(self, message: Any) -> bytes: - """Marshals a message into bytes, optionally compressing it if it exceeds a certain size. + """Serializes and optionally compresses a message object into bytes. Args: - message (Any): The message to be marshaled. + message (Any): The message object to be marshaled. Returns: - bytes: The marshaled (and possibly compressed) message. + bytes: The serialized (and possibly compressed) message. Raises: - ConnectError: If there is an error during marshaling or if the message size exceeds the allowed limit. - + ConnectError: If the codec is not set, if marshaling fails, or if the (compressed or uncompressed) + message size exceeds the configured send_max_bytes limit. + + Process: + - Uses the configured codec to serialize the message. + - If the serialized data is smaller than `compress_min_bytes` or compression is not set, + returns the data as-is (after checking size limits). + - Otherwise, compresses the data, checks the size limit again, and sets the appropriate + compression header before returning the compressed data. """ if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) @@ -119,18 +123,17 @@ def marshal(self, message: Any) -> bytes: class ConnectUnaryRequestMarshaler(ConnectUnaryMarshaler): - """ConnectUnaryRequestMarshaler is responsible for marshaling unary request messages for the Connect protocol, with support for GET requests and stable codecs. + """ConnectUnaryRequestMarshaler is a specialized marshaler for unary requests in the Connect protocol. - This class extends ConnectUnaryMarshaler to provide additional functionality for handling GET requests, - including marshaling messages using a stable codec, enforcing message size limits, and optionally compressing - messages when necessary. It also manages the construction of GET URLs with appropriate query parameters and - headers for the Connect protocol. + This class extends ConnectUnaryMarshaler and adds the ability to marshal messages for GET requests, + optionally using a stable codec for deterministic serialization. It manages request headers, compression, + and enforces message size limits. If GET requests are enabled, it ensures that a stable codec is available + and handles URL construction for GET requests, including optional compression and base64 encoding. Attributes: - enable_get (bool): Flag indicating whether GET requests are enabled. - stable_codec (StableCodec | None): The codec used for stable marshaling, if available. - url (URL | None): The URL to use for the request. - + enable_get (bool): Whether to enable GET requests for marshaling. + stable_codec (StableCodec | None): Optional stable codec for deterministic message serialization. + url (URL | None): The URL endpoint for the connection. """ enable_get: bool @@ -148,21 +151,17 @@ def __init__( stable_codec: StableCodec | None = None, url: URL | None = None, ) -> None: - """Initialize the protocol connection with the specified configuration. + """Initializes the object with the specified codec, compression, compression threshold, maximum send bytes, headers, and optional parameters. Args: - codec (Codec | None): The codec to use for encoding/decoding messages, or None. - compression (Compression | None): The compression algorithm to use, or None. + codec (Codec | None): The codec to use for serialization, or None. + compression (Compression | None): The compression method to use, or None. compress_min_bytes (int): Minimum number of bytes before compression is applied. - send_max_bytes (int): Maximum number of bytes allowed per send operation. - headers (Headers): Headers to include in each request. + send_max_bytes (int): Maximum number of bytes allowed to send. + headers (Headers): Headers to include in the protocol. enable_get (bool, optional): Whether to enable GET requests. Defaults to False. - stable_codec (StableCodec | None, optional): An optional stable codec for message encoding/decoding. Defaults to None. - url (URL | None, optional): The URL endpoint for the connection. Defaults to None. - - Returns: - None - + stable_codec (StableCodec | None, optional): An optional stable codec for serialization. Defaults to None. + url (URL | None, optional): An optional URL associated with the protocol. Defaults to None. """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes, headers) self.enable_get = enable_get @@ -170,24 +169,20 @@ def __init__( self.url = url def marshal(self, message: Any) -> bytes: - """Marshal a message into bytes. - - If `enable_get` is True and `stable_codec` is None, raises a `ConnectError` - indicating that the codec does not support stable marshal and cannot use get. - Otherwise, if `enable_get` is True and `stable_codec` is not None, marshals - the message using the `marshal_with_get` method. + """Serializes the given message into bytes using the configured codec. - If `enable_get` is False, marshals the message using the `. + If `enable_get` is True, attempts to use a stable codec for marshaling. + Raises a ConnectError if the codec is not set or if the codec does not support stable marshaling. + Otherwise, delegates marshaling to the superclass implementation. Args: - message (Any): The message to be marshaled. + message (Any): The message object to be serialized. Returns: - bytes: The marshaled message in bytes. + bytes: The serialized message. Raises: - ConnectError: If `enable_get` is True and `stable_codec` is None. - + ConnectError: If the codec is not set or does not support stable marshaling when required. """ if self.enable_get: if self.codec is None: @@ -204,29 +199,22 @@ def marshal(self, message: Any) -> bytes: return super().marshal(message) def marshal_with_get(self, message: Any) -> bytes: - """Marshals the given message and sends it using a GET request. - - This method first marshals the message using the stable codec. If the marshaled - data exceeds the maximum allowed size (`send_max_bytes`) and compression is not - enabled, it raises a `ConnectError`. If the data size is within the limit, it - builds the GET URL and sends the data. - - If the data size exceeds the limit and compression is enabled, it compresses - the data and checks the size again. If the compressed data still exceeds the - limit, it raises a `ConnectError`. Otherwise, it builds the GET URL with the - compressed data and sends it. + """Marshals a message and sends it using a GET request, applying compression if necessary. Args: - message (Any): The message to be marshaled and sent. + message (Any): The message object to be marshaled and sent. Returns: - bytes: The marshaled (and possibly compressed) data. + bytes: The marshaled (and possibly compressed) message data. Raises: - ConnectError: If the data size exceeds the maximum allowed size and compression - is not enabled, or if the compressed data size still exceeds the - limit. + ConnectError: If the stable codec is not set. + ConnectError: If the marshaled message size exceeds `send_max_bytes` and compression is not enabled. + ConnectError: If the compressed message size still exceeds `send_max_bytes`. + Notes: + - If the marshaled message size exceeds `send_max_bytes` and compression is enabled, the message will be compressed before sending. + - The method builds the appropriate GET URL based on whether compression was applied. """ if self.stable_codec is None: raise ConnectError("stable_codec is not set", Code.INTERNAL) @@ -305,12 +293,13 @@ def _write_with_get(self, url: URL) -> None: class ConnectStreamingMarshaler(EnvelopeWriter): - """A class responsible for marshaling messages with optional compression. + """ConnectStreamingMarshaler is responsible for marshaling streaming messages in the Connect protocol. Attributes: - codec (Codec): The codec used for marshaling messages. - compression (Compression | None): The compression method used for compressing messages, if any. - + codec (Codec | None): The codec used for encoding and decoding messages. + compress_min_bytes (int): The minimum payload size (in bytes) before compression is applied. + send_max_bytes (int): The maximum allowed size (in bytes) for a single message to be sent. + compression (Compression | None): The compression algorithm to use, or None for no compression. """ codec: Codec | None @@ -321,14 +310,13 @@ class ConnectStreamingMarshaler(EnvelopeWriter): def __init__( self, codec: Codec | None, compression: Compression | None, compress_min_bytes: int, send_max_bytes: int ) -> None: - """Initialize the ProtocolConnect instance. + """Initializes the marshaler with the specified codec, compression settings, and byte limits. Args: - codec (Codec): The codec to be used for encoding and decoding. - compression (Compression | None): The compression method to be used, or None if no compression is to be applied. + codec (Codec | None): The codec to use for encoding/decoding, or None if not specified. + compression (Compression | None): The compression algorithm to use, or None if not specified. compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes that can be sent in a single message. - + send_max_bytes (int): The maximum number of bytes allowed to send in a single message. """ self.codec = codec self.compress_min_bytes = compress_min_bytes @@ -336,15 +324,18 @@ def __init__( self.compression = compression def marshal_end_stream(self, error: ConnectError | None, response_trailers: Headers) -> bytes: - """Serialize the end-of-stream message with optional error and response trailers into a bytes envelope. + """Serializes the end-of-stream message for a Connect protocol response. + + This method converts the provided error (if any) and response trailers into a JSON object, + encodes it, wraps it in an envelope with the end_stream flag, and returns the final bytes + to be sent over the wire. Args: - error (ConnectError | None): An optional error object to include in the end-of-stream message. - response_trailers (Headers): Headers to include as response trailers. + error (ConnectError | None): The error to include in the end-of-stream message, or None if no error occurred. + response_trailers (Headers): The response trailers to include in the end-of-stream message. Returns: - bytes: The serialized envelope containing the end-of-stream message. - + bytes: The serialized and enveloped end-of-stream message. """ json_obj = end_stream_to_json(error, response_trailers) json_str = json.dumps(json_obj) diff --git a/src/connect/protocol_connect/unmarshaler.py b/src/connect/protocol_connect/unmarshaler.py index 64fbd5d..aef0e15 100644 --- a/src/connect/protocol_connect/unmarshaler.py +++ b/src/connect/protocol_connect/unmarshaler.py @@ -1,4 +1,4 @@ -"""Module providing classes for unmarshaling unary and streaming Connect protocol messages.""" +"""Connect protocol message unmarshaling utilities.""" from collections.abc import ( AsyncIterable, @@ -18,14 +18,17 @@ class ConnectUnaryUnmarshaler: - """A class to handle the unmarshaling of data using a specified codec. + """ConnectUnaryUnmarshaler is responsible for asynchronously reading and unmarshaling messages from an async byte stream. + + This class manages the process of reading data in chunks from an asynchronous stream, enforcing a maximum read size, + optionally decompressing the data, and then decoding (unmarshaling) the message using a provided codec. It also ensures + proper cleanup of the stream resource. Attributes: - codec (Codec): The codec used for unmarshaling the data. - body (bytes): The raw data to be unmarshaled. - read_max_bytes (int): The maximum number of bytes to read. + codec (Codec | None): The codec used for encoding/decoding the message. + read_max_bytes (int): The maximum number of bytes to read from the stream. compression (Compression | None): The compression method to use, if any. - + stream (AsyncIterable[bytes] | None): The asynchronous stream of bytes to be unmarshaled. """ codec: Codec | None @@ -40,14 +43,13 @@ def __init__( compression: Compression | None = None, stream: AsyncIterable[bytes] | None = None, ) -> None: - """Initialize the ProtocolConnect object. + """Initializes the object with the specified codec, maximum read bytes, compression method, and optional asynchronous byte stream. Args: - stream (AsyncIterable[bytes] | None): The stream of bytes to be unmarshaled. - codec (Codec): The codec used for encoding/decoding the message. - read_max_bytes (int): The maximum number of bytes to read. - compression (Compression | None): The compression method to use, if any. - + codec (Codec | None): The codec to use for decoding, or None if not specified. + read_max_bytes (int): The maximum number of bytes to read at once. + compression (Compression | None, optional): The compression method to use, or None for no compression. Defaults to None. + stream (AsyncIterable[bytes] | None, optional): An optional asynchronous iterable byte stream. Defaults to None. """ self.codec = codec self.read_max_bytes = read_max_bytes @@ -55,14 +57,16 @@ def __init__( self.stream = stream async def unmarshal(self, message: Any) -> Any: - """Asynchronously unmarshals a given message using the provided unmarshal function and codec. + """Asynchronously unmarshals a given message using the configured codec. Args: message (Any): The message to be unmarshaled. Returns: - Any: The result of the unmarshaling process. + Any: The unmarshaled message. + Raises: + ConnectError: If the codec is not set. """ if self.codec is None: raise ConnectError("codec is not set", Code.INTERNAL) @@ -70,25 +74,21 @@ async def unmarshal(self, message: Any) -> Any: return await self.unmarshal_func(message, self.codec.unmarshal) async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) -> Any: - """Asynchronously unmarshals a message using the provided function. - - This function reads data from the stream in chunks, checks if the total - bytes read exceed the maximum allowed bytes, and optionally decompresses - the data. It then uses the provided function to unmarshal the data into - the desired format. + """Asynchronously reads data from the stream, optionally decompresses it, and applies a given function to unmarshal the data. Args: - message (Any): The message to be unmarshaled. - func (Callable[[bytes, Any], Any]): A function that takes the raw bytes - and the message, and returns the unmarshaled object. + message (Any): The message context or object to be passed to the unmarshal function. + func (Callable[[bytes, Any], Any]): A callable that takes the raw bytes and the message, and returns the unmarshaled object. Returns: - Any: The unmarshaled object. + Any: The result of the unmarshal function applied to the data and message. Raises: - ConnectError: If the stream is not set, if the message size exceeds the - maximum allowed bytes, or if there is an error during unmarshaling. + ConnectError: If the stream is not set, if the message size exceeds the configured maximum, or if unmarshaling fails. + Notes: + - The stream is closed after processing, regardless of success or failure. + - If compression is enabled, the data is decompressed before unmarshaling. """ if self.stream is None: raise ConnectError("stream is not set", Code.INTERNAL) @@ -125,11 +125,14 @@ async def unmarshal_func(self, message: Any, func: Callable[[bytes, Any], Any]) return obj async def aclose(self) -> None: - """Asynchronously close the stream if it is set. + """Asynchronously closes the underlying stream if it supports asynchronous closing. - This method is intended to be called when the stream is no longer needed - to release any associated resources. + This method attempts to retrieve an asynchronous close method (`aclose`) from the + `stream` attribute using the `get_acallable_attribute` utility. If such a method exists, + it is awaited to properly close the stream and release any associated resources. + Raises: + Any exception raised by the underlying stream's `aclose` method. """ aclose = get_acallable_attribute(self.stream, "aclose") if aclose: @@ -137,14 +140,15 @@ async def aclose(self) -> None: class ConnectStreamingUnmarshaler(EnvelopeReader): - """A class to handle the unmarshaling of streaming data. + """ConnectStreamingUnmarshaler is an asynchronous envelope reader for streaming Connect protocol messages. - Attributes: - codec (Codec): The codec used for unmarshaling data. - compression (Compression | None): The compression method used, if any. - stream (AsyncIterable[bytes] | None): The asynchronous byte stream to read from. - buffer (bytes): The buffer to store incoming data chunks. + This class is responsible for reading, decoding, and handling streamed messages from an asynchronous byte stream, + optionally applying compression and decoding using a specified codec. It also manages end-of-stream errors and + trailer headers, which are additional headers sent after the message body. + Attributes: + _end_stream_error (ConnectError | None): Stores any error that occurred at the end of the stream. + _trailers (Headers): Stores the trailers headers received at the end of the stream. """ _end_stream_error: ConnectError | None @@ -157,32 +161,32 @@ def __init__( stream: AsyncIterable[bytes] | None = None, compression: Compression | None = None, ) -> None: - """Initialize the protocol connection. + """Initializes the object with the specified codec, maximum read bytes, optional stream, and optional compression. Args: - codec (Codec): The codec to use for encoding and decoding data. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read from. Defaults to None. - compression (Compression | None, optional): The compression method to use. Defaults to None. - + codec (Codec | None): The codec to use for decoding, or None. + read_max_bytes (int): The maximum number of bytes to read. + stream (AsyncIterable[bytes] | None, optional): An optional asynchronous byte stream. Defaults to None. + compression (Compression | None, optional): The compression method to use, or None. Defaults to None. """ super().__init__(codec, read_max_bytes, stream, compression) self._end_stream_error = None self._trailers = Headers() async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: - """Asynchronously unmarshals messages from the stream. + """Asynchronously unmarshals a message, yielding objects and end-of-stream flags. + + Iterates over the result of the superclass's `unmarshal` method, yielding each + object and a boolean indicating if it is the end of the stream. If `self.last` + is set, extracts error and trailer information from its data and stores them + in instance variables. Args: - message (Any): The message type to unmarshal. + message (Any): The message to be unmarshaled. Yields: - Any: The unmarshaled message object. - - Raises: - ConnectError: If the stream is not set, if there is an error in the - unmarshaling process, or if there is a protocol error. - + tuple[Any, bool]: A tuple containing the unmarshaled object and a boolean + indicating if it is the end of the stream. """ async for obj, end in super().unmarshal(message): if self.last: @@ -194,23 +198,20 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: @property def trailers(self) -> Headers: - """Return the trailers headers. + """Returns the trailers associated with the response. - Trailers are additional headers sent after the body of the message. + Trailers are additional headers sent after the response body, typically used in protocols like HTTP/2 or gRPC to provide metadata at the end of a message. Returns: - Headers: The trailers headers. - + Headers: The trailers as a Headers object. """ return self._trailers @property def end_stream_error(self) -> ConnectError | None: - """Return the error that occurred at the end of the stream, if any. + """Returns the error that occurred at the end of the stream, if any. Returns: - ConnectError | None: The error that occurred at the end of the stream, - or None if no error occurred. - + ConnectError | None: The error encountered at the end of the stream, or None if no error occurred. """ return self._end_stream_error diff --git a/src/connect/protocol_grpc/constants.py b/src/connect/protocol_grpc/constants.py index 25bd698..cea8759 100644 --- a/src/connect/protocol_grpc/constants.py +++ b/src/connect/protocol_grpc/constants.py @@ -1,4 +1,4 @@ -"""Constants for gRPC protocol implementation in connect-python.""" +"""Constants and settings for gRPC protocol support in connect-python.""" import re import sys diff --git a/src/connect/protocol_grpc/content_type.py b/src/connect/protocol_grpc/content_type.py index 33c8289..6cfc4ff 100644 --- a/src/connect/protocol_grpc/content_type.py +++ b/src/connect/protocol_grpc/content_type.py @@ -12,15 +12,19 @@ def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: - """Return the appropriate gRPC content type string based on the given codec name and whether the request is for gRPC-Web. + """Returns the appropriate gRPC content type string based on the codec name and whether the request is for gRPC-Web. Args: web (bool): Indicates if the content type is for gRPC-Web (True) or standard gRPC (False). codec_name (str): The name of the codec (e.g., "proto", "json"). Returns: - str: The corresponding gRPC content type string. + str: The constructed gRPC content type string. + Notes: + - If `web` is True, returns the gRPC-Web content type prefix concatenated with the codec name. + - If `codec_name` is `CodecNameType.PROTO` and `web` is False, returns the default gRPC content type. + - Otherwise, returns the standard gRPC content type prefix concatenated with the codec name. """ if web: return GRPC_WEB_CONTENT_TYPE_PREFIX + codec_name @@ -32,17 +36,16 @@ def grpc_content_type_from_codec_name(web: bool, codec_name: str) -> str: def grpc_codec_from_content_type(web: bool, content_type: str) -> str: - """Determine the gRPC codec name from the given content type string. + """Determines the gRPC codec name from the given content type string. Args: - web (bool): Indicates whether the request is a gRPC-web request. + web (bool): Indicates whether the context is gRPC-Web (True) or standard gRPC (False). content_type (str): The content type string to parse. Returns: - str: The codec name extracted from the content type. If the content type matches the default gRPC or gRPC-web content type, - returns the default codec name. Otherwise, extracts and returns the codec name from the content type prefix, or returns - the original content type if no known prefix is found. - + str: The codec name extracted from the content type. If the content type matches the default + for the given context, returns the default codec name. Otherwise, returns the codec name + parsed from the content type prefix or the original content type if no prefix is found. """ if (not web and content_type == GRPC_CONTENT_TYPE_DEFAULT) or ( web and content_type == GRPC_WEB_CONTENT_TYPE_DEFAULT @@ -58,16 +61,15 @@ def grpc_codec_from_content_type(web: bool, content_type: str) -> str: def grpc_validate_response_content_type(web: bool, request_codec_name: str, response_content_type: str) -> None: - """Validate that the gRPC response content type matches the expected value based on the request codec and whether gRPC-Web is used. + """Validates the gRPC response content type against the expected content type based on the request codec and context. Args: - web (bool): Indicates if gRPC-Web is being used. - request_codec_name (str): The name of the codec used in the request (e.g., "proto", "json"). - response_content_type (str): The content type returned in the response. + web (bool): Indicates if the request is a gRPC-web request. + request_codec_name (str): The codec name used in the request (e.g., "proto", "json"). + response_content_type (str): The content type received in the response. Raises: - ConnectError: If the response content type does not match the expected value, with an appropriate error code. - + ConnectError: If the response content type does not match the expected content type. """ bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX if web: diff --git a/src/connect/protocol_grpc/error_trailer.py b/src/connect/protocol_grpc/error_trailer.py index c6ab0b8..9f2c9bb 100644 --- a/src/connect/protocol_grpc/error_trailer.py +++ b/src/connect/protocol_grpc/error_trailer.py @@ -1,4 +1,4 @@ -"""Provides functions to convert between ConnectError and gRPC trailer headers.""" +"""Helpers for encoding and decoding gRPC error trailers in Connect protocol.""" import base64 from urllib.parse import quote, unquote @@ -18,20 +18,17 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: - """Convert a ConnectError to gRPC trailer headers. + """Converts a ConnectError into gRPC trailer headers. - Args: - trailer (Headers): The trailer headers dictionary to update with gRPC error information. - error (ConnectError | None): The error to convert. If None, indicates success. - - Side Effects: - Modifies the `trailer` dictionary in-place to include gRPC status, message, and optional details. - - Notes: - - If `error` is None, sets the gRPC status header to "0" (OK). - - If `ConnectError.wire_error` is False, updates the trailer with error metadata excluding protocol headers. - - Serializes error details using protobuf if present, encoding them in base64 for the trailer. + This function populates the provided trailer headers with gRPC status information + based on the given error. If no error is provided, it sets the status to "0" (OK). + If the error is not a wire error, it updates the trailer with the error's metadata, + excluding protocol headers. It then serializes the error status and attaches the + status code, message, and (if present) base64-encoded details to the trailer. + Args: + trailer (Headers): The trailer headers dictionary to be updated. + error (ConnectError | None): The error to convert into gRPC trailer headers. """ if error is None: trailer[GRPC_HEADER_STATUS] = "0" @@ -59,23 +56,22 @@ def grpc_error_to_trailer(trailer: Headers, error: ConnectError | None) -> None: def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: - """Parse gRPC error information from response trailers and constructs a ConnectError if present. + """Parses gRPC error information from response trailers and constructs a ConnectError if present. Args: trailers (Headers): The gRPC response trailers containing error information. Returns: - ConnectError | None: Returns a ConnectError instance if an error is found in the trailers, + ConnectError | None: Returns a ConnectError instance if an error is present in the trailers, or None if the status code indicates success. Raises: - ConnectError: If the grpc-status-details-bin trailer or protobuf error details are invalid. - - The function extracts the gRPC status code, error message, and optional error details from the trailers. - If the status code is missing or invalid, it returns a ConnectError with an appropriate message. - If the status code indicates success ("0"), it returns None. - If error details are present and valid, they are attached to the ConnectError. + ConnectError: If the trailers contain invalid or malformed error details or protobuf data. + The function extracts the gRPC status code, error message, and optional error details from the + trailers. If the status code indicates an error, it constructs and returns a ConnectError with + the relevant information. If the status code is missing or invalid, or if error details are + malformed, a ConnectError is raised with an appropriate message. """ code_header = trailers.get(GRPC_HEADER_STATUS) if code_header is None: @@ -141,10 +137,10 @@ def grpc_error_from_trailer(trailers: Headers) -> ConnectError | None: def decode_binary_header(data: str) -> bytes: - """Decode a base64-encoded string representing a binary header. + """Decodes a base64-encoded string representing a binary header. - If the input string's length is not a multiple of 4, it pads the string with '=' characters - to make it valid base64 before decoding. + If the input string's length is not a multiple of 4, it is padded with '=' characters + to make it valid for base64 decoding. Args: data (str): The base64-encoded string to decode. @@ -152,6 +148,8 @@ def decode_binary_header(data: str) -> bytes: Returns: bytes: The decoded binary data. + Raises: + binascii.Error: If the input is not correctly base64-encoded. """ if len(data) % 4: data += "=" * (-len(data) % 4) diff --git a/src/connect/protocol_grpc/grpc_client.py b/src/connect/protocol_grpc/grpc_client.py index 829b246..6cb9115 100644 --- a/src/connect/protocol_grpc/grpc_client.py +++ b/src/connect/protocol_grpc/grpc_client.py @@ -1,4 +1,4 @@ -"""gRPC client implementation for Connect-Python, supporting async streaming and HTTP/2 communication.""" +"""gRPC protocol client implementation for asynchronous Python clients.""" import asyncio import contextlib @@ -49,13 +49,25 @@ class GRPCClient(ProtocolClient): - """GRPCClient is a protocol client implementation for gRPC communication, supporting both standard and web environments. + """GRPCClient is a gRPC protocol client implementation that manages connection parameters, peer association, and request header configuration for gRPC or HTTP/2 communication. It supports both standard and web environments, handling codec selection, compression negotiation, and header mutation for outgoing requests. Attributes: - params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, pool, and URL. + params (ProtocolClientParams): Configuration parameters for the protocol client, including codec, compression, and connection pool. _peer (Peer): The peer instance associated with this client, representing the remote endpoint. - web (bool): Indicates whether the client is running in a web environment, affecting header and content-type handling. + web (bool): Indicates if the client operates in a web environment, affecting header and content-type handling. + Methods: + __init__(params: ProtocolClientParams, peer: Peer, web: bool) -> None: + Initializes the GRPCClient with the provided parameters, peer, and environment flag. + + peer -> Peer: + Returns the associated Peer object. + + write_request_headers(_: StreamType, headers: Headers) -> None: + Modifies the provided headers dictionary in place to ensure compliance with gRPC protocol requirements, including user agent, content type, compression, and environment-specific headers. + + conn(spec: Spec, headers: Headers) -> StreamingClientConn: + Creates and returns a configured GRPCClientConn instance for the specified protocol/service specification and request headers, initializing marshaling and unmarshaling logic with appropriate codecs and compression settings. """ params: ProtocolClientParams @@ -63,13 +75,12 @@ class GRPCClient(ProtocolClient): web: bool def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: - """Initialize the ProtocolClient with the given parameters. + """Initializes the gRPC client with the given parameters. Args: params (ProtocolClientParams): The parameters for the protocol client. - peer (Peer): The peer instance to be used. + peer (Peer): The peer instance to connect to. web (bool): Indicates whether the client is running in a web environment. - """ self.params = params self._peer = peer @@ -77,33 +88,34 @@ def __init__(self, params: ProtocolClientParams, peer: Peer, web: bool) -> None: @property def peer(self) -> Peer: - """Returns the associated Peer object. + """Returns the associated Peer object for this client. Returns: - Peer: The peer instance associated with this object. - + Peer: The peer instance representing the remote endpoint. """ return self._peer def write_request_headers(self, _: StreamType, headers: Headers) -> None: - """Set and modifies HTTP/2 or gRPC request headers based on the stream type, connection parameters, and environment. + """Sets and modifies HTTP/2 or gRPC-specific headers for an outgoing request. + + This method ensures that required headers such as `User-Agent`, `Content-Type`, and compression-related headers + are present and correctly set based on the client's configuration and the request context. Args: - stream_type (StreamType): The type of stream for which headers are being written. - headers (Headers): The dictionary of headers to be modified or populated. + _ (StreamType): The stream associated with the request (unused in this method). + headers (Headers): The dictionary of HTTP headers to be modified in-place. Behavior: - - Ensures the 'User-Agent' header is set to the default gRPC user agent if not already present. - - If running in a web environment, also sets the 'X-User-Agent' header. - - Sets the 'Content-Type' header according to the codec name and environment. - - Sets the 'Accept-Encoding' header to indicate supported compression. - - If a specific compression is configured and is not the identity, sets the gRPC compression header. - - If multiple compressions are supported, sets the gRPC accept compression header with the supported values. - - For non-web environments, adds the 'Te: trailers' header required for gRPC. - - Note: - This method mutates the provided headers dictionary in place. + - Sets a default `User-Agent` header if not already present. + - For web clients, sets an additional `X-User-Agent` header if not present. + - Sets the `Content-Type` header based on the codec and web mode. + - Sets the `Accept-Encoding` header to indicate supported compression. + - If a specific compression is configured, sets the corresponding gRPC compression header. + - If multiple compressions are supported, sets the `Accept-Encoding` header accordingly. + - For non-web clients, adds the `Te: trailers` header as required by gRPC. + Returns: + None """ if headers.get(HEADER_USER_AGENT, None) is None: headers[HEADER_USER_AGENT] = DEFAULT_GRPC_USER_AGENT @@ -124,19 +136,14 @@ def write_request_headers(self, _: StreamType, headers: Headers) -> None: headers["Te"] = "trailers" def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: - """Create and returns a GRPCClientConn instance configured with the provided specification and headers. + """Creates and returns a GRPCClientConn instance configured with the provided specification and headers. Args: - spec (Spec): The specification object defining the protocol or service interface. - headers (Headers): The request headers to include in the connection. + spec (Spec): The specification object defining the gRPC method and message types. + headers (Headers): The request headers to include in the gRPC call. Returns: - StreamingClientConn: An initialized gRPC streaming client connection. - - Details: - - Configures the connection with parameters such as pool, peer, URL, codec, and compression settings. - - Initializes GRPCMarshaler and GRPCUnmarshaler with appropriate codecs and limits. - - Compression is determined using the provided compression name and available compressions. + StreamingClientConn: An instance of GRPCClientConn configured for streaming communication. """ return GRPCClientConn( @@ -163,24 +170,24 @@ def conn(self, spec: Spec, headers: Headers) -> StreamingClientConn: class GRPCClientConn(StreamingClientConn): - """GRPCClientConn is a gRPC client connection implementation supporting asynchronous streaming requests and responses over HTTP/2. + """GRPCClientConn manages a gRPC client connection over HTTP/2, supporting streaming, compression, and custom codecs. - This class manages the lifecycle of a gRPC client connection, including marshaling and unmarshaling messages, handling request and response headers/trailers, managing compression, and supporting event hooks for request/response events. It integrates with an asynchronous HTTP client connection pool and supports cancellation via asyncio events. + This class is responsible for sending and receiving gRPC messages asynchronously, handling request/response headers and trailers, managing connection pooling, and supporting event hooks for request and response lifecycle events. It abstracts the details of HTTP/2 transport and gRPC protocol compliance, including marshaling/unmarshaling messages, applying compression, and error handling. Attributes: - pool (AsyncConnectionPool): The asynchronous connection pool for managing HTTP/2 connections. - _spec (Spec): The protocol or API specification. - _peer (Peer): Information about the remote peer. - url (URL): The endpoint URL for the connection. - codec (Codec | None): Codec for encoding/decoding messages. - compressions (list[Compression]): Supported compression algorithms. - marshaler (GRPCMarshaler): Marshaler for serializing messages. - unmarshaler (GRPCUnmarshaler): Unmarshaler for deserializing messages. - _response_headers (Headers): HTTP response headers. - _response_trailers (Headers): HTTP response trailers. - _request_headers (Headers): HTTP request headers. + web (bool): Indicates if the connection is for a web environment. + pool (AsyncConnectionPool): The connection pool for managing HTTP/2 connections. + _spec (Spec): The specification object describing the protocol or API. + _peer (Peer): The peer information for the connection. + url (URL): The URL endpoint for the connection. + codec (Codec | None): The codec to use for encoding/decoding messages, or None. + compressions (list[Compression]): List of supported compression algorithms. + marshaler (GRPCMarshaler): The marshaler for serializing messages. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. + _response_headers (Headers): Stores response headers. + _response_trailers (Headers): Stores response trailers. + _request_headers (Headers): Stores request headers. receive_trailers (Callable[[], None] | None): Callback to receive trailers after response. - """ web: bool @@ -211,21 +218,20 @@ def __init__( unmarshaler: GRPCUnmarshaler, event_hooks: None | (Mapping[str, list[EventHook]]) = None, ) -> None: - """Initialize a new instance of the class. + """Initializes a new instance of the gRPC client. Args: - web (bool): Indicates if the connection is for a web environment. - pool (AsyncConnectionPool): The connection pool for managing HTTP/2 connections. - spec (Spec): The specification object describing the protocol or API. + web (bool): Indicates if the client is running in a web environment. + pool (AsyncConnectionPool): The asynchronous connection pool to use for connections. + spec (Spec): The service specification. peer (Peer): The peer information for the connection. - url (URL): The URL endpoint for the connection. - codec (Codec | None): The codec to use for encoding/decoding messages, or None. + url (URL): The URL of the gRPC server. + codec (Codec | None): The codec to use for message serialization, or None. compressions (list[Compression]): List of supported compression algorithms. - request_headers (Headers): Headers to include in outgoing requests. - marshaler (GRPCMarshaler): The marshaler for serializing messages. - unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing messages. - event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. Defaults to None. - + request_headers (Headers): Headers to include in each request. + marshaler (GRPCMarshaler): The marshaler for serializing requests. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing responses. + event_hooks (None | Mapping[str, list[EventHook]], optional): Optional mapping of event hooks for "request" and "response" events. """ event_hooks = {} if event_hooks is None else event_hooks @@ -249,12 +255,20 @@ def __init__( @property def spec(self) -> Spec: - """Return the specification details.""" + """Returns the specification object associated with this client. + + Returns: + Spec: The specification instance used by the client. + """ return self._spec @property def peer(self) -> Peer: - """Return the peer information.""" + """Returns the current Peer instance associated with this client. + + Returns: + Peer: The peer instance representing the remote endpoint for this client. + """ return self._peer async def receive(self, message: Any, abort_event: asyncio.Event | None) -> AsyncIterator[Any]: @@ -314,26 +328,31 @@ def _receive_trailers(self, response: httpcore.Response) -> None: @property def request_headers(self) -> Headers: - """Return the request headers.""" + """Returns the headers to be included in the gRPC request. + + Returns: + Headers: The headers used for the gRPC request. + """ return self._request_headers async def send( self, messages: AsyncIterable[Any], timeout: float | None, abort_event: asyncio.Event | None ) -> None: - """Send a gRPC request asynchronously using HTTP/2 via httpcore, handling streaming messages, timeouts, and abort events. + """Sends a gRPC request asynchronously using HTTP/2 via httpcore. Args: messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent as the request body. timeout (float | None): Optional timeout in seconds for the request. If provided, sets the gRPC timeout header. - abort_event (asyncio.Event | None): Optional asyncio event that, if set, will abort the request and raise a cancellation error. + abort_event (asyncio.Event | None): Optional asyncio event that, when set, aborts the request. Raises: - ConnectError: If the request is aborted before or during execution, or if an error occurs during the HTTP request. + ConnectError: If the request is aborted via the abort_event. + Exception: Propagates exceptions raised by the underlying HTTP client or marshaling/unmarshaling logic. Side Effects: - Invokes registered request and response event hooks. - - Sets up the response stream and trailers for further processing. - - Validates the HTTP response. + - Sets up the response stream for unmarshaling. + - Validates the response after receiving it. """ extensions = {} @@ -421,54 +440,59 @@ async def _validate_response(self, response: httpcore.Response) -> None: @property def response_headers(self) -> Headers: - """Return the response headers.""" + """Returns the headers received in the response. + + Returns: + Headers: The response headers. + """ return self._response_headers @property def response_trailers(self) -> Headers: - """Return response trailers.""" + """Returns the response trailers as headers. + + Returns: + Headers: The response trailers received from the gRPC call. + """ return self._response_trailers def on_request_send(self, fn: EventHook) -> None: - """Register a callback function to be invoked when a request is sent. + """Registers a callback function to be invoked when a request is sent. Args: fn (EventHook): The callback function to be added to the "request" event hook. Returns: None - """ self._event_hooks["request"].append(fn) async def aclose(self) -> None: - """Asynchronously closes the underlying unmarshaler resource. + """Asynchronously closes the resources associated with the client. - This method should be called to properly release any resources held by the unmarshaler, - such as open network connections or file handles, when they are no longer needed. + This method ensures that any resources held by the `unmarshaler` are properly released. + It should be called when the client is no longer needed to avoid resource leaks. """ await self.unmarshaler.aclose() def grpc_encode_timeout(timeout: float) -> str: - """Encode a timeout value (in seconds) into the gRPC timeout format string. - - The gRPC timeout format is a decimal number with a time unit suffix, where the unit can be: - - 'H' for hours - - 'M' for minutes - - 'S' for seconds - - 'm' for milliseconds - - 'u' for microseconds - - 'n' for nanoseconds + """Encodes a timeout value (in seconds) into a gRPC-compatible timeout string. - If the timeout is less than or equal to zero, returns "0n". + The gRPC protocol requires timeout values to be specified as a string with a numeric value and a unit suffix. + This function converts a floating-point timeout (in seconds) into the appropriate string format, choosing the + largest unit possible without exceeding the maximum value allowed for that unit. Args: timeout (float): The timeout value in seconds. Returns: - str: The timeout encoded as a gRPC timeout string. + str: The timeout encoded as a gRPC timeout string (e.g., "10S", "500m", "0n"). + Notes: + - If the timeout is less than or equal to zero, "0n" is returned. + - The function uses predefined unit-to-seconds mappings and a maximum value per unit. + - If the timeout exceeds all unit ranges, it is encoded in hours ("H"). """ if timeout <= 0: return "0n" diff --git a/src/connect/protocol_grpc/grpc_handler.py b/src/connect/protocol_grpc/grpc_handler.py index 5d25b6b..dd82a87 100644 --- a/src/connect/protocol_grpc/grpc_handler.py +++ b/src/connect/protocol_grpc/grpc_handler.py @@ -1,4 +1,4 @@ -"""gRPC and gRPC-Web protocol handler implementation for Connect Python server.""" +"""gRPC and gRPC-web protocol handler and connection classes for Connect Python framework.""" from collections.abc import AsyncIterable, AsyncIterator from http import HTTPMethod @@ -42,19 +42,18 @@ class GRPCHandler(ProtocolHandler): - """GRPCHandler is a protocol handler for gRPC and gRPC-Web requests. + """GRPCHandler is a protocol handler for managing gRPC and gRPC-web requests. - This class implements the ProtocolHandler interface to handle gRPC protocol requests, - including negotiation of compression, codec selection, and connection management for - both standard gRPC and gRPC-Web. It supports content type negotiation, payload handling, - and manages the lifecycle of a gRPC connection, including streaming and non-streaming - requests. + This class is responsible for handling incoming gRPC protocol requests, negotiating compression, + selecting codecs, and managing the connection lifecycle for both standard gRPC and gRPC-web protocols. + It provides methods to determine accepted HTTP methods, supported content types, and whether a given + payload can be handled. The main entry point for handling a connection is the asynchronous `conn` method, + which negotiates protocol details and returns a connection handler for streaming communication. Attributes: - params (ProtocolHandlerParams): Configuration parameters for the handler, including codecs and compressions. - web (bool): Indicates if the handler is for gRPC-Web. - accept (list[str]): List of accepted content types. - + params (ProtocolHandlerParams): Configuration parameters for the protocol handler, including codecs and compressions. + web (bool): Indicates if the handler is for gRPC-web protocol. + accept (list[str]): List of accepted MIME content types. """ params: ProtocolHandlerParams @@ -62,16 +61,12 @@ class GRPCHandler(ProtocolHandler): accept: list[str] def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) -> None: - """Initialize the ProtocolHandler with the given parameters. + """Initializes the handler with the given parameters. Args: - params (ProtocolHandlerParams): The parameters required for the protocol handler. + params (ProtocolHandlerParams): The parameters for the protocol handler. web (bool): Indicates whether the handler is for web usage. - accept (list[str]): A list of accepted content types. - - Returns: - None - + accept (list[str]): List of accepted content types. """ self.params = params self.web = web @@ -79,33 +74,30 @@ def __init__(self, params: ProtocolHandlerParams, web: bool, accept: list[str]) @property def methods(self) -> list[HTTPMethod]: - """Returns a list of allowed HTTP methods for gRPC protocol. + """Returns a list of allowed HTTP methods for gRPC handlers. Returns: - list[HTTPMethod]: A list containing the HTTP methods permitted for gRPC communication. - + list[HTTPMethod]: The list of HTTP methods permitted for gRPC endpoints. """ return GRPC_ALLOWED_METHODS def content_types(self) -> list[str]: - """Return a list of accepted content types. + """Returns a list of accepted content types. Returns: list[str]: A list of MIME types that are accepted. - """ return self.accept def can_handle_payload(self, _: Request, content_type: str) -> bool: - """Determine if the given content type is supported by this handler. + """Determines if the handler can process a request with the specified content type. Args: - _ (Request): The request object (unused). - content_type (str): The MIME type of the payload to check. + _ (Request): The incoming request object (unused). + content_type (str): The MIME type of the request payload. Returns: - bool: True if the content type is accepted, False otherwise. - + bool: True if the content type is accepted by the handler, False otherwise. """ return content_type in self.accept @@ -116,18 +108,22 @@ async def conn( response_trailers: Headers, writer: ServerResponseWriter, ) -> StreamingHandlerConn | None: - """Handle a connection request. + """Handles the setup of a gRPC streaming connection, negotiating compression, codecs, and protocol details. Args: - request (Request): The incoming request object. - response_headers (Headers): The headers to be sent in the response. - response_trailers (Headers): The trailers to be sent in the response. - writer (ServerResponseWriter): The writer used to send the response. - is_streaming (bool, optional): Whether this is a streaming connection. Defaults to False. + request (Request): The incoming gRPC request object containing headers and client information. + response_headers (Headers): Headers to be sent in the response. + response_trailers (Headers): Trailers to be sent at the end of the response. + writer (ServerResponseWriter): The writer used to send responses to the client. Returns: - StreamingHandlerConn | None: The connection handler or None if not implemented. + StreamingHandlerConn | None: Returns a configured GRPCHandlerConn instance for handling the connection, + or None if an error occurred during negotiation (in which case an error is sent to the client). + Side Effects: + - Negotiates compression and codec based on request and server capabilities. + - Sets appropriate response headers for gRPC protocol. + - Sends an error to the client and returns None if negotiation fails. """ content_encoding = request.headers.get(GRPC_HEADER_COMPRESSION) accept_encoding = request.headers.get(GRPC_HEADER_ACCEPT_COMPRESSION) @@ -182,19 +178,22 @@ async def conn( class GRPCHandlerConn(StreamingHandlerConn): - """GRPCHandlerConn is a handler class for managing gRPC protocol connections within a streaming server context. + """GRPCHandlerConn is a connection handler for gRPC protocol requests, supporting both standard and web environments. - This class encapsulates the logic for handling gRPC requests and responses, including marshaling and unmarshaling messages, - managing request and response headers/trailers, handling timeouts, and enforcing protocol-specific constraints for unary and streaming operations. + This class manages the lifecycle of a gRPC connection, including parsing request headers, handling message + marshaling/unmarshaling, managing response headers and trailers, and sending responses or errors to the client. + It supports both streaming and unary operations, and can adapt its behavior for web-based gRPC requests. Attributes: - _spec (Spec): The specification object describing the protocol or service. - _peer (Peer): The peer information for the current connection. - _request_headers (Headers): The headers received with the request. - _response_headers (Headers): The headers to include in the response. - _response_trailers (Headers): The trailers to include in the response. - _is_streaming (bool): Indicates if the connection is streaming. - + web (bool): Indicates if the connection is for a web environment. + writer (ServerResponseWriter): The writer used to send responses to the client. + marshaler (GRPCMarshaler): Marshals response messages into bytes. + unmarshaler (GRPCUnmarshaler): Unmarshals request messages from bytes. + _spec (Spec): The protocol or service specification. + _peer (Peer): Information about the remote peer. + _request_headers (Headers): Headers received with the request. + _response_headers (Headers): Headers to include in the response. + _response_trailers (Headers): Trailers to include in the response. """ web: bool @@ -219,19 +218,18 @@ def __init__( response_headers: Headers, response_trailers: Headers | None = None, ) -> None: - """Initialize a new instance of the class. + """Initializes a new instance of the class. Args: - web (bool): Indicates if the connection is for a web environment. - writer (ServerResponseWriter): The writer used to send responses to the client. - spec (Spec): The specification object describing the protocol or service. - peer (Peer): The peer information for the current connection. - marshaler (GRPCMarshaler): The marshaler used to serialize response messages. - unmarshaler (GRPCUnmarshaler): The unmarshaler used to deserialize request messages. - request_headers (Headers): The headers received with the request. - response_headers (Headers): The headers to include in the response. - response_trailers (Headers | None, optional): The trailers to include in the response. Defaults to None. - is_streaming (bool, optional): Indicates if the connection is streaming. Defaults to False. + web (bool): Indicates if the handler is for a web context. + writer (ServerResponseWriter): The response writer for sending server responses. + spec (Spec): The specification object for the gRPC protocol. + peer (Peer): The peer information for the connection. + marshaler (GRPCMarshaler): The marshaler for serializing responses. + unmarshaler (GRPCUnmarshaler): The unmarshaler for deserializing requests. + request_headers (Headers): The headers received in the request. + response_headers (Headers): The headers to be sent in the response. + response_trailers (Headers | None, optional): The trailers to be sent in the response. Defaults to None. """ self.web = web @@ -245,18 +243,17 @@ def __init__( self._response_trailers = response_trailers if response_trailers is not None else Headers() def parse_timeout(self) -> float | None: - """Parse the gRPC timeout value from the request headers and returns it as seconds. + """Parses the gRPC timeout value from the request headers and returns it as seconds. Returns: float | None: The timeout value in seconds if present and valid, otherwise None. Raises: - ConnectError: If the timeout value is present but invalid or too long. + ConnectError: If the timeout value is present but invalid or exceeds the maximum allowed duration. Notes: - - The timeout is extracted from the gRPC header and must match the expected format. - - If the timeout unit is hours and exceeds the maximum allowed, None is returned. - + - If the timeout unit is hours and exceeds the maximum allowed hours, None is returned. + - The timeout is extracted from the request headers using the GRPC_HEADER_TIMEOUT key. """ timeout = self._request_headers.get(GRPC_HEADER_TIMEOUT) if not timeout: @@ -279,34 +276,30 @@ def parse_timeout(self) -> float | None: @property def spec(self) -> Spec: - """Returns the specification object associated with this instance. + """Returns the specification object associated with this handler. Returns: - Spec: The specification object. - + Spec: The specification instance for this handler. """ return self._spec @property def peer(self) -> Peer: - """Returns the associated Peer object. - - Returns: - Peer: The peer instance associated with this object. + """Returns the current Peer instance associated with this handler. + :returns: The Peer object representing the current peer. + :rtype: Peer """ return self._peer async def receive(self, message: Any) -> AsyncIterator[Any]: - """Receives a message and processes it. + """Asynchronously receives and yields deserialized objects from the given message. Args: - message (Any): The message to be received and processed. - - Returns: - AsyncIterator[Any]: An async iterator yielding message(s). For non-streaming operations, - this will yield exactly one item. + message (Any): The incoming message to be unmarshaled. + Yields: + Any: Each deserialized object extracted from the message. """ async for obj, _ in self.unmarshaler.unmarshal(message): yield obj @@ -316,21 +309,23 @@ def request_headers(self) -> Headers: """Returns the headers associated with the current request. Returns: - Headers: The headers of the request. - + Headers: The headers of the current request. """ return self._request_headers async def send(self, messages: AsyncIterable[Any]) -> None: - """Send message(s) by marshaling them into bytes. + """Asynchronously sends messages to the client using a streaming response. + + Depending on the `web` attribute, constructs and writes a `StreamingResponse` with appropriate headers and optional trailers. Args: - messages (AsyncIterable[Any]): The message(s) to be sent. For unary operations, - this should be an iterable with a single item. + messages (AsyncIterable[Any]): An asynchronous iterable of messages to be sent to the client. Returns: None + Raises: + Any exceptions raised by the underlying writer or StreamingResponse. """ if self.web: await self.writer.write( @@ -352,43 +347,41 @@ async def send(self, messages: AsyncIterable[Any]) -> None: @property def response_headers(self) -> Headers: - """Returns the response headers associated with the current request. + """Returns the response headers associated with the current gRPC call. Returns: - Headers: The headers returned in the response. - + Headers: The headers sent in the gRPC response. """ return self._response_headers @property def response_trailers(self) -> Headers: - """Returns the response trailers as headers. + """Returns the response trailers as a Headers object. - Response trailers are additional metadata sent by the server after the response body, - typically used in gRPC and HTTP/2 protocols. + Response trailers are additional HTTP headers sent after the response body in gRPC communication. + They may contain metadata or status information relevant to the response. Returns: - Headers: The response trailers associated with the current response. - + Headers: The response trailers associated with the gRPC response. """ return self._response_trailers async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[bytes]: - """Asynchronously sends marshaled messages and yields them as byte streams. + """Asynchronously sends marshaled messages and yields them as bytes. + + Iterates over the provided asynchronous iterable of messages, marshals each message, + and yields the resulting bytes. Handles exceptions by converting them to a ConnectError + if necessary, and appends error information to the response trailers. If running in a web + context, marshals and yields the response trailers as the final message. Args: messages (AsyncIterable[Any]): An asynchronous iterable of messages to be marshaled and sent. Yields: - bytes: Marshaled message bytes, and optionally marshaled web trailers if in web mode. + bytes: The marshaled message bytes, and optionally marshaled web trailers if in a web context. Raises: - ConnectError: If an error occurs during marshaling or sending messages, a ConnectError is set and handled. - - Notes: - - Errors encountered during message marshaling are converted to ConnectError and added to response trailers. - - If running in web mode (`self.web` is True), marshaled web trailers are yielded at the end. - + ConnectError: If an internal error occurs during marshaling or sending. """ error: ConnectError | None = None try: @@ -404,17 +397,17 @@ async def _send_messages(self, messages: AsyncIterable[Any]) -> AsyncIterator[by yield body async def send_error(self, error: ConnectError) -> None: - """Send an error response over gRPC by converting the provided ConnectError into gRPC trailers. + """Sends an error response using gRPC error trailers. + + Depending on the context (web or non-web), this method serializes and writes the error information + to the response stream. For web clients, it marshals the trailers and writes them as the response body. + For non-web clients, it sends the trailers directly. Args: - error (ConnectError): The error to be sent as a gRPC trailer. + error (ConnectError): The error to be sent in the response. Returns: None - - This method updates the response trailers with the error information and writes a streaming response - with the appropriate headers and trailers to the client. - """ grpc_error_to_trailer(self.response_trailers, error) if self.web: diff --git a/src/connect/protocol_grpc/grpc_protocol.py b/src/connect/protocol_grpc/grpc_protocol.py index 56eff49..955e87c 100644 --- a/src/connect/protocol_grpc/grpc_protocol.py +++ b/src/connect/protocol_grpc/grpc_protocol.py @@ -1,4 +1,4 @@ -"""Protocol implementation for handling gRPC and gRPC-Web requests.""" +"""gRPC and gRPC-Web protocol implementation for Connect framework.""" from connect.codec import CodecNameType from connect.connect import ( @@ -25,24 +25,37 @@ class ProtocolGRPC(Protocol): - """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web requests. + """ProtocolGRPC is a protocol implementation for handling gRPC and gRPC-Web communication. + + This class provides methods to create protocol handlers and clients that are configured + for either standard gRPC or gRPC-Web, depending on the `web` flag provided at initialization. Attributes: - web (bool): Indicates whether to use gRPC-Web (True) or standard gRPC (False). + web (bool): Indicates whether the protocol instance is configured for gRPC-Web. + + Methods: + __init__(web: bool) -> None: + Initializes the ProtocolGRPC instance, setting the mode to gRPC or gRPC-Web. + + handler(params: ProtocolHandlerParams) -> ProtocolHandler: + Creates and returns a GRPCHandler instance with content types determined by the codecs + and the protocol mode (gRPC or gRPC-Web). + client(params: ProtocolClientParams) -> ProtocolClient: + Creates and returns a GRPCClient instance, configuring the peer protocol and address + based on the provided parameters and the protocol mode. """ def __init__(self, web: bool) -> None: - """Initialize the instance. + """Initializes the instance with the specified web mode. Args: - web (bool): Indicates whether the instance is for web usage. - + web (bool): Indicates whether to use web mode. """ self.web = web def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: - """Create and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. + """Creates and returns a GRPCHandler instance configured with appropriate content types based on the provided parameters. Args: params (ProtocolHandlerParams): The parameters containing codec information and other handler configuration. @@ -53,9 +66,8 @@ def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: Behavior: - Determines the default and prefix content types based on whether gRPC-Web is enabled. - Constructs a list of supported content types from the available codecs. - - Adds the bare content type if the PROTO codec is present. + - Adds the bare content type if the 'proto' codec is present. - Returns a GRPCHandler with the computed content types. - """ bare, prefix = GRPC_CONTENT_TYPE_DEFAULT, GRPC_CONTENT_TYPE_PREFIX if self.web: @@ -71,14 +83,17 @@ def handler(self, params: ProtocolHandlerParams) -> ProtocolHandler: return GRPCHandler(params, self.web, content_types) def client(self, params: ProtocolClientParams) -> ProtocolClient: - """Create and return a GRPCClient instance. + """Creates and returns a ProtocolClient instance configured for gRPC or gRPC-Web communication. Args: - params (ProtocolClientParams): The parameters required to initialize the client. + params (ProtocolClientParams): Parameters required to configure the protocol client, including the target URL. Returns: - ProtocolClient: An instance of GRPCClient. + ProtocolClient: An instance of GRPCClient initialized with the provided parameters and peer configuration. + Notes: + - If the instance is configured for web usage (`self.web` is True), the protocol is set to gRPC-Web. + - The peer's address is constructed from the host and port in `params.url`, defaulting to an empty host and port 80 if not specified. """ peer = Peer( address=Address(host=params.url.host or "", port=params.url.port or 80), diff --git a/src/connect/protocol_grpc/marshaler.py b/src/connect/protocol_grpc/marshaler.py index 3eb9740..59e6354 100644 --- a/src/connect/protocol_grpc/marshaler.py +++ b/src/connect/protocol_grpc/marshaler.py @@ -1,4 +1,4 @@ -"""Marshaler for encoding messages into the gRPC wire format, including gRPC-Web trailer support.""" +"""gRPC-Web marshaler for serializing messages and trailers with optional compression.""" from connect.codec import Codec from connect.compression import Compression @@ -7,19 +7,19 @@ class GRPCMarshaler(EnvelopeWriter): - """GRPCMarshaler is responsible for marshaling messages into the gRPC wire format. + """GRPCMarshaler is responsible for serializing and deserializing messages and trailers according to the gRPC-Web protocol. - Args: - codec (Codec | None): The codec used for encoding/decoding messages. - compression (Compression | None): The compression algorithm to use, if any. - compress_min_bytes (int): Minimum message size in bytes before compression is applied. - send_max_bytes (int): Maximum allowed size of a message to send. + This class extends EnvelopeWriter to provide gRPC-Web specific marshaling logic, including support for optional compression, configurable minimum compression thresholds, and maximum send size enforcement. It also provides utilities for encoding HTTP trailer headers into the gRPC-Web trailer envelope format. - Methods: - marshal(messages: AsyncIterable[bytes]) -> AsyncIterator[bytes]: - Asynchronously marshals a stream of message bytes into the gRPC wire format. - Yields marshaled message bytes ready for transmission. + Attributes: + codec (Codec | None): The codec used for encoding and decoding messages. + compression (Compression | None): The compression algorithm used for message payloads. + compress_min_bytes (int): The minimum payload size (in bytes) before compression is applied. + send_max_bytes (int): The maximum allowed size (in bytes) for a single outgoing message. + Methods: + marshal_web_trailers(trailers: Headers) -> bytes: + Serializes HTTP trailer headers into a gRPC-Web trailer envelope. """ def __init__( @@ -29,28 +29,27 @@ def __init__( compress_min_bytes: int, send_max_bytes: int, ) -> None: - """Initialize the protocol with the specified configuration. + """Initializes the object with the specified codec, compression settings, minimum bytes for compression, and maximum bytes to send. Args: - codec (Codec | None): The codec to use for encoding/decoding messages, or None for default. + codec (Codec | None): The codec to use for encoding/decoding, or None for default. compression (Compression | None): The compression algorithm to use, or None for no compression. compress_min_bytes (int): The minimum number of bytes before compression is applied. - send_max_bytes (int): The maximum number of bytes allowed to send in a single message. + send_max_bytes (int): The maximum number of bytes allowed to send. Returns: None - """ super().__init__(codec, compression, compress_min_bytes, send_max_bytes) async def marshal_web_trailers(self, trailers: Headers) -> bytes: - """Serialize HTTP trailer headers into a gRPC-Web trailer envelope. + """Serializes HTTP trailer headers into a gRPC-Web envelope. Args: trailers (Headers): A dictionary-like object containing HTTP trailer headers. Returns: - bytes: The serialized gRPC-Web trailer envelope containing the trailer headers. + bytes: The gRPC-Web envelope containing the serialized trailer headers. """ lines = [] diff --git a/src/connect/protocol_grpc/unmarshaler.py b/src/connect/protocol_grpc/unmarshaler.py index 222a02d..f92bf7b 100644 --- a/src/connect/protocol_grpc/unmarshaler.py +++ b/src/connect/protocol_grpc/unmarshaler.py @@ -1,4 +1,4 @@ -"""Module for gRPC message unmarshaling using EnvelopeReader and related utilities.""" +"""GRPCUnmarshaler provides async gRPC message unmarshaling with web trailer support.""" from collections.abc import AsyncIterable, AsyncIterator from copy import copy @@ -12,19 +12,25 @@ class GRPCUnmarshaler(EnvelopeReader): - """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC message unmarshaling. + """GRPCUnmarshaler is a specialized EnvelopeReader for handling gRPC protocol messages. - Args: - codec (Codec | None): The codec used for decoding messages. - read_max_bytes (int): The maximum number of bytes to read from the stream. - stream (AsyncIterable[bytes] | None, optional): The asynchronous byte stream to read messages from. - compression (Compression | None, optional): Compression algorithm to use for decompressing messages. + With support for both standard and web environments, it provides asynchronous + unmarshaling of messages, extracting and storing HTTP/2 trailers when operating in + web mode. + + Attributes: + _web_trailers (Headers | None): Stores the trailers received in the last envelope, if any. Methods: - async unmarshal(message: Any) -> AsyncIterator[Any]: - Asynchronously unmarshals the given message, yielding each decoded object. - Iterates over the results of the internal _unmarshal method, yielding only the object part of each tuple. + __init__(web, codec, read_max_bytes, stream=None, compression=None): + Initializes the GRPCUnmarshaler with the specified parameters. + + async unmarshal(message): + Asynchronously unmarshals a given message, yielding each resulting object and + handling trailers in web mode. + web_trailers: + Returns the trailers received in the last envelope, or None if no trailers were received. """ web: bool @@ -38,29 +44,37 @@ def __init__( stream: AsyncIterable[bytes] | None = None, compression: Compression | None = None, ) -> None: - """Initialize the protocol gRPC handler. + """Initializes the object with the given parameters. Args: - web (bool): Indicates if the connection is for a web environment. - codec (Codec | None): The codec to use for encoding/decoding messages. Can be None. - read_max_bytes (int): The maximum number of bytes to read from the stream. + web (bool): Indicates whether the web mode is enabled. + codec (Codec | None): The codec to use for decoding, or None. + read_max_bytes (int): The maximum number of bytes to read. stream (AsyncIterable[bytes] | None, optional): An asynchronous iterable stream of bytes. Defaults to None. - compression (Compression | None, optional): The compression method to use. Defaults to None. + compression (Compression | None, optional): The compression method to use, or None. Defaults to None. + Returns: + None """ super().__init__(codec, read_max_bytes, stream, compression) self.web = web self._web_trailers = None async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: - """Asynchronously unmarshals a given message and yields each resulting object. + """Asynchronously unmarshals a message and yields objects along with an end flag. + + Iterates over the result of the superclass's `unmarshal` method, processing each object and its corresponding end flag. + When the end flag is True, validates and parses the envelope's data as HTTP/2 trailers, storing them in the instance. + Raises a ConnectError if the envelope is empty or has invalid flags. Args: message (Any): The message to be unmarshaled. Yields: - Any: Each object obtained from unmarshaling the message. + tuple[Any, bool]: A tuple containing the unmarshaled object and a boolean indicating if it is the end of the stream. + Raises: + ConnectError: If the envelope is empty or has invalid flags. """ async for obj, end in super().unmarshal(message): if end: @@ -96,10 +110,9 @@ async def unmarshal(self, message: Any) -> AsyncIterator[tuple[Any, bool]]: @property def web_trailers(self) -> Headers | None: - """Return the trailers received in the last envelope. + """Returns the HTTP trailers associated with the web response, if any. Returns: - Headers | None: The trailers received in the last envelope, or None if no trailers were received. - + Headers | None: The HTTP trailers as a Headers object if present, otherwise None. """ return self._web_trailers diff --git a/src/connect/request.py b/src/connect/request.py index 924bc1b..c193cb1 100644 --- a/src/connect/request.py +++ b/src/connect/request.py @@ -1,3 +1,3 @@ -"""Request module for handling requests.""" +"""HTTP request handling module using Starlette's Request class.""" from starlette.requests import Request as Request diff --git a/src/connect/response.py b/src/connect/response.py index 978b2e3..fc69fbc 100644 --- a/src/connect/response.py +++ b/src/connect/response.py @@ -1,4 +1,4 @@ -"""Response module for the connect package.""" +"""Streaming HTTP response implementation with async content delivery and trailer support.""" import typing from functools import partial @@ -17,19 +17,28 @@ class StreamingResponse(Response): - """A streaming HTTP response class that supports HTTP trailers. + """A streaming HTTP response class that supports asynchronous content delivery with optional trailers. - This class extends the standard response to allow sending HTTP trailers - at the end of a streamed response body, if supported by the ASGI server. + This class extends the base Response class to handle streaming content delivery, + allowing for efficient transmission of large or dynamically generated content. + It supports both synchronous and asynchronous iterables as content sources, + HTTP trailers, and background tasks. Attributes: - body_iterator (AsyncContentStream): An asynchronous iterator over the response body content. - status_code (int): HTTP status code for the response. - media_type (str | None): The media type of the response. - background (BackgroundTask | None): Optional background task to run after response is sent. - headers (Mapping[str, str]): HTTP headers for the response. - _trailers (Mapping[str, str] | None): HTTP trailers to send after the response body. + body_iterator (AsyncContentStream): An async iterator over the response body content. + Features: + - Automatic conversion of sync iterables to async using thread pools + - HTTP trailer support with proper header advertisement + - Client disconnect detection and handling + - Background task execution after response completion + - ASGI spec version compatibility (2.0+ with enhanced features for 2.4+) + + + Note: + For ASGI spec versions < 2.4, client disconnect detection is handled through + concurrent task monitoring. For versions >= 2.4, native disconnect detection + via OSError is used for better performance. """ body_iterator: AsyncContentStream @@ -44,20 +53,18 @@ def __init__( media_type: str | None = None, background: BackgroundTask | None = None, ) -> None: - """Initialize a response object with optional HTTP trailers. + """Initialize a streaming response. Args: - content (ContentStream): The response body content, which can be an async iterable or a regular iterable. - status_code (int, optional): HTTP status code for the response. Defaults to 200. - headers (typing.Mapping[str, str] | None, optional): HTTP headers to include in the response. Defaults to None. - trailers (typing.Mapping[str, str] | None, optional): HTTP trailers to include in the response. Defaults to None. - media_type (str | None, optional): The media type of the response. If None, uses the default media type. Defaults to None. - background (BackgroundTask | None, optional): A background task to run after the response is sent. Defaults to None. - - Notes: - - If `content` is not an async iterable, it will be wrapped to run in a thread pool. - - If trailers are provided, their names will be added to the "Trailer" header. - + content: The content to stream, either an async iterable or content that will be iterated in a thread pool. + status_code: HTTP status code for the response. Defaults to 200. + headers: Optional mapping of HTTP headers to include in the response. + trailers: Optional mapping of HTTP trailers to include in the response. + media_type: Optional media type for the response. If None, uses the existing media_type. + background: Optional background task to run after the response is sent. + + Returns: + None """ if isinstance(content, typing.AsyncIterable): self.body_iterator = content @@ -99,23 +106,27 @@ async def _stream_response(self, send: Send, trailers_supported: bool) -> None: }) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Handle the ASGI call interface for streaming HTTP responses with optional support for HTTP trailers. - - This method determines the ASGI spec version and whether HTTP response trailers are supported. - For ASGI spec version >= 2.4, it streams the response and handles client disconnects. - For earlier versions, it concurrently streams the response and listens for client disconnects, - cancelling the response stream if a disconnect is detected. + """ASGI application callable that handles HTTP response streaming with disconnect detection. - After sending the response, if a background task is provided, it is awaited. + This method implements the ASGI application interface, handling different ASGI spec versions + and managing client disconnections during response streaming. Args: - scope (Scope): The ASGI connection scope. - receive (Receive): Awaitable callable to receive ASGI messages. - send (Send): Awaitable callable to send ASGI messages. + scope (Scope): ASGI scope dictionary containing request information and server capabilities + receive (Receive): ASGI receive callable for receiving messages from the client + send (Send): ASGI send callable for sending messages to the client + + Returns: + None Raises: - ClientDisconnect: If the client disconnects during response streaming. + ClientDisconnect: When an OSError occurs during response streaming (client disconnected) + Notes: + - For ASGI spec version 2.4+: Uses simple streaming with OSError handling + - For older versions: Uses task group with concurrent disconnect listening + - Supports HTTP trailers when available in the server extensions + - Executes background tasks after response completion if configured """ spec_version = tuple(map(int, scope.get("asgi", {}).get("spec_version", "2.0").split("."))) trailers_supported = "http.response.trailers" in scope.get("extensions", {}) diff --git a/src/connect/response_writer.py b/src/connect/response_writer.py index e942145..6c56691 100644 --- a/src/connect/response_writer.py +++ b/src/connect/response_writer.py @@ -1,4 +1,4 @@ -"""Module containing the ServerResponseWriter class.""" +"""Single-use asynchronous response writer for server communication with thread-safe queue mechanism.""" import asyncio @@ -6,30 +6,49 @@ class ServerResponseWriter: - """A writer class for handling server responses asynchronously using an asyncio.Queue. + """A single-use asynchronous response writer for server communication. + + This class provides a thread-safe mechanism for writing and receiving responses + using an internal asyncio queue with a maximum size of 1. The writer is designed + for single-use scenarios where only one response can be written and received + before the writer becomes closed. Attributes: - queue (asyncio.Queue[Response]): The queue used to store a single response. - is_closed (bool): Indicates whether the writer has been closed. + queue (asyncio.Queue[Response]): Internal queue for storing responses with maxsize=1. + is_closed (bool): Flag indicating whether the writer has been closed. + Note: + The response writer automatically closes after receiving a response, + making it unsuitable for multiple read/write operations. """ queue: asyncio.Queue[Response] is_closed: bool = False def __init__(self) -> None: - """Initialize the instance with an asyncio queue of maximum size 1.""" + """Initialize the ResponseWriter with an async queue. + + Creates an asyncio Queue with a maximum size of 1 to handle response writing + in an asynchronous manner. The queue acts as a buffer for managing responses + that need to be written. + """ self.queue = asyncio.Queue(maxsize=1) async def write(self, response: Response) -> None: - """Asynchronously writes a response to the internal queue. + """Write a response to the queue for processing. + + This method adds a response to the internal queue for asynchronous processing. + The response writer must not be closed when calling this method. Args: - response (Response): The response object to be written. + response (Response): The response object to be written to the queue. Raises: - RuntimeError: If the response writer is already closed. + RuntimeError: If the response writer has been closed and cannot accept + new responses. + Returns: + None: This method does not return a value. """ if self.is_closed: raise RuntimeError("Cannot write to a closed response writer.") @@ -37,17 +56,22 @@ async def write(self, response: Response) -> None: await self.queue.put(response) async def receive(self) -> Response: - """Asynchronously retrieves a response from the internal queue. + """Asynchronously receive a response from the response writer's queue. - Raises: - RuntimeError: If the response writer is already closed. + This method retrieves the next response from the internal queue and marks + the response writer as closed after receiving the response. Returns: - Response: The next response item from the queue. + Response: The response object retrieved from the queue. - Side Effects: - Marks the response writer as closed after receiving a response. + Raises: + RuntimeError: If the response writer is already closed when attempting + to receive a response. + Note: + This method can only be called once per response writer instance. + After calling this method, the response writer will be marked as closed + and subsequent calls will raise a RuntimeError. """ if self.is_closed: raise RuntimeError("Cannot receive from a closed response writer.") diff --git a/src/connect/utils.py b/src/connect/utils.py index a0f86e8..829c8d4 100644 --- a/src/connect/utils.py +++ b/src/connect/utils.py @@ -1,4 +1,4 @@ -"""Provides utility functions for asynchronous programming.""" +"""Utility functions for async operations, HTTP exception handling, and callable attribute inspection.""" import asyncio import contextlib @@ -26,19 +26,17 @@ def is_async_callable(obj: typing.Any) -> typing.TypeGuard[AwaitableCallable[typ def is_async_callable(obj: typing.Any) -> typing.Any: - """Check if the given object is an asynchronous callable. + """Check if an object is an async callable (coroutine function or callable with async __call__). - This function unwraps functools.partial objects to check if the underlying - function is an asynchronous coroutine function. It returns True if the object - is an async coroutine function or if it is a callable object whose __call__ method - is an async coroutine function. + This function handles partial functions by unwrapping them to check the underlying + function. It returns True if the object is a coroutine function or if it's a + callable object with an async __call__ method. Args: - obj (typing.Any): The object to check. + obj (typing.Any): The object to check for async callability. Returns: - bool: True if the object is an asynchronous callable, False otherwise. - + typing.Any: True if the object is async callable, False otherwise. """ while isinstance(obj, functools.partial): obj = obj.func @@ -47,40 +45,38 @@ def is_async_callable(obj: typing.Any) -> typing.Any: async def run_in_threadpool[T, **P](func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: - """Run a function in a thread pool and return the result. + """Execute a synchronous function in a thread pool asynchronously. - This function is useful for running synchronous code in an asynchronous context - by offloading the execution to a thread pool. + This function takes a synchronous callable and runs it in a separate thread + using anyio's thread pool, allowing async code to call blocking functions + without blocking the event loop. Args: - func (typing.Callable[P, T]): The function to run in the thread pool. - *args (P.args): Positional arguments to pass to the function. - **kwargs (P.kwargs): Keyword arguments to pass to the function. + func: A callable function to execute in the thread pool + *args: Positional arguments to pass to the function + **kwargs: Keyword arguments to pass to the function Returns: - T: The result of the function execution. - - Raises: - Exception: Any exception raised by the function will be propagated. - - Example: - result = await run_in_threadpool(some_sync_function, arg1, arg2, kwarg1=value1) - + The return value of the executed function """ func = functools.partial(func, *args, **kwargs) return await anyio.to_thread.run_sync(func) def get_callable_attribute(obj: object, attr: str) -> typing.Callable[..., typing.Any] | None: - """Retrieve a callable attribute from an object if it exists and is callable. + """Get a callable attribute from an object. + + This function attempts to retrieve an attribute from an object and returns it + only if the attribute exists and is callable. If the attribute doesn't exist + or is not callable, returns None. Args: obj (object): The object from which to retrieve the attribute. attr (str): The name of the attribute to retrieve. Returns: - typing.Callable[..., typing.Any] | None: The callable attribute if it exists and is callable, otherwise None. - + typing.Callable[..., typing.Any] | None: The callable attribute if it exists + and is callable, otherwise None. """ try: attr_value = getattr(obj, attr) @@ -101,7 +97,6 @@ def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typi Returns: typing.Callable[..., typing.Awaitable[typing.Any]] | None: The attribute if it is callable and asynchronous, otherwise None. - """ try: attr_value = getattr(obj, attr) @@ -113,14 +108,16 @@ def get_acallable_attribute(obj: object, attr: str) -> typing.Callable[..., typi async def aiterate[T](iterable: typing.Iterable[T]) -> typing.AsyncIterator[T]: - """Turn a plain iterable into an async iterator. + """Convert a regular iterable to an async iterator. + + This function takes a synchronous iterable and yields each item asynchronously, + allowing it to be used in async contexts with `async for` loops. Args: - iterable (typing.Iterable[T]): The iterable to convert. + iterable: A synchronous iterable of type T. Yields: - typing.AsyncIterator[T]: An async iterator over the elements of the input iterable. - + T: Each item from the input iterable, yielded asynchronously. """ for i in iterable: yield i @@ -150,21 +147,19 @@ def _load_httpcore_exceptions() -> dict[type[Exception], Code]: @contextlib.contextmanager def map_httpcore_exceptions() -> Iterator[None]: - """Map exceptions raised by the HTTP core to custom exceptions. + """Context manager that maps httpcore exceptions to ConnectError exceptions. - This function uses a global exception map `HTTPCORE_EXC_MAP` to translate exceptions - raised within its context. If the map is empty, it loads the exceptions using the - `_load_httpcore_exceptions` function. When an exception is caught, it checks if the - exception matches any in the map and raises a `ConnectError` with the corresponding - error code. If no match is found, the original exception is re-raised. + This function lazily loads a mapping of httpcore exceptions to Connect error codes + and converts any httpcore exceptions that occur within its context to ConnectError + instances with appropriate error codes. Yields: - None: This function is a generator used as a context manager. + None: This is a context manager that yields control to the calling code. Raises: - ConnectError: If the caught exception matches an entry in `HTTPCORE_EXC_MAP`. - Exception: If no match is found in `HTTPCORE_EXC_MAP`, the original exception is re-raised. - + ConnectError: If an httpcore exception occurs that has a mapping defined, + it will be converted to a ConnectError with the appropriate code. + Exception: Any other exceptions are re-raised unchanged. """ global HTTPCORE_EXC_MAP if len(HTTPCORE_EXC_MAP) == 0: diff --git a/tests/conftest.py b/tests/conftest.py index 1cd19dc..577d824 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ from connect.envelope import Envelope, EnvelopeFlags from connect.middleware import ConnectMiddleware -from connect.options import ConnectOptions +from connect.options import HandlerOptions from tests.testdata.ping.v1.ping_pb2 import PingResponse from tests.testdata.ping.v1.v1connect.ping_connect import ( PingService_service_descriptor, @@ -240,14 +240,14 @@ class AsyncClient: service: PingServiceHandler client: AsyncTestClient - def __init__(self, service: PingServiceHandler, options: ConnectOptions | None = None) -> None: + def __init__(self, service: PingServiceHandler, options: HandlerOptions | None = None) -> None: self.service = service self.options = options async def __aenter__(self) -> AsyncTestClient: assert isinstance(self.service, PingServiceHandler) - options = self.options or ConnectOptions() + options = self.options or HandlerOptions() options.descriptor = PingService_service_descriptor middleware = [ diff --git a/tests/test_streaming_connect_server.py b/tests/test_streaming_connect_server.py index 0d42d76..5bf1909 100644 --- a/tests/test_streaming_connect_server.py +++ b/tests/test_streaming_connect_server.py @@ -12,7 +12,7 @@ from connect.handler_context import HandlerContext from connect.handler_interceptor import HandlerInterceptor, StreamFunc from connect.headers import Headers -from connect.options import ConnectOptions +from connect.options import HandlerOptions from tests.conftest import AsyncClient from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse from tests.testdata.ping.v1.v1connect.ping_connect import ( @@ -381,7 +381,7 @@ async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> Stre return _wrapped async with AsyncClient( - PingService(), ConnectOptions(interceptors=[FileInterceptor1(), FileInterceptor2()]) + PingService(), HandlerOptions(interceptors=[FileInterceptor1(), FileInterceptor2()]) ) as client: response = await client.post( path="/tests.testdata.ping.v1.PingService/PingServerStream", @@ -517,7 +517,7 @@ async def _wrapped(request: StreamRequest[Any], context: HandlerContext) -> Stre return _wrapped async with AsyncClient( - PingService(), ConnectOptions(interceptors=[FileInterceptor1(), FileInterceptor2()]) + PingService(), HandlerOptions(interceptors=[FileInterceptor1(), FileInterceptor2()]) ) as client: response = await client.post( path="/tests.testdata.ping.v1.PingService/PingClientStream", diff --git a/tests/test_unary_connect_server.py b/tests/test_unary_connect_server.py index 8ba7ec9..05419ec 100644 --- a/tests/test_unary_connect_server.py +++ b/tests/test_unary_connect_server.py @@ -10,7 +10,7 @@ from connect.connect import UnaryRequest, UnaryResponse from connect.handler_context import HandlerContext from connect.idempotency_level import IdempotencyLevel -from connect.options import ConnectOptions +from connect.options import HandlerOptions from tests.conftest import AsyncClient from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse from tests.testdata.ping.v1.v1connect.ping_connect import PingServiceHandler @@ -140,7 +140,7 @@ async def Ping( async with AsyncClient( PingService(), - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), ) as client: encoded_message = json.dumps({"name": "test"}).encode() response = await client.get( @@ -171,7 +171,7 @@ async def Ping( async with AsyncClient( PingService(), - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), ) as client: encoded_message = base64.b64encode(json.dumps({"name": "test"}).encode()).decode() response = await client.get( @@ -203,7 +203,7 @@ async def Ping( async with AsyncClient( PingService(), - options=ConnectOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), + options=HandlerOptions(idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS), ) as client: compressor = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS) content = PingRequest(name="test").SerializeToString() diff --git a/tests/testdata/ping/v1/v1connect/ping_connect.py b/tests/testdata/ping/v1/v1connect/ping_connect.py index 545bfa4..39e3ae3 100644 --- a/tests/testdata/ping/v1/v1connect/ping_connect.py +++ b/tests/testdata/ping/v1/v1connect/ping_connect.py @@ -8,7 +8,7 @@ from connect.connect import StreamRequest, StreamResponse, UnaryRequest, UnaryResponse from connect.handler import ClientStreamHandler, Handler, ServerStreamHandler, UnaryHandler from connect.handler_context import HandlerContext -from connect.options import ConnectOptions +from connect.options import HandlerOptions from tests.testdata.ping.v1 import ping_pb2 from tests.testdata.ping.v1.ping_pb2 import PingRequest, PingResponse @@ -42,7 +42,7 @@ async def PingServerStream(self, request: StreamRequest[PingRequest], context: H async def PingClientStream(self, request: StreamRequest[PingRequest], context: HandlerContext) -> StreamResponse[PingResponse]: ... -def create_PingService_handlers(service: PingServiceHandler, options: ConnectOptions | None = None) -> list[Handler]: +def create_PingService_handlers(service: PingServiceHandler, options: HandlerOptions | None = None) -> list[Handler]: handlers: list[Handler] = [ UnaryHandler( procedure=PingServiceProcedures.Ping.value,