Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions src/realtime/src/realtime/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
from ..exceptions import NotConnectedError
from ..message import Message, ServerMessageAdapter
from ..transformers import http_endpoint_url
from ..serializer import Serializer
from ..types import (
DEFAULT_HEARTBEAT_INTERVAL,
DEFAULT_TIMEOUT,
DEFAULT_VSN,
PHOENIX_CHANNEL,
VSN,
VSN_1_0_0,
VSN_2_0_0,
ChannelEvents,
)
from ..utils import is_ws_url
Expand All @@ -38,6 +42,7 @@ def wrapper(*args, **kwargs):


class AsyncRealtimeClient:
serializer: Optional[Serializer]
def __init__(
self,
url: str,
Expand All @@ -48,6 +53,8 @@ def __init__(
max_retries: int = 5,
initial_backoff: float = 1.0,
timeout: int = DEFAULT_TIMEOUT,
vsn: str = DEFAULT_VSN,
allowed_metadata_keys: Optional[List[str]] = None,
) -> None:
"""
Initialize a RealtimeClient instance for WebSocket communication.
Expand All @@ -61,6 +68,9 @@ def __init__(
:param max_retries: Maximum number of reconnection attempts. Defaults to 5.
:param initial_backoff: Initial backoff time (in seconds) for reconnection attempts. Defaults to 1.0.
:param timeout: Connection timeout in seconds. Defaults to DEFAULT_TIMEOUT.
:param vsn: Serializer version to use. Defaults to "1.0.0". Use "2.0.0" for binary support.
:param allowed_metadata_keys: List of metadata keys allowed in user broadcast push messages.
Only used with VSN 2.0.0. Defaults to None.
"""
if not is_ws_url(url):
raise ValueError("url must be a valid WebSocket URL or HTTP URL string")
Expand All @@ -80,8 +90,17 @@ def __init__(
self.max_retries = max_retries
self.initial_backoff = initial_backoff
self.timeout = timeout
self.vsn = vsn
self._listen_task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None

# Initialize serializer based on version
if vsn == VSN_2_0_0:
self.serializer = Serializer(allowed_metadata_keys=allowed_metadata_keys)
elif vsn == VSN_1_0_0:
self.serializer = None # V1 uses JSON directly
else:
raise ValueError(f"Unsupported serializer version: {vsn}")

@property
def is_connected(self) -> bool:
Expand All @@ -101,7 +120,19 @@ async def _listen(self) -> None:
logger.info(f"receive: {msg!r}")

try:
message = ServerMessageAdapter.validate_json(msg)
# Handle binary messages for V2
if isinstance(msg, bytes) and self.vsn == VSN_2_0_0 and self.serializer:
decoded = self.serializer.decode(msg)
# Convert decoded message to JSON string for validation
msg_json = json.dumps({
"event": decoded.get("event"),
"topic": decoded.get("topic"),
"payload": decoded.get("payload"),
"ref": decoded.get("ref"),
})
message = ServerMessageAdapter.validate_json(msg_json)
else:
message = ServerMessageAdapter.validate_json(msg)
except ValidationError as e:
logger.error(f"Unrecognized message format {msg!r}\n{e}")
continue
Expand Down Expand Up @@ -343,15 +374,36 @@ async def send(self, message: Union[Message, Dict[str, Any]]) -> None:
"Warning: calling AsyncRealtimeClient.send with a dictionary is deprecated. Please call it with a Message object instead. This will be a hard error in the future."
)
msg = Message(**message)
message_str = msg.model_dump_json()
logger.info(f"send: {message_str}")

# Encode message based on serializer version
message_data: Union[str, bytes]
if self.vsn == VSN_2_0_0 and self.serializer:
# Convert Message to dict for serializer
msg_dict = {
"join_ref": msg.join_ref,
"ref": msg.ref,
"topic": msg.topic,
"event": msg.event,
"payload": msg.payload,
}
encoded = self.serializer.encode(msg_dict)
if isinstance(encoded, bytes):
message_data = encoded
logger.info(f"send (binary): {len(message_data)} bytes")
else:
message_data = encoded
logger.info(f"send: {message_data}")
else:
# V1: JSON encoding
message_data = msg.model_dump_json()
logger.info(f"send: {message_data}")

async def send_message():
if not self._ws_connection:
raise NotConnectedError("_send")

try:
await self._ws_connection.send(message_str)
await self._ws_connection.send(message_data)
except websockets.exceptions.ConnectionClosedError as e:
await self._on_connect_error(e)
except websockets.exceptions.ConnectionClosedOK:
Expand All @@ -374,7 +426,7 @@ async def _leave_open_topic(self, topic: str):

def endpoint_url(self) -> str:
parsed_url = urlparse(self.url)
query = urlencode({**self.params, "vsn": VSN}, doseq=True)
query = urlencode({**self.params, "vsn": self.vsn}, doseq=True)
return urlunparse(
(
parsed_url.scheme,
Expand Down
Loading