Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bellows/ezsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def frame_received(self, data: bytes) -> None:
try:
self._protocol(data)
except Exception:
LOGGER.warning("Failed to parse frame, ignoring")
LOGGER.warning("Failed to parse frame. This is a bug!", exc_info=True)

async def get_board_info(
self,
Expand Down
247 changes: 197 additions & 50 deletions bellows/ezsp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
from asyncio import timeout as asyncio_timeout
import binascii
from collections.abc import AsyncGenerator, Callable, Iterable
from dataclasses import dataclass
import functools
import logging
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Final

from zigpy.datastructures import PriorityDynamicBoundedSemaphore
from zigpy.event.event_base import EventBase
import zigpy.state
import zigpy.types

from bellows.config import CONF_EZSP_POLICIES
from bellows.exception import InvalidCommandError
Expand All @@ -27,13 +30,62 @@
MAX_COMMAND_CONCURRENCY = 1


class ProtocolHandler(abc.ABC):
@dataclass(frozen=True, kw_only=True)
class MessageSentEvent:
event_type: Final[str] = "message_sent"

status: t.sl_Status
message_type: t.EmberOutgoingMessageType
destination: t.uint16_t
aps_frame: t.EmberApsFrame
message_tag: t.uint8_t
message_contents: t.LVBytes


@dataclass(frozen=True, kw_only=True)
class PacketReceivedEvent:
event_type: Final[str] = "packet_received"

packet: zigpy.types.ZigbeePacket


@dataclass(frozen=True, kw_only=True)
class TrustCenterJoinEvent:
event_type: Final[str] = "trust_center_join"

nwk: t.EmberNodeId
ieee: t.EUI64
device_update_status: t.EmberDeviceUpdate
decision: t.EmberJoinDecision
parent_nwk: t.EmberNodeId


@dataclass(frozen=True, kw_only=True)
class RouteRecordEvent:
event_type: Final[str] = "route_record"

nwk: t.EmberNodeId
ieee: t.EUI64
lqi: t.uint8_t
rssi: t.int8s
relays: t.LVList[t.EmberNodeId]


@dataclass(frozen=True, kw_only=True)
class IdConflictEvent:
event_type: Final[str] = "id_conflict"

nwk: t.EmberNodeId


class ProtocolHandler(EventBase, abc.ABC):
"""EZSP protocol specific handler."""

COMMANDS = {}
VERSION = None

def __init__(self, cb_handler: Callable, gateway: Gateway) -> None:
super().__init__()
self._handle_callback = cb_handler
self._awaiting = {}
self._gw = gateway
Expand Down Expand Up @@ -179,52 +231,6 @@ def __call__(self, data: bytes) -> None:
if data:
LOGGER.debug("Frame contains trailing data: %s", data)

if (
frame_name == "incomingMessageHandler"
and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT
):
# Extract received APS frame and sender
aps_frame = result[1]
sender = result[4]

# The fragment count and index are encoded in the groupId field
fragment_count = (aps_frame.groupId >> 8) & 0xFF
fragment_index = aps_frame.groupId & 0xFF

(
complete,
reassembled,
frag_count,
frag_index,
) = self._fragment_manager.handle_incoming_fragment(
sender_nwk=sender,
aps_sequence=aps_frame.sequence,
profile_id=aps_frame.profileId,
cluster_id=aps_frame.clusterId,
fragment_count=fragment_count,
fragment_index=fragment_index,
payload=result[7],
)

ack_task = asyncio.create_task(
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
) # APS Ack

self._fragment_ack_tasks.add(ack_task)
ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t))

if not complete:
# Do not pass partial data up the stack
LOGGER.debug("Fragment reassembly not complete. waiting for more data.")
return

# Replace partial data with fully reassembled data
result[7] = reassembled

LOGGER.debug(
"Reassembled fragmented message. Proceeding with normal handling."
)

if sequence in self._awaiting:
expected_id, schema, future = self._awaiting.pop(sequence)
try:
Expand All @@ -246,8 +252,20 @@ def __call__(self, data: bytes) -> None:
sequence,
self.COMMANDS_BY_ID.get(expected_id, [expected_id])[0],
)
else:
self._handle_callback(frame_name, result)

return

self.handle_parsed_callback(frame_name, result)

# Legacy callback system for CLI tools
self._handle_callback(frame_name, result)

def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None:
"""Dispatch a callback frame to the appropriate handler method."""
handler = getattr(self, f"_handle_{frame_name}", None)

if handler is not None:
handler(*args)

