diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index e4441486..0b360ffc 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -364,7 +364,7 @@ def frame_received(self, data: bytes) -> None: try: self._protocol(data) except Exception: - LOGGER.warning("Failed to parse frame, ignoring") + LOGGER.warning("Failed to parse frame. This is a bug!", exc_info=True) async def get_board_info( self, diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 30006dc4..0ae93b8e 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -5,13 +5,16 @@ from asyncio import timeout as asyncio_timeout import binascii from collections.abc import AsyncGenerator, Callable, Iterable +from dataclasses import dataclass import functools import logging import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Final from zigpy.datastructures import PriorityDynamicBoundedSemaphore +from zigpy.event.event_base import EventBase import zigpy.state +import zigpy.types from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError @@ -27,13 +30,62 @@ MAX_COMMAND_CONCURRENCY = 1 -class ProtocolHandler(abc.ABC): +@dataclass(frozen=True, kw_only=True) +class MessageSentEvent: + event_type: Final[str] = "message_sent" + + status: t.sl_Status + message_type: t.EmberOutgoingMessageType + destination: t.uint16_t + aps_frame: t.EmberApsFrame + message_tag: t.uint8_t + message_contents: t.LVBytes + + +@dataclass(frozen=True, kw_only=True) +class PacketReceivedEvent: + event_type: Final[str] = "packet_received" + + packet: zigpy.types.ZigbeePacket + + +@dataclass(frozen=True, kw_only=True) +class TrustCenterJoinEvent: + event_type: Final[str] = "trust_center_join" + + nwk: t.EmberNodeId + ieee: t.EUI64 + device_update_status: t.EmberDeviceUpdate + decision: t.EmberJoinDecision + parent_nwk: t.EmberNodeId + + +@dataclass(frozen=True, kw_only=True) +class RouteRecordEvent: + event_type: Final[str] = "route_record" + + nwk: t.EmberNodeId + ieee: t.EUI64 + lqi: t.uint8_t + rssi: t.int8s + relays: t.LVList[t.EmberNodeId] + + +@dataclass(frozen=True, kw_only=True) +class IdConflictEvent: + event_type: Final[str] = "id_conflict" + + nwk: t.EmberNodeId + + +class ProtocolHandler(EventBase, abc.ABC): """EZSP protocol specific handler.""" COMMANDS = {} VERSION = None def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: + super().__init__() self._handle_callback = cb_handler self._awaiting = {} self._gw = gateway @@ -179,52 +231,6 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) - if ( - frame_name == "incomingMessageHandler" - and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT - ): - # Extract received APS frame and sender - aps_frame = result[1] - sender = result[4] - - # The fragment count and index are encoded in the groupId field - fragment_count = (aps_frame.groupId >> 8) & 0xFF - fragment_index = aps_frame.groupId & 0xFF - - ( - complete, - reassembled, - frag_count, - frag_index, - ) = self._fragment_manager.handle_incoming_fragment( - sender_nwk=sender, - aps_sequence=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - fragment_count=fragment_count, - fragment_index=fragment_index, - payload=result[7], - ) - - ack_task = asyncio.create_task( - self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) - ) # APS Ack - - self._fragment_ack_tasks.add(ack_task) - ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) - - if not complete: - # Do not pass partial data up the stack - LOGGER.debug("Fragment reassembly not complete. waiting for more data.") - return - - # Replace partial data with fully reassembled data - result[7] = reassembled - - LOGGER.debug( - "Reassembled fragmented message. Proceeding with normal handling." - ) - if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) try: @@ -246,8 +252,20 @@ def __call__(self, data: bytes) -> None: sequence, self.COMMANDS_BY_ID.get(expected_id, [expected_id])[0], ) - else: - self._handle_callback(frame_name, result) + + return + + self.handle_parsed_callback(frame_name, result) + + # Legacy callback system for CLI tools + self._handle_callback(frame_name, result) + + def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: + """Dispatch a callback frame to the appropriate handler method.""" + handler = getattr(self, f"_handle_{frame_name}", None) + + if handler is not None: + handler(*args) async def _send_fragment_ack( self, @@ -275,6 +293,135 @@ async def _send_fragment_ack( status = await self.sendReply(sender, ackFrame, b"") return status[0] + def _handle_incoming_message( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + sender: zigpy.types.NWK, + eui64: zigpy.types.EUI64 | None, + binding_index: t.uint8_t, + address_index: t.uint8_t, + lqi: t.uint8_t, + rssi: t.int8s, + timestamp: t.uint32_t | None, + message: t.LVBytes, + ) -> None: + """Handle incomingMessageHandler callback and maybe return a packet.""" + + if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: + fragment_count = (aps_frame.groupId >> 8) & 0xFF + fragment_index = aps_frame.groupId & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = self._fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=message, + ) + + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) + self._fragment_ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) + + if not complete: + LOGGER.debug("Fragment reassembly not complete, waiting for more data") + return + + LOGGER.debug("Reassembled fragmented message, proceeding with handling") + message = reassembled + + # Determine destination address based on message type + if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=aps_frame.groupId, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: + dst = None # We don't know the current NWK here + else: + LOGGER.debug("Ignoring message type: %r", message_type) + return + + self.emit( + PacketReceivedEvent.event_type, + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(sender), + ), + src_ep=aps_frame.sourceEndpoint, + dst=dst, + dst_ep=aps_frame.destinationEndpoint, + tsn=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + data=zigpy.types.SerializableBytes(message), + lqi=lqi, + rssi=rssi, + ) + ), + ) + + def _handle_trustCenterJoinHandler( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + device_update_status: t.EmberDeviceUpdate, + decision: t.EmberJoinDecision, + parent_nwk: t.EmberNodeId, + ) -> None: + self.emit( + TrustCenterJoinEvent.event_type, + TrustCenterJoinEvent( + nwk=nwk, + ieee=ieee, + device_update_status=device_update_status, + decision=decision, + parent_nwk=parent_nwk, + ), + ) + + def _handle_incomingRouteRecordHandler( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + lqi: t.uint8_t, + rssi: t.int8s, + relays: t.LVList[t.EmberNodeId], + ) -> None: + self.emit( + RouteRecordEvent.event_type, + RouteRecordEvent( + nwk=nwk, + ieee=ieee, + lqi=lqi, + rssi=rssi, + relays=relays, + ), + ) + + def _handle_idConflictHandler(self, nwk: t.EmberNodeId) -> None: + self.emit( + IdConflictEvent.event_type, + IdConflictEvent(nwk=nwk), + ) + def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index 16dbeec7..dfba42e0 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -2,17 +2,22 @@ from __future__ import annotations from collections.abc import AsyncGenerator +import logging import voluptuous as vol from zigpy.exceptions import NetworkNotFormed import zigpy.state +import zigpy.types import bellows.config import bellows.types as t from . import commands, config +from ..protocol import MessageSentEvent from ..v13 import EZSPv13 +LOGGER = logging.getLogger(__name__) + class EZSPv14(EZSPv13): """EZSP Version 14 Protocol version handler.""" @@ -144,3 +149,50 @@ async def send_broadcast( ) return status, sequence + + def _handle_incomingMessageHandler( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + sender: t.EmberNodeId, + eui64: t.EUI64, + binding_index: t.uint8_t, + address_index: t.uint8_t, + lqi: t.uint8_t, + rssi: t.int8s, + timestamp: t.uint32_t, + message: t.LVBytes, + ) -> None: + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, + lqi=lqi, + rssi=rssi, + timestamp=None, + message=message, + ) + + def _handle_messageSentHandler( + self, + status: t.sl_Status, + message_type: t.EmberOutgoingMessageType, + destination: t.EmberNodeId, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + message: t.LVBytes, + ) -> None: + self.emit( + MessageSentEvent.event_type, + MessageSentEvent( + status=status, + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + message_contents=message, + ), + ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index 3b454ecd..e2946818 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -7,6 +7,7 @@ import voluptuous as vol import zigpy.state +import zigpy.types import bellows.config import bellows.types as t @@ -14,6 +15,7 @@ from . import commands, config from .. import protocol +from ..protocol import MessageSentEvent LOGGER = logging.getLogger(__name__) @@ -235,3 +237,48 @@ async def set_extended_timeout( newId=nwk, newExtendedTimeout=extended_timeout, ) + + def _handle_incomingMessageHandler( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + lqi: t.uint8_t, + rssi: t.int8s, + sender: t.EmberNodeId, + binding_index: t.uint8_t, + address_index: t.uint8_t, + message: t.LVBytes, + ) -> None: + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, + lqi=lqi, + rssi=rssi, + timestamp=None, + message=message, + ) + + def _handle_messageSentHandler( + self, + message_type: t.EmberOutgoingMessageType, + destination: t.EmberNodeId, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + status: t.EmberStatus, + message: t.LVBytes, + ) -> None: + self.emit( + MessageSentEvent.event_type, + MessageSentEvent( + status=t.sl_Status.from_ember_status(status), + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + message_contents=message, + ), + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 1bd0271c..51db4585 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -2,7 +2,7 @@ import asyncio from asyncio import timeout as asyncio_timeout -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from datetime import UTC, datetime import importlib.metadata import logging @@ -39,6 +39,13 @@ StackAlreadyRunning, ) import bellows.ezsp +from bellows.ezsp.protocol import ( + IdConflictEvent, + MessageSentEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) from bellows.ezsp.xncp import FirmwareFeatures import bellows.multicast import bellows.types as t @@ -97,6 +104,7 @@ def __init__(self, config: dict) -> None: self._multicast = None self._mfg_id_task: asyncio.Task | None = None self._pending_requests = {} + self._protocol_on_remove_callbacks: list[Callable[[], None]] = [] self._watchdog_failures = 0 self._watchdog_feed_counter = 0 @@ -240,7 +248,8 @@ async def start_network(self): for cnt_group in self.state.counters: cnt_group.reset() - ezsp.add_callback(self.ezsp_callback_handler) + self._subscribe_to_protocol_events() + self.controller_event.set() group_membership = {} @@ -602,14 +611,52 @@ async def reset_network_info(self): else: await self._ezsp.leaveNetwork() + def _unsubscribe_from_protocol_events(self) -> None: + """Unsubscribe from protocol events.""" + for callback in self._protocol_on_remove_callbacks: + callback() + + self._protocol_on_remove_callbacks.clear() + async def _reset(self): + self._unsubscribe_from_protocol_events() self._ezsp.stop_ezsp() await self._ezsp.startup_reset() await self._ezsp.write_config(self.config[CONF_EZSP_CONFIG]) + self._subscribe_to_protocol_events() + + def _subscribe_to_protocol_events(self) -> None: + """Subscribe to protocol-level events.""" + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + PacketReceivedEvent.event_type, self._on_packet_received + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + MessageSentEvent.event_type, self._on_message_sent + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + TrustCenterJoinEvent.event_type, self._on_trust_center_join + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + RouteRecordEvent.event_type, self._on_route_record + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + IdConflictEvent.event_type, self._on_id_conflict + ) + ) async def disconnect(self): # TODO: how do you shut down the stack? self.controller_event.clear() + self._unsubscribe_from_protocol_events() if self._ezsp is not None: await self._ezsp.disconnect() self._ezsp = None @@ -619,172 +666,60 @@ async def force_remove(self, dev): # of the device itself. await self._ezsp.removeDevice(dev.nwk, dev.ieee, dev.ieee) - def ezsp_callback_handler(self, frame_name, args): - LOGGER.debug("Received %s frame with %s", frame_name, args) - if frame_name == "incomingMessageHandler": - if self._ezsp.ezsp_version >= 14: - ( - message_type, - aps_frame, - nwk, - _eui64, - binding_index, - address_index, - lqi, - rssi, - _timestamp, - message, - ) = args - else: - ( - message_type, - aps_frame, - lqi, - rssi, - nwk, - binding_index, - address_index, - message, - ) = args - - self._handle_frame( - message_type=message_type, - aps_frame=aps_frame, - lqi=lqi, - rssi=rssi, - sender=nwk, - binding_index=binding_index, - address_index=address_index, - message=message, - ) - elif frame_name == "messageSentHandler": - if self._ezsp.ezsp_version >= 14: - ( - status, - message_type, - destination, - aps_frame, - message_tag, - message, - ) = args - else: - ( - message_type, - destination, - aps_frame, - message_tag, - status, - message, - ) = args - status = t.sl_Status.from_ember_status(status) - - self._handle_frame_sent( - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - status=status, - message=message, - ) - elif frame_name == "trustCenterJoinHandler": - self._handle_tc_join_handler(*args) - elif frame_name == "incomingRouteRecordHandler": - self.handle_route_record(*args) - elif frame_name == "incomingRouteErrorHandler": - status, nwk = args - status = t.sl_Status.from_ember_status(status) - self.handle_route_error(status, nwk) - elif frame_name == "idConflictHandler": - self._handle_id_conflict(*args) - - def _handle_frame( - self, - message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame, - lqi: t.uint8_t, - rssi: t.int8s, - sender: t.EmberNodeId, - binding_index: t.uint8_t, - address_index: t.uint8_t, - message: bytes, - ) -> None: - if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + def _on_packet_received(self, message: PacketReceivedEvent) -> None: + """Handle packet_received event from protocol handler.""" + packet = message.packet + + # The protocol handler doesn't know our current NWK address + if packet.dst is None: + packet = packet.replace( + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self.state.node_info.nwk, + ) ) + + if packet.dst.addr_mode == zigpy.types.AddrMode.NWK: + self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() + elif packet.dst.addr_mode == zigpy.types.AddrMode.Broadcast: self.state.counters[COUNTERS_CTRL][COUNTER_RX_BCAST].increment() - elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, address=aps_frame.groupId - ) + elif packet.dst.addr_mode == zigpy.types.AddrMode.Group: self.state.counters[COUNTERS_CTRL][COUNTER_RX_MCAST].increment() - elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, address=self.state.node_info.nwk - ) - self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() - else: - LOGGER.debug("Ignoring message type: %r", message_type) - return - self.packet_received( - zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=sender, - ), - src_ep=aps_frame.sourceEndpoint, - dst=dst, - dst_ep=aps_frame.destinationEndpoint, - tsn=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - data=zigpy.types.SerializableBytes(message), - lqi=lqi, - rssi=rssi, - ) - ) + self.packet_received(packet) - def _handle_frame_sent( - self, - message_type: t.EmberIncomingMessageType, - destination: t.EmberNodeId, - aps_frame: t.EmberApsFrame, - message_tag: int, - status: t.sl_Status, - message: bytes, - ): - if status == t.sl_Status.OK: + def _on_message_sent(self, event: MessageSentEvent) -> None: + """Handle message_sent event from protocol handler.""" + if event.status == t.sl_Status.OK: msg = "success" else: msg = "failure" - if message_type in ( + if event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_BROADCAST, t.EmberOutgoingMessageType.OUTGOING_BROADCAST_WITH_ALIAS, ): cnt_name = f"broadcast_tx_{msg}" - elif message_type in ( + elif event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_MULTICAST, t.EmberOutgoingMessageType.OUTGOING_MULTICAST_WITH_ALIAS, ): cnt_name = f"multicast_tx_{msg}" - elif message_type in ( + elif event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_DIRECT, t.EmberOutgoingMessageType.OUTGOING_VIA_ADDRESS_TABLE, ): cnt_name = f"unicast_tx_{msg}" - elif message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: + elif event.message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: cnt_name = f"via_binding_tx_{msg}" else: cnt_name = f"unknown_msg_type_{msg}" - pending_tag = (destination, message_tag) + pending_tag = (event.destination, event.message_tag) try: future = self._pending_requests[pending_tag] - future.set_result((status, f"message send {msg}")) + future.set_result((event.status, f"message send {msg}")) self.state.counters[COUNTERS_CTRL][cnt_name].increment() except KeyError: self.state.counters[COUNTERS_CTRL][f"{cnt_name}_unexpected"].increment() @@ -800,44 +735,31 @@ def _handle_frame_sent( exc, ) - async def _handle_no_such_device(self, sender: int) -> None: - """Try to match unknown device by its EUI64 address.""" - status, ieee = await self._ezsp.lookupEui64ByNodeId(nodeId=sender) - status = t.sl_Status.from_ember_status(status) - - if status == t.sl_Status.OK: - LOGGER.debug("Found %s ieee for %s sender", ieee, sender) - self.handle_join(sender, ieee, 0) - return - LOGGER.debug("Couldn't look up ieee for %s", sender) - - def _handle_tc_join_handler( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - device_update_status: t.EmberDeviceUpdate, - decision: t.EmberJoinDecision, - parent_nwk: t.EmberNodeId, - ) -> None: - """Trust Center Join handler.""" - if device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: - self.handle_leave(nwk, ieee) + def _on_trust_center_join(self, event: TrustCenterJoinEvent) -> None: + """Handle trust_center_join event from protocol handler.""" + if event.device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: + self.handle_leave(event.nwk, event.ieee) return - if device_update_status == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN: - self.create_task(self.cleanup_tc_link_key(ieee), "cleanup_tc_link_key") + if ( + event.device_update_status + == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN + ): + self.create_task( + self.cleanup_tc_link_key(event.ieee), "cleanup_tc_link_key" + ) - if decision == t.EmberJoinDecision.DENY_JOIN: + if event.decision == t.EmberJoinDecision.DENY_JOIN: # no point in handling the join if it was denied return - mfg_id = IEEE_PREFIX_MFG_ID.get(str(ieee)[:8].upper()) + mfg_id = IEEE_PREFIX_MFG_ID.get(str(event.ieee)[:8].upper()) if mfg_id is not None: if self._mfg_id_task and not self._mfg_id_task.done(): self._mfg_id_task.cancel() self._mfg_id_task = asyncio.create_task(self._reset_mfg_id(mfg_id)) - self.handle_join(nwk, ieee, parent_nwk) + self.handle_join(event.nwk, event.ieee, event.parent_nwk) async def _reset_mfg_id(self, mfg_id: int) -> None: """Resets manufacturer id if was temporary overridden by a joining device.""" @@ -1131,20 +1053,21 @@ async def permit_with_link_key( return await super().permit(time_s) - def _handle_id_conflict(self, nwk: t.EmberNodeId) -> None: - LOGGER.warning("NWK conflict is reported for 0x%04x", nwk) + def _on_id_conflict(self, event: IdConflictEvent) -> None: + """Handle id_conflict event from protocol handler.""" + LOGGER.warning("NWK conflict is reported for 0x%04x", event.nwk) self.state.counters[COUNTERS_CTRL][COUNTER_NWK_CONFLICTS].increment() for device in self.devices.values(): - if device.nwk != nwk: + if device.nwk != event.nwk: continue LOGGER.warning( "Found %s device for 0x%04x NWK conflict: %s %s", device.ieee, - nwk, + event.nwk, device.manufacturer, device.model, ) - self.handle_leave(nwk, device.ieee) + self.handle_leave(event.nwk, device.ieee) async def _watchdog_loop(self): self._watchdog_failures = 0 @@ -1205,18 +1128,10 @@ async def _get_free_buffers(self) -> int | None: LOGGER.debug("Free buffers status %s, value: %s", status, buffers) return buffers - def handle_route_record( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - lqi: t.uint8_t, - rssi: t.int8s, - relays: t.LVList[t.EmberNodeId], - ) -> None: + def _on_route_record(self, event: RouteRecordEvent) -> None: + """Handle route_record event from protocol handler.""" LOGGER.debug( - "Processing route record request: %s", (nwk, ieee, lqi, rssi, relays) + "Processing route record request: %s", + (event.nwk, event.ieee, event.lqi, event.rssi, event.relays), ) - self.handle_relays(nwk=nwk, relays=relays) - - def handle_route_error(self, status: t.sl_Status, nwk: t.EmberNodeId) -> None: - LOGGER.debug("Processing route error: status=%s, nwk=%s", status, nwk) + self.handle_relays(nwk=event.nwk, relays=event.relays) diff --git a/tests/test_application.py b/tests/test_application.py index 018dd632..f1fa27da 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,6 +16,13 @@ import bellows.config as config from bellows.exception import ControllerError, EzspError, InvalidCommandError import bellows.ezsp as ezsp +from bellows.ezsp.protocol import ( + IdConflictEvent, + MessageSentEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) from bellows.ezsp.v9.commands import GetTokenDataRsp from bellows.ezsp.xncp import ( FirmwareFeatures, @@ -71,6 +78,9 @@ def inner(config, send_timeout: float = 0.05, **kwargs): app.handle_message = MagicMock() app.packet_received = MagicMock() + # Set up event subscriptions normally done in start_network() + app._subscribe_to_protocol_events() + return app return inner @@ -417,215 +427,18 @@ async def test_startup_no_board_info(app, ieee, caplog): assert "EZSP Radio does not support getMfgToken command" in caplog.text -@pytest.fixture -def aps_frame(): - return t.EmberApsFrame( - profileId=0x1234, - clusterId=0x5678, - sourceEndpoint=0x9A, - destinationEndpoint=0xBC, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0xDE, - ) - - -def _handle_incoming_aps_frame(app, aps_frame, type): - app.ezsp_callback_handler( - "incomingMessageHandler", - list( - dict( - type=type, - apsFrame=aps_frame, - lastHopLqi=123, - lastHopRssi=-45, - sender=0xABCD, - bindingIndex=56, - addressIndex=78, - message=b"test message", - ).values() - ), - ) - - -def test_frame_handler_unicast(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_UNICAST - ) - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.NWK - assert packet.dst.address == app.state.node_info.nwk - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_UNICAST - ] - == 1 - ) - - -def test_frame_handler_broadcast(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST - ) - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.Broadcast - assert packet.dst.address == zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_BCAST - ] - == 1 - ) - - -def test_frame_handler_multicast(app, aps_frame): - aps_frame.groupId = 0xEF12 - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_MULTICAST - ) - - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.Group - assert packet.dst.address == 0xEF12 - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_MCAST - ] - == 1 - ) - - -def test_frame_handler_ignored(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK - ) - assert app.packet_received.call_count == 0 - - -@pytest.mark.parametrize( - "msg_type", - ( - t.EmberIncomingMessageType.INCOMING_BROADCAST, - t.EmberIncomingMessageType.INCOMING_MULTICAST, - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xFF, - ), -) -async def test_send_failure(app, aps, ieee, msg_type): - fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", [msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b""] - ) - assert fut.result() == (t.sl_Status.OK, "message send success") - - -async def test_dup_send_failure(app, aps, ieee): - fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - fut.set_result("Already set") - - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xBEED, - aps, - 254, - sentinel.status, - b"", - ], - ) - - -def test_send_failure_unexpected(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, - 0xBEED, - aps, - 257, - 1, - b"", - ], - ) - - -async def test_send_success(app, aps, ieee): - fut = app._pending_requests[(0xBEED, 253)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, - 0xBEED, - aps, - 253, - t.EmberStatus.SUCCESS, - b"", - ], - ) - - assert fut.result() == (t.sl_Status.OK, "message send success") - - -def test_unexpected_send_success(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b""], - ) - - async def test_join_handler(app, ieee): # Calls device.initialize, leaks a task app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - ieee, - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0) assert ieee not in app.devices @@ -639,15 +452,14 @@ async def test_join_handler(app, ieee): # cleanup TCLK, but no join handling app.handle_join.reset_mock() app.cleanup_tc_link_key.reset_mock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - ieee, - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.DENY_JOIN, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.DENY_JOIN, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0) assert app.cleanup_tc_link_key.await_count == 1 @@ -658,8 +470,14 @@ async def test_join_handler(app, ieee): def test_leave_handler(app, ieee): app.handle_join = MagicMock() app.devices[ieee] = MagicMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", [1, ieee, t.EmberDeviceUpdate.DEVICE_LEFT, None, None] + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.DEVICE_LEFT, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=t.EmberNodeId(0x0000), + ) ) assert ieee in app.devices assert app.handle_join.call_count == 0 @@ -739,7 +557,7 @@ async def test_request_concurrency_duplicate_failure( ) -> None: def send_unicast(aps_frame, data, message_tag, nwk): asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -791,7 +609,7 @@ async def _test_send_packet_unicast( def send_unicast(*args, **kwargs): asyncio.get_running_loop().call_later( 0.01, - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -923,15 +741,14 @@ async def test_send_packet_unicast_extended_timeout_with_acks(app, ieee, packet) asyncio.get_running_loop().call_later( 0.1, - app.ezsp_callback_handler, - "incomingRouteRecordHandler", - { - "source": packet.dst.address, - "sourceEui": ieee, - "lastHopLqi": 123, - "lastHopRssi": -60, - "relayList": [0x1234], - }.values(), + app._on_route_record, + RouteRecordEvent( + nwk=packet.dst.address, + ieee=ieee, + lqi=123, + rssi=-60, + relays=[0x1234], + ), ) await _test_send_packet_unicast( @@ -953,15 +770,14 @@ async def test_send_packet_unicast_extended_timeout_without_acks(app, ieee, pack asyncio.get_running_loop().call_later( 0.1, - app.ezsp_callback_handler, - "incomingRouteRecordHandler", - { - "source": packet.dst.address, - "sourceEui": ieee, - "lastHopLqi": 123, - "lastHopRssi": -60, - "relayList": [0x1234], - }.values(), + app._on_route_record, + RouteRecordEvent( + nwk=packet.dst.address, + ieee=ieee, + lqi=123, + rssi=-60, + relays=[0x1234], + ), ) await _test_send_packet_unicast( @@ -1045,7 +861,7 @@ async def send_message_sent_reply( await asyncio.sleep(0.01) - app.ezsp_callback_handler( + app._ezsp._protocol.handle_parsed_callback( "messageSentHandler", list( dict( @@ -1102,7 +918,7 @@ async def test_send_packet_broadcast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -1148,7 +964,7 @@ async def test_send_packet_broadcast_ignored_delivery_failure(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -1201,7 +1017,7 @@ async def test_send_packet_multicast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -1570,20 +1386,18 @@ def test_coordinator_model_manuf(coordinator): def test_handle_route_record(app): """Test route record handling for an existing device.""" app.handle_relays = MagicMock(spec_set=app.handle_relays) - app.ezsp_callback_handler( - "incomingRouteRecordHandler", - [sentinel.nwk, sentinel.ieee, sentinel.lqi, sentinel.rssi, sentinel.relays], - ) - app.handle_relays.assert_called_once_with(nwk=sentinel.nwk, relays=sentinel.relays) - - -def test_handle_route_error(app): - """Test route error handler.""" - app.handle_relays = MagicMock(spec_set=app.handle_relays) - app.ezsp_callback_handler( - "incomingRouteErrorHandler", [sentinel.status, sentinel.nwk] + app._on_route_record( + RouteRecordEvent( + nwk=sentinel.nwk, + ieee=sentinel.ieee, + lqi=sentinel.lqi, + rssi=sentinel.rssi, + relays=sentinel.relays, + ) ) - app.handle_relays.assert_not_called() + assert app.handle_relays.mock_calls == [ + call(nwk=sentinel.nwk, relays=sentinel.relays) + ] def test_handle_id_conflict(app, ieee): @@ -1592,43 +1406,14 @@ def test_handle_id_conflict(app, ieee): app.add_device(ieee, nwk) app.handle_leave = MagicMock() - app.ezsp_callback_handler("idConflictHandler", [nwk + 1]) + app._on_id_conflict(IdConflictEvent(nwk=nwk + 1)) assert app.handle_leave.call_count == 0 - app.ezsp_callback_handler("idConflictHandler", [nwk]) + app._on_id_conflict(IdConflictEvent(nwk=nwk)) assert app.handle_leave.call_count == 1 assert app.handle_leave.call_args[0][0] == nwk -async def test_handle_no_such_device(app, ieee): - """Test handling of an unknown device IEEE lookup.""" - - app._ezsp.lookupEui64ByNodeId = AsyncMock() - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.ERR_FATAL, ieee)), - ) - p2 = patch.object(app, "handle_join") - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 0 - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.SUCCESS, sentinel.ieee)), - ) - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 1 - assert handle_join_mock.call_args[0][0] == sentinel.nwk - assert handle_join_mock.call_args[0][1] == sentinel.ieee - - async def test_cleanup_tc_link_key(app): """Test cleaning up tc link key.""" ezsp = app._ezsp @@ -1673,26 +1458,24 @@ async def test_set_mfg_id(ieee, expected_mfg_id, app): app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - t.EUI64.convert(ieee), - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=t.EUI64.convert(ieee), + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) # preempt - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - t.EUI64.convert(ieee), - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=t.EUI64.convert(ieee), + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0.20) if expected_mfg_id is not None: @@ -2683,3 +2466,166 @@ async def test_set_tx_power(app: ControllerApplication) -> None: assert result == 12.0 assert app._ezsp.setRadioPower.mock_calls == [call(power=12)] assert mock_update.mock_calls == [call(app._ezsp, tx_power=12)] + + +async def test_reset_resubscribes_events(app: ControllerApplication) -> None: + """Test that _reset unsubscribes, resets, and resubscribes to protocol events.""" + app._ezsp.stop_ezsp = MagicMock() + app._ezsp.startup_reset = AsyncMock() + app._ezsp.write_config = AsyncMock() + + # Add a dummy callback to verify unsubscribe is called + unsubscribe_mock = MagicMock() + app._protocol_on_remove_callbacks.append(unsubscribe_mock) + + await app._reset() + + # Verify unsubscribe was called + assert unsubscribe_mock.mock_calls == [call()] + + # Verify EZSP reset sequence + assert len(app._ezsp.stop_ezsp.mock_calls) == 1 + assert len(app._ezsp.startup_reset.mock_calls) == 1 + assert len(app._ezsp.write_config.mock_calls) == 1 + + # Verify we resubscribed (callbacks list should have 5 entries now) + assert len(app._protocol_on_remove_callbacks) == 5 + + +def test_on_packet_received_unicast(app: ControllerApplication) -> None: + """Test _on_packet_received with unicast message (dst=None gets replaced).""" + app.state.node_info.nwk = zigpy_t.NWK(0x0000) + + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + # Unicast packets have dst=None, protocol handler doesn't know our NWK + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=None, # Will be replaced with our NWK + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with dst replaced + assert packet_received_mock.mock_calls == [ + call( + zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x0000), + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test"), + lqi=200, + rssi=-40, + ) + ) + ] + + +def test_on_packet_received_broadcast(app: ControllerApplication) -> None: + """Test _on_packet_received with broadcast message.""" + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"broadcast"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with the same packet (dst already set) + assert packet_received_mock.mock_calls == [call(event.packet)] + + +def test_on_packet_received_multicast(app: ControllerApplication) -> None: + """Test _on_packet_received with multicast message.""" + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Group, + address=0x5678, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"multicast"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with the same packet (dst already set) + assert packet_received_mock.mock_calls == [call(event.packet)] + + +async def test_on_message_sent_via_binding(app: ControllerApplication) -> None: + """Test _on_message_sent with OUTGOING_VIA_BINDING message type.""" + # Create a pending request future + future = asyncio.get_running_loop().create_future() + app._pending_requests[(0x1234, 0x42)] = future + + event = MessageSentEvent( + status=t.sl_Status.OK, + message_type=t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING, + destination=0x1234, + aps_frame=t.EmberApsFrame(), + message_tag=0x42, + message_contents=b"test", + ) + + app._on_message_sent(event) + + # Verify the future was resolved + assert future.done() + assert future.result() == (t.sl_Status.OK, "message send success") diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 3906eb5e..8550c2c5 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -3,8 +3,15 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import pytest +import zigpy.types from bellows.ezsp import EZSP +from bellows.ezsp.protocol import ( + IdConflictEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) import bellows.ezsp.v4 import bellows.ezsp.v9 from bellows.ezsp.v9.commands import GetTokenDataRsp @@ -206,9 +213,9 @@ async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - prot_hndl._handle_callback.assert_not_called() - assert "Fragment reassembly not complete. waiting for more data." in caplog.text - mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) + assert len(prot_hndl._handle_callback.mock_calls) == 1 + assert "Fragment reassembly not complete, waiting for more data" in caplog.text + assert mock_ack.mock_calls == [call(sender, aps_frame, 2, 0)] async def test_incoming_fragmented_message_complete(prot_hndl, caplog): @@ -221,27 +228,34 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07" + b"message" ) # fragment index 1 - sender = 0x1D6F aps_frame_1 = t.EmberApsFrame( profileId=260, - clusterId=65281, + clusterId=0xFF01, sourceEndpoint=2, destinationEndpoint=2, - options=33088, # Includes APS_OPTION_FRAGMENT - groupId=512, # fragment_count=2, fragment_index=0 + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + | t.EmberApsOption.APS_OPTION_FRAGMENT + ), + groupId=0x0200, # fragment_count=2, fragment_index=0 sequence=238, ) + aps_frame_2 = t.EmberApsFrame( profileId=260, - clusterId=65281, + clusterId=0xFF01, sourceEndpoint=2, destinationEndpoint=2, - options=33088, - groupId=513, # fragment_count=2, fragment_index=1 + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + | t.EmberApsOption.APS_OPTION_FRAGMENT + ), + groupId=0x0201, # fragment_count=2, fragment_index=1 sequence=238, ) - reassembled = b"complete message" with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: mock_ack.return_value = None @@ -250,43 +264,240 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): # Packet 1 prot_hndl(packet1) assert len(prot_hndl._fragment_ack_tasks) == 1 - ack_task = next(iter(prot_hndl._fragment_ack_tasks)) - await asyncio.gather(ack_task) # Ensure task completes and triggers callback - assert ( - len(prot_hndl._fragment_ack_tasks) == 0 - ), "Done callback should have removed task" - - prot_hndl._handle_callback.assert_not_called() - assert ( - "Reassembled fragmented message. Proceeding with normal handling." - not in caplog.text - ) - mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) + await asyncio.gather( + *prot_hndl._fragment_ack_tasks + ) # Ensure task completes and triggers callback + assert len(prot_hndl._fragment_ack_tasks) == 0 # Packet 2 prot_hndl(packet2) assert len(prot_hndl._fragment_ack_tasks) == 1 - ack_task = next(iter(prot_hndl._fragment_ack_tasks)) - await asyncio.gather(ack_task) # Ensure task completes and triggers callback - assert ( - len(prot_hndl._fragment_ack_tasks) == 0 - ), "Done callback should have removed task" + await asyncio.gather( + *prot_hndl._fragment_ack_tasks + ) # Ensure task completes and triggers callback + assert len(prot_hndl._fragment_ack_tasks) == 0 + + assert "Reassembled fragmented message, proceeding with handling" in caplog.text + assert mock_ack.mock_calls == [ + call(0x1D6F, aps_frame_1, 2, 0), + call(0x1D6F, aps_frame_2, 2, 1), + ] + + +def test_incoming_message_broadcast(prot_hndl) -> None: + """Test handling of incoming broadcast message.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) - prot_hndl._handle_callback.assert_called_once_with( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 - aps_frame_2, # Parsed APS frame - 255, # lastHopLqi: 0xFF - -8, # lastHopRssi: 0xF8 - sender, # 0x1D6F - 255, # bindingIndex: 0xFF - 255, # addressIndex: 0xFF - reassembled, # Reassembled payload - ], + # v4 field order: type, apsFrame, lqi, rssi, sender, bindingIndex, addressIndex, message + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_BROADCAST, + aps_frame, + 200, # lqi + -40, # rssi + t.EmberNodeId(0x1234), # sender + 0, # binding_index + 0, # address_index + b"broadcast message", + ], + ) + + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ), + src_ep=1, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy.types.SerializableBytes(b"broadcast message"), + lqi=200, + rssi=-40, + ) + ) ) - assert ( - "Reassembled fragmented message. Proceeding with normal handling." - in caplog.text + ] + + +def test_incoming_message_multicast(prot_hndl) -> None: + """Test handling of incoming multicast message.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x5678, + sequence=0x42, + ) + + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MULTICAST, + aps_frame, + 200, + -40, + t.EmberNodeId(0x1234), + 0, + 0, + b"multicast message", + ], + ) + + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ), + src_ep=1, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=0x5678, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy.types.SerializableBytes(b"multicast message"), + lqi=200, + rssi=-40, + ) + ) + ) + ] + + +def test_incoming_message_ignored_type(prot_hndl, caplog) -> None: + """Test that unknown message types are ignored.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + caplog.set_level(logging.DEBUG) + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, + aps_frame, + 200, + -40, + t.EmberNodeId(0x1234), + 0, + 0, + b"ignored message", + ], + ) + + # No event should be emitted for ignored message types + assert len(handler.mock_calls) == 0 + assert "Ignoring message type" in caplog.text + + +def test_trust_center_join_handler(prot_hndl) -> None: + """Test trustCenterJoinHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(TrustCenterJoinEvent.event_type, handler) + + ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") + prot_hndl.handle_parsed_callback( + "trustCenterJoinHandler", + { + "newNodeId": t.EmberNodeId(0x1234), + "newNodeEui64": ieee, + "status": t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + "policyDecision": t.EmberJoinDecision.NO_ACTION, + "parentOfNewNodeId": t.EmberNodeId(0x0000), + }.values(), + ) + + assert handler.mock_calls == [ + call( + TrustCenterJoinEvent( + nwk=t.EmberNodeId(0x1234), + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=t.EmberNodeId(0x0000), + ) + ) + ] + + +def test_incoming_route_record_handler(prot_hndl) -> None: + """Test incomingRouteRecordHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(RouteRecordEvent.event_type, handler) + + ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") + prot_hndl.handle_parsed_callback( + "incomingRouteRecordHandler", + { + "source": t.EmberNodeId(0x1234), + "sourceEui": ieee, + "lastHopLqi": t.uint8_t(200), + "lastHopRssi": t.int8s(-40), + "relayList": [t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], + }.values(), + ) + + assert handler.mock_calls == [ + call( + RouteRecordEvent( + nwk=t.EmberNodeId(0x1234), + ieee=ieee, + lqi=t.uint8_t(200), + rssi=t.int8s(-40), + relays=[t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], + ) ) - mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) + ] + + +def test_id_conflict_handler(prot_hndl) -> None: + """Test idConflictHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(IdConflictEvent.event_type, handler) + + prot_hndl.handle_parsed_callback( + "idConflictHandler", + {"conflictingId": t.EmberNodeId(0x1234)}.values(), + ) + + assert handler.mock_calls == [call(IdConflictEvent(nwk=t.EmberNodeId(0x1234)))] diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 49bf152b..9eab1e96 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -3,7 +3,9 @@ import pytest import zigpy.exceptions import zigpy.state +import zigpy.types +from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent import bellows.ezsp.v14 import bellows.types as t @@ -226,3 +228,111 @@ async def test_send_broadcast(ezsp_f) -> None: message=b"hello", ) ] + + +def test_handle_parsed_callback_incoming_message(ezsp_f) -> None: + """Test handle_parsed_callback for incomingMessageHandler.""" + handler = MagicMock() + ezsp_f.on_event(PacketReceivedEvent.event_type, handler) + + ezsp_f.handle_parsed_callback( + "incomingMessageHandler", + { + "message_type": t.EmberIncomingMessageType.INCOMING_UNICAST, + "aps_frame": t.EmberApsFrame( + profileId=260, + clusterId=8, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=168, + ), + "nwk": 0x1174, + "eui64": t.EUI64.convert("00:00:00:00:00:00:00:00"), + "binding_index": 255, + "address_index": 13, + "lqi": 192, + "rssi": -63, + "timestamp": 1333671578, + "message": b"\x18,\x0b\x04\x00", + }.values(), + ) + + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1174), + ), + src_ep=1, + dst=None, + dst_ep=1, + tsn=168, + profile_id=0x0104, + cluster_id=0x0008, + data=zigpy.types.SerializableBytes(b"\x18,\x0b\x04\x00"), + lqi=192, + rssi=-63, + ) + ) + ) + ] + + +def test_handle_parsed_callback_message_sent(ezsp_f) -> None: + """Test handle_parsed_callback for messageSentHandler.""" + handler = MagicMock() + ezsp_f.on_event(MessageSentEvent.event_type, handler) + + ezsp_f.handle_parsed_callback( + "messageSentHandler", + { + "status": t.sl_Status.OK, + "message_type": t.EmberOutgoingMessageType.OUTGOING_DIRECT, + "nwk": 0x0E0D, + "aps_frame": t.EmberApsFrame( + profileId=260, + clusterId=513, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=236, + ), + "message_tag": 103, + "message": b"", + }.values(), + ) + + assert handler.mock_calls == [ + call( + MessageSentEvent( + status=t.sl_Status.OK, + message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + destination=t.EmberNodeId(0x0E0D), + aps_frame=t.EmberApsFrame( + profileId=260, + clusterId=513, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=236, + ), + message_tag=103, + message_contents=b"", + ) + ) + ]