6161 const (0x05 ): "Connection Refused - Unauthorized" ,
6262}
6363
64- _the_interface = None # pylint: disable=invalid-name
65- _the_sock = None # pylint: disable=invalid-name
66-
64+ _default_sock = None # pylint: disable=invalid-name
65+ _fake_context = None # pylint: disable=invalid-name
6766
6867class MMQTTException (Exception ):
6968 """MiniMQTT Exception class."""
@@ -74,18 +73,17 @@ class MMQTTException(Exception):
7473
7574# Legacy ESP32SPI Socket API
7675def set_socket (sock , iface = None ):
77- """Legacy API for setting the socket and network interface, use a Session instead.
78-
76+ """Legacy API for setting the socket and network interface.
7977 :param sock: socket object.
8078 :param iface: internet interface object
79+
8180 """
82- global _the_sock # pylint: disable=invalid-name, global-statement
83- _the_sock = sock
81+ global _default_sock # pylint: disable=invalid-name, global-statement
82+ global _fake_context # pylint: disable=invalid-name, global-statement
83+ _default_sock = sock
8484 if iface :
85- global _the_interface # pylint: disable=invalid-name, global-statement
86- _the_interface = iface
87- _the_sock .set_interface (iface )
88-
85+ _default_sock .set_interface (iface )
86+ _fake_context = _FakeSSLContext (iface )
8987
9088class _FakeSSLSocket :
9189 def __init__ (self , socket , tls_mode ):
@@ -103,7 +101,6 @@ def connect(self, address):
103101 except RuntimeError as error :
104102 raise OSError (errno .ENOMEM ) from error
105103
106-
107104class _FakeSSLContext :
108105 def __init__ (self , iface ):
109106 self ._iface = iface
@@ -144,18 +141,7 @@ def __init__(
144141 ):
145142
146143 self ._socket_pool = socket_pool
147- # Legacy API - if we do not have a socket pool, use default socket
148- if self ._socket_pool is None :
149- self ._socket_pool = _the_sock
150-
151144 self ._ssl_context = ssl_context
152- # Legacy API - if we do not have SSL context, fake it
153- if self ._ssl_context is None :
154- self ._ssl_context = _FakeSSLContext (_the_interface )
155-
156- # Hang onto open sockets so that we can reuse them
157- self ._socket_free = {}
158- self ._open_sockets = {}
159145 self ._sock = None
160146 self ._backwards_compatible_sock = False
161147
@@ -214,93 +200,53 @@ def __init__(
214200 self .on_subscribe = None
215201 self .on_unsubscribe = None
216202
217- # Socket helpers
218- def _free_socket (self , socket ):
219- """Frees a socket for re-use."""
220- if socket not in self ._open_sockets .values ():
221- raise RuntimeError ("Socket not from MQTT client." )
222- self ._socket_free [socket ] = True
223-
224- def _close_socket (self , socket ):
225- """Closes a slocket."""
226- socket .close ()
227- del self ._socket_free [socket ]
228- key = None
229- for k in self ._open_sockets :
230- if self ._open_sockets [k ] == socket :
231- key = k
232- break
233- if key :
234- del self ._open_sockets [key ]
235-
236- def _free_sockets (self ):
237- """Closes all free sockets."""
238- free_sockets = []
239- for sock in self ._socket_free :
240- if self ._socket_free [sock ]:
241- free_sockets .append (sock )
242- for sock in free_sockets :
243- self ._close_socket (sock )
244203
245204 # pylint: disable=too-many-branches
246205 def _get_socket (self , host , port , * , timeout = 1 ):
247- key = (host , port )
248- if key in self ._open_sockets :
249- sock = self ._open_sockets [key ]
250- if self ._socket_free [sock ]:
251- self ._socket_free [sock ] = False
252- return sock
253- if port == 8883 and not self ._ssl_context :
254- raise RuntimeError (
255- "ssl_context must be set before using adafruit_mqtt for secure MQTT."
256- )
206+ # For reconnections - check if we're using a socket already and close it
207+ if self ._sock :
208+ self ._sock .close ()
257209
258210 # Legacy API - use a default socket instead of socket pool
259211 if self ._socket_pool is None :
260- self ._socket_pool = _the_sock
212+ self ._socket_pool = _default_sock
213+
214+ # Legacy API - fake the ssl context
215+ if self ._ssl_context is None :
216+ self ._ssl_context = _fake_context
217+
218+ if port == 8883 and self ._ssl_context is None :
219+ raise RuntimeError (
220+ "ssl_context must be set before using adafruit_mqtt for secure MQTT."
221+ )
261222
262223 addr_info = self ._socket_pool .getaddrinfo (
263224 host , port , 0 , self ._socket_pool .SOCK_STREAM
264225 )[0 ]
226+
265227 retry_count = 0
266228 sock = None
267- while retry_count < 5 and sock is None :
268- if retry_count > 0 :
269- if any (self ._socket_free .items ()):
270- self ._free_sockets ()
271- else :
272- raise RuntimeError ("Sending request failed" )
273- retry_count += 1
274229
275- try :
276- sock = self ._socket_pool .socket (
277- addr_info [0 ], addr_info [1 ], addr_info [2 ]
278- )
279- except OSError :
280- continue
281-
282- connect_host = addr_info [- 1 ][0 ]
283- if port == 8883 :
284- sock = self ._ssl_context .wrap_socket (sock , server_hostname = host )
285- connect_host = host
286- sock .settimeout (timeout )
230+ sock = self ._socket_pool .socket (
231+ addr_info [0 ], addr_info [1 ], addr_info [2 ]
232+ )
287233
288- try :
289- sock .connect ((connect_host , port ))
290- except MemoryError :
291- sock .close ()
292- sock = None
293- except OSError :
294- sock .close ()
295- sock = None
234+ connect_host = addr_info [- 1 ][0 ]
235+ if port == 8883 :
236+ sock = self ._ssl_context .wrap_socket (sock , server_hostname = host )
237+ connect_host = host
238+ sock .settimeout (timeout )
296239
297- if sock is None :
298- raise RuntimeError ("Repeated socket failures" )
240+ try :
241+ sock .connect ((connect_host , port ))
242+ except MemoryError as err :
243+ sock .close ()
244+ raise MemoryError (err )
245+ except OSError as err :
246+ sock .close ()
247+ raise OSError (err )
299248
300249 self ._backwards_compatible_sock = not hasattr (sock , "recv_into" )
301-
302- self ._open_sockets [key ] = sock
303- self ._socket_free [sock ] = False
304250 return sock
305251
306252 def __enter__ (self ):
0 commit comments