async def _send_fragment_ack(
self,
Expand Down Expand Up @@ -275,6 +293,135 @@ async def _send_fragment_ack(
status = await self.sendReply(sender, ackFrame, b"")
return status[0]

def _handle_incoming_message(
self,
message_type: t.EmberIncomingMessageType,
aps_frame: t.EmberApsFrame,
sender: zigpy.types.NWK,
eui64: zigpy.types.EUI64 | None,
binding_index: t.uint8_t,
address_index: t.uint8_t,
lqi: t.uint8_t,
rssi: t.int8s,
timestamp: t.uint32_t | None,
message: t.LVBytes,
) -> None:
"""Handle incomingMessageHandler callback and maybe return a packet."""

if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT:
fragment_count = (aps_frame.groupId >> 8) & 0xFF
fragment_index = aps_frame.groupId & 0xFF

(
complete,
reassembled,
frag_count,
frag_index,
) = self._fragment_manager.handle_incoming_fragment(
sender_nwk=sender,
aps_sequence=aps_frame.sequence,
profile_id=aps_frame.profileId,
cluster_id=aps_frame.clusterId,
fragment_count=fragment_count,
fragment_index=fragment_index,
payload=message,
)

ack_task = asyncio.create_task(
self._send_fragment_ack(sender, aps_frame, frag_count, frag_index)
)
self._fragment_ack_tasks.add(ack_task)
ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t))

if not complete:
LOGGER.debug("Fragment reassembly not complete, waiting for more data")
return

LOGGER.debug("Reassembled fragmented message, proceeding with handling")
message = reassembled

# Determine destination address based on message type
if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST:
dst = zigpy.types.AddrModeAddress(
addr_mode=zigpy.types.AddrMode.Broadcast,
address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR,
)
elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST:
dst = zigpy.types.AddrModeAddress(
addr_mode=zigpy.types.AddrMode.Group,
address=aps_frame.groupId,
)
elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST:
dst = None # We don't know the current NWK here
else:
LOGGER.debug("Ignoring message type: %r", message_type)
return

self.emit(
PacketReceivedEvent.event_type,
PacketReceivedEvent(
packet=zigpy.types.ZigbeePacket(
src=zigpy.types.AddrModeAddress(
addr_mode=zigpy.types.AddrMode.NWK,
address=zigpy.types.NWK(sender),
),
src_ep=aps_frame.sourceEndpoint,
dst=dst,
dst_ep=aps_frame.destinationEndpoint,
tsn=aps_frame.sequence,
profile_id=aps_frame.profileId,
cluster_id=aps_frame.clusterId,
data=zigpy.types.SerializableBytes(message),
lqi=lqi,
rssi=rssi,
)
),
)

def _handle_trustCenterJoinHandler(
self,
nwk: t.EmberNodeId,
ieee: t.EUI64,
device_update_status: t.EmberDeviceUpdate,
decision: t.EmberJoinDecision,
parent_nwk: t.EmberNodeId,
) -> None:
self.emit(
TrustCenterJoinEvent.event_type,
TrustCenterJoinEvent(
nwk=nwk,
ieee=ieee,
device_update_status=device_update_status,
decision=decision,
parent_nwk=parent_nwk,
),
)

def _handle_incomingRouteRecordHandler(
self,
nwk: t.EmberNodeId,
ieee: t.EUI64,
lqi: t.uint8_t,
rssi: t.int8s,
relays: t.LVList[t.EmberNodeId],
) -> None:
self.emit(
RouteRecordEvent.event_type,
RouteRecordEvent(
nwk=nwk,
ieee=ieee,
lqi=lqi,
rssi=rssi,
relays=relays,
),
)

def _handle_idConflictHandler(self, nwk: t.EmberNodeId) -> None:
self.emit(
IdConflictEvent.event_type,
IdConflictEvent(nwk=nwk),
)

def __getattr__(self, name: str) -> Callable:
if name not in self.COMMANDS:
raise AttributeError(f"{name} not found in COMMANDS")
Expand Down
52 changes: 52 additions & 0 deletions bellows/ezsp/v14/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
import logging

import voluptuous as vol
from zigpy.exceptions import NetworkNotFormed
import zigpy.state
import zigpy.types

import bellows.config
import bellows.types as t

from . import commands, config
from ..protocol import MessageSentEvent
from ..v13 import EZSPv13

LOGGER = logging.getLogger(__name__)


class EZSPv14(EZSPv13):
"""EZSP Version 14 Protocol version handler."""
Expand Down Expand Up @@ -144,3 +149,50 @@ async def send_broadcast(
)

return status, sequence

def _handle_incomingMessageHandler(
self,
message_type: t.EmberIncomingMessageType,
aps_frame: t.EmberApsFrame,
sender: t.EmberNodeId,
eui64: t.EUI64,
binding_index: t.uint8_t,
address_index: t.uint8_t,
lqi: t.uint8_t,
rssi: t.int8s,
timestamp: t.uint32_t,
message: t.LVBytes,
) -> None:
self._handle_incoming_message(
message_type=message_type,
aps_frame=aps_frame,
sender=sender,
eui64=None,
binding_index=binding_index,
address_index=address_index,
lqi=lqi,
rssi=rssi,
timestamp=None,
message=message,
)

def _handle_messageSentHandler(
self,
status: t.sl_Status,
message_type: t.EmberOutgoingMessageType,
destination: t.EmberNodeId,
aps_frame: t.EmberApsFrame,
message_tag: t.uint8_t,
message: t.LVBytes,
) -> None:
self.emit(
MessageSentEvent.event_type,
MessageSentEvent(
status=status,
message_type=message_type,
destination=destination,
aps_frame=aps_frame,
message_tag=message_tag,
message_contents=message,
),
)
Loading
Loading