|
26 | 26 | * Adafruit CircuitPython firmware for the supported boards: |
27 | 27 | https://github.com/adafruit/circuitpython/releases |
28 | 28 |
|
| 29 | +* Adafruit's Connection Manager library: |
| 30 | + https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager |
| 31 | +
|
29 | 32 | """ |
30 | 33 | import errno |
31 | 34 | import struct |
32 | 35 | import time |
33 | 36 | from random import randint |
34 | 37 |
|
| 38 | +from adafruit_connectionmanager import ( |
| 39 | + get_connection_manager, |
| 40 | + SocketGetOSError, |
| 41 | + SocketConnectMemoryError, |
| 42 | +) |
| 43 | + |
35 | 44 | try: |
36 | 45 | from typing import List, Optional, Tuple, Type, Union |
37 | 46 | except ImportError: |
|
78 | 87 | _default_sock = None # pylint: disable=invalid-name |
79 | 88 | _fake_context = None # pylint: disable=invalid-name |
80 | 89 |
|
| 90 | +TemporaryError = (SocketGetOSError, SocketConnectMemoryError) |
| 91 | + |
81 | 92 |
|
82 | 93 | class MMQTTException(Exception): |
83 | 94 | """MiniMQTT Exception class.""" |
84 | 95 |
|
85 | | - # pylint: disable=unnecessary-pass |
86 | | - # pass |
87 | | - |
88 | | - |
89 | | -class TemporaryError(Exception): |
90 | | - """Temporary error class used for handling reconnects.""" |
91 | | - |
92 | | - |
93 | | -# Legacy ESP32SPI Socket API |
94 | | -def set_socket(sock, iface=None) -> None: |
95 | | - """Legacy API for setting the socket and network interface. |
96 | | -
|
97 | | - :param sock: socket object. |
98 | | - :param iface: internet interface object |
99 | | -
|
100 | | - """ |
101 | | - global _default_sock # pylint: disable=invalid-name, global-statement |
102 | | - global _fake_context # pylint: disable=invalid-name, global-statement |
103 | | - _default_sock = sock |
104 | | - if iface: |
105 | | - _default_sock.set_interface(iface) |
106 | | - _fake_context = _FakeSSLContext(iface) |
107 | | - |
108 | | - |
109 | | -class _FakeSSLSocket: |
110 | | - def __init__(self, socket, tls_mode) -> None: |
111 | | - self._socket = socket |
112 | | - self._mode = tls_mode |
113 | | - self.settimeout = socket.settimeout |
114 | | - self.send = socket.send |
115 | | - self.recv = socket.recv |
116 | | - self.close = socket.close |
117 | | - |
118 | | - def connect(self, address): |
119 | | - """connect wrapper to add non-standard mode parameter""" |
120 | | - try: |
121 | | - return self._socket.connect(address, self._mode) |
122 | | - except RuntimeError as error: |
123 | | - raise OSError(errno.ENOMEM) from error |
124 | | - |
125 | | - |
126 | | -class _FakeSSLContext: |
127 | | - def __init__(self, iface) -> None: |
128 | | - self._iface = iface |
129 | | - |
130 | | - def wrap_socket(self, socket, server_hostname=None) -> _FakeSSLSocket: |
131 | | - """Return the same socket""" |
132 | | - # pylint: disable=unused-argument |
133 | | - return _FakeSSLSocket(socket, self._iface.TLS_MODE) |
134 | | - |
135 | 96 |
|
136 | 97 | class NullLogger: |
137 | 98 | """Fake logger class that does not do anything""" |
138 | 99 |
|
139 | 100 | # pylint: disable=unused-argument |
140 | 101 | def nothing(self, msg: str, *args) -> None: |
141 | 102 | """no action""" |
142 | | - pass |
143 | 103 |
|
144 | 104 | def __init__(self) -> None: |
145 | 105 | for log_level in ["debug", "info", "warning", "error", "critical"]: |
@@ -194,6 +154,7 @@ def __init__( |
194 | 154 | user_data=None, |
195 | 155 | use_imprecise_time: Optional[bool] = None, |
196 | 156 | ) -> None: |
| 157 | + self._connection_manager = get_connection_manager(socket_pool) |
197 | 158 | self._socket_pool = socket_pool |
198 | 159 | self._ssl_context = ssl_context |
199 | 160 | self._sock = None |
@@ -300,77 +261,6 @@ def get_monotonic_time(self) -> float: |
300 | 261 |
|
301 | 262 | return time.monotonic() |
302 | 263 |
|
303 | | - # pylint: disable=too-many-branches |
304 | | - def _get_connect_socket(self, host: str, port: int, *, timeout: int = 1): |
305 | | - """Obtains a new socket and connects to a broker. |
306 | | -
|
307 | | - :param str host: Desired broker hostname |
308 | | - :param int port: Desired broker port |
309 | | - :param int timeout: Desired socket timeout, in seconds |
310 | | - """ |
311 | | - # For reconnections - check if we're using a socket already and close it |
312 | | - if self._sock: |
313 | | - self._sock.close() |
314 | | - self._sock = None |
315 | | - |
316 | | - # Legacy API - use the interface's socket instead of a passed socket pool |
317 | | - if self._socket_pool is None: |
318 | | - self._socket_pool = _default_sock |
319 | | - |
320 | | - # Legacy API - fake the ssl context |
321 | | - if self._ssl_context is None: |
322 | | - self._ssl_context = _fake_context |
323 | | - |
324 | | - if not isinstance(port, int): |
325 | | - raise RuntimeError("Port must be an integer") |
326 | | - |
327 | | - if self._is_ssl and not self._ssl_context: |
328 | | - raise RuntimeError( |
329 | | - "ssl_context must be set before using adafruit_mqtt for secure MQTT." |
330 | | - ) |
331 | | - |
332 | | - if self._is_ssl: |
333 | | - self.logger.info(f"Establishing a SECURE SSL connection to {host}:{port}") |
334 | | - else: |
335 | | - self.logger.info(f"Establishing an INSECURE connection to {host}:{port}") |
336 | | - |
337 | | - addr_info = self._socket_pool.getaddrinfo( |
338 | | - host, port, 0, self._socket_pool.SOCK_STREAM |
339 | | - )[0] |
340 | | - |
341 | | - try: |
342 | | - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) |
343 | | - except OSError as exc: |
344 | | - # Do not consider this for back-off. |
345 | | - self.logger.warning( |
346 | | - f"Failed to create socket for host {addr_info[0]} and port {addr_info[1]}" |
347 | | - ) |
348 | | - raise TemporaryError from exc |
349 | | - |
350 | | - connect_host = addr_info[-1][0] |
351 | | - if self._is_ssl: |
352 | | - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) |
353 | | - connect_host = host |
354 | | - sock.settimeout(timeout) |
355 | | - |
356 | | - last_exception = None |
357 | | - try: |
358 | | - sock.connect((connect_host, port)) |
359 | | - except MemoryError as exc: |
360 | | - sock.close() |
361 | | - self.logger.warning(f"Failed to allocate memory for connect: {exc}") |
362 | | - # Do not consider this for back-off. |
363 | | - raise TemporaryError from exc |
364 | | - except OSError as exc: |
365 | | - sock.close() |
366 | | - last_exception = exc |
367 | | - |
368 | | - if last_exception: |
369 | | - raise last_exception |
370 | | - |
371 | | - self._backwards_compatible_sock = not hasattr(sock, "recv_into") |
372 | | - return sock |
373 | | - |
374 | 264 | def __enter__(self): |
375 | 265 | return self |
376 | 266 |
|
@@ -593,8 +483,15 @@ def _connect( |
593 | 483 | time.sleep(self._reconnect_timeout) |
594 | 484 |
|
595 | 485 | # Get a new socket |
596 | | - self._sock = self._get_connect_socket( |
597 | | - self.broker, self.port, timeout=self._socket_timeout |
| 486 | + self._sock = self._connection_manager.get_socket( |
| 487 | + self.broker, |
| 488 | + self.port, |
| 489 | + "mqtt:", |
| 490 | + timeout=self._socket_timeout, |
| 491 | + is_ssl=self._is_ssl, |
| 492 | + ssl_context=self._ssl_context, |
| 493 | + max_retries=1, # setting to 1 since we want to handle backoff internally |
| 494 | + exception_passthrough=True, |
598 | 495 | ) |
599 | 496 |
|
600 | 497 | # Fixed Header |
@@ -689,7 +586,7 @@ def disconnect(self) -> None: |
689 | 586 | except RuntimeError as e: |
690 | 587 | self.logger.warning(f"Unable to send DISCONNECT packet: {e}") |
691 | 588 | self.logger.debug("Closing socket") |
692 | | - self._sock.close() |
| 589 | + self._connection_manager.free_socket(self._sock) |
693 | 590 | self._is_connected = False |
694 | 591 | self._subscribed_topics = [] |
695 | 592 | if self.on_disconnect is not None: |
|
0 commit comments