3737from select import select
3838from socket import create_connection , SHUT_RDWR , error as SocketError
3939from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
40- from threading import Lock
40+ from threading import RLock
4141
4242from .constants import DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , TRUST_DEFAULT , TRUST_ON_FIRST_USE
4343from .exceptions import ProtocolError , Unauthorized , ServiceUnavailable
@@ -378,15 +378,26 @@ class ConnectionPool(object):
378378 """ A collection of connections to one or more server addresses.
379379 """
380380
381+ closed = False
382+
381383 def __init__ (self , connector ):
382384 self .connector = connector
383385 self .connections = {}
384- self .lock = Lock ()
386+ self .lock = RLock ()
387+
388+ def __enter__ (self ):
389+ return self
390+
391+ def __exit__ (self , exc_type , exc_value , traceback ):
392+ self .close ()
385393
386394 def acquire (self , address ):
387395 """ Acquire a connection to a given address from the pool.
388396 This method is thread safe.
389397 """
398+ if self .closed :
399+ raise ServiceUnavailable ("This connection pool is closed so no new "
400+ "connections may be acquired" )
390401 with self .lock :
391402 try :
392403 connections = self .connections [address ]
@@ -411,18 +422,25 @@ def release(self, connection):
411422 with self .lock :
412423 connection .in_use = False
413424
425+ def remove (self , address ):
426+ """ Remove an address from the connection pool, if present, closing
427+ all connections to that address.
428+ """
429+ with self .lock :
430+ for connection in self .connections .pop (address , ()):
431+ try :
432+ connection .close ()
433+ except IOError :
434+ pass
435+
414436 def close (self ):
415437 """ Close all connections and empty the pool.
416438 This method is thread safe.
417439 """
418440 with self .lock :
419- for _ , connections in self .connections .items ():
420- for connection in connections :
421- try :
422- connection .close ()
423- except IOError :
424- pass
425- self .connections .clear ()
441+ self .closed = True
442+ for address in list (self .connections ):
443+ self .remove (address )
426444
427445
428446class CertificateStore (object ):
0 commit comments