@@ -458,6 +458,16 @@ def metadata(self):
458458 return self .__metadata .copy ()
459459
460460
461+ def _negotiate_creds (all_credentials ):
462+ """Return one credential that needs mechanism negotiation, if any.
463+ """
464+ if all_credentials :
465+ for creds in all_credentials .values ():
466+ if creds .mechanism == 'DEFAULT' and creds .username :
467+ return creds
468+ return None
469+
470+
461471class SocketInfo (object ):
462472 """Store a socket with some metadata.
463473
@@ -488,13 +498,16 @@ def __init__(self, sock, pool, address, id):
488498 self .compression_settings = pool .opts .compression_settings
489499 self .compression_context = None
490500 self .socket_checker = SocketChecker ()
501+ # Support for mechanism negotiation on the initial handshake.
502+ # Maps credential to saslSupportedMechs.
503+ self .negotiated_mechanisms = {}
491504
492505 # The pool's generation changes with each reset() so we can close
493506 # sockets created before the last reset.
494507 self .generation = pool .generation
495508 self .ready = False
496509
497- def ismaster (self , metadata , cluster_time ):
510+ def ismaster (self , metadata , cluster_time , all_credentials = None ):
498511 cmd = SON ([('ismaster' , 1 )])
499512 if not self .performed_handshake :
500513 cmd ['client' ] = metadata
@@ -504,6 +517,12 @@ def ismaster(self, metadata, cluster_time):
504517 if self .max_wire_version >= 6 and cluster_time is not None :
505518 cmd ['$clusterTime' ] = cluster_time
506519
520+ # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
521+ # unchangeable value per MongoClient.
522+ creds = _negotiate_creds (all_credentials )
523+ if creds :
524+ cmd ['saslSupportedMechs' ] = creds .source + '.' + creds .username
525+
507526 ismaster = IsMaster (self .command ('admin' , cmd , publish_events = False ))
508527 self .is_writable = ismaster .is_writable
509528 self .max_wire_version = ismaster .max_wire_version
@@ -520,6 +539,8 @@ def ismaster(self, metadata, cluster_time):
520539
521540 self .performed_handshake = True
522541 self .op_msg_enabled = ismaster .max_wire_version >= 6
542+ if creds :
543+ self .negotiated_mechanisms [creds ] = ismaster .sasl_supported_mechs
523544 return ismaster
524545
525546 def command (self , dbname , spec , slave_ok = False ,
@@ -701,8 +722,7 @@ def check_auth(self, all_credentials):
701722 self .authset .discard (credentials )
702723
703724 for credentials in cached - authset :
704- auth .authenticate (credentials , self )
705- self .authset .add (credentials )
725+ self .authenticate (credentials )
706726
707727 # CMAP spec says to publish the ready event only after authenticating
708728 # the connection.
@@ -721,6 +741,8 @@ def authenticate(self, credentials):
721741 """
722742 auth .authenticate (credentials , self )
723743 self .authset .add (credentials )
744+ # negotiated_mechanisms are no longer needed.
745+ self .negotiated_mechanisms .pop (credentials , None )
724746
725747 def validate_session (self , client , session ):
726748 """Validate this session before use with client.
@@ -1026,7 +1048,7 @@ def reset(self):
10261048 def close (self ):
10271049 self ._reset (close = True )
10281050
1029- def remove_stale_sockets (self , reference_generation ):
1051+ def remove_stale_sockets (self , reference_generation , all_credentials ):
10301052 """Removes stale sockets then adds new ones if pool is too small and
10311053 has not been reset. The `reference_generation` argument specifies the
10321054 `generation` at the point in time this operation was requested on the
@@ -1050,7 +1072,7 @@ def remove_stale_sockets(self, reference_generation):
10501072 if not self ._socket_semaphore .acquire (False ):
10511073 break
10521074 try :
1053- sock_info = self .connect ()
1075+ sock_info = self .connect (all_credentials )
10541076 with self .lock :
10551077 # Close connection and return if the pool was reset during
10561078 # socket creation or while acquiring the pool lock.
@@ -1061,7 +1083,7 @@ def remove_stale_sockets(self, reference_generation):
10611083 finally :
10621084 self ._socket_semaphore .release ()
10631085
1064- def connect (self ):
1086+ def connect (self , all_credentials = None ):
10651087 """Connect to Mongo and return a new SocketInfo.
10661088
10671089 Can raise ConnectionFailure or CertificateError.
@@ -1081,9 +1103,6 @@ def connect(self):
10811103 try :
10821104 sock = _configured_socket (self .address , self .opts )
10831105 except socket .error as error :
1084- if sock is not None :
1085- sock .close ()
1086-
10871106 if self .enabled_for_cmap :
10881107 listeners .publish_connection_closed (
10891108 self .address , conn_id , ConnectionClosedReason .ERROR )
@@ -1092,7 +1111,7 @@ def connect(self):
10921111
10931112 sock_info = SocketInfo (sock , self , self .address , conn_id )
10941113 if self .handshake :
1095- sock_info .ismaster (self .opts .metadata , None )
1114+ sock_info .ismaster (self .opts .metadata , None , all_credentials )
10961115 self .is_writable = sock_info .is_writable
10971116
10981117 return sock_info
@@ -1123,29 +1142,23 @@ def get_socket(self, all_credentials, checkout=False):
11231142 listeners = self .opts .event_listeners
11241143 if self .enabled_for_cmap :
11251144 listeners .publish_connection_check_out_started (self .address )
1126- # First get a socket, then attempt authentication. Simplifies
1127- # semaphore management in the face of network errors during auth.
1128- sock_info = self ._get_socket_no_auth ()
1129- checked_auth = False
1145+
1146+ sock_info = self ._get_socket (all_credentials )
1147+
1148+ if self .enabled_for_cmap :
1149+ listeners .publish_connection_checked_out (
1150+ self .address , sock_info .id )
11301151 try :
1131- sock_info .check_auth (all_credentials )
1132- checked_auth = True
1133- if self .enabled_for_cmap :
1134- listeners .publish_connection_checked_out (
1135- self .address , sock_info .id )
11361152 yield sock_info
11371153 except :
11381154 # Exception in caller. Decrement semaphore.
1139- self .return_socket (sock_info , publish_checkin = checked_auth )
1140- if self .enabled_for_cmap and not checked_auth :
1141- self .opts .event_listeners .publish_connection_check_out_failed (
1142- self .address , ConnectionCheckOutFailedReason .CONN_ERROR )
1155+ self .return_socket (sock_info )
11431156 raise
11441157 else :
11451158 if not checkout :
11461159 self .return_socket (sock_info )
11471160
1148- def _get_socket_no_auth (self ):
1161+ def _get_socket (self , all_credentials ):
11491162 """Get or create a SocketInfo. Can raise ConnectionFailure."""
11501163 # We use the pid here to avoid issues with fork / multiprocessing.
11511164 # See test.test_client:TestClient.test_fork for an example of
@@ -1177,10 +1190,11 @@ def _get_socket_no_auth(self):
11771190 sock_info = self .sockets .popleft ()
11781191 except IndexError :
11791192 # Can raise ConnectionFailure or CertificateError.
1180- sock_info = self .connect ()
1193+ sock_info = self .connect (all_credentials )
11811194 else :
11821195 if self ._perished (sock_info ):
11831196 sock_info = None
1197+ sock_info .check_auth (all_credentials )
11841198 except Exception :
11851199 self ._socket_semaphore .release ()
11861200 with self .lock :
@@ -1193,16 +1207,14 @@ def _get_socket_no_auth(self):
11931207
11941208 return sock_info
11951209
1196- def return_socket (self , sock_info , publish_checkin = True ):
1210+ def return_socket (self , sock_info ):
11971211 """Return the socket to the pool, or if it's closed discard it.
11981212
11991213 :Parameters:
12001214 - `sock_info`: The socket to check into the pool.
1201- - `publish_checkin`: If False, a ConnectionCheckedInEvent will not
1202- be published.
12031215 """
12041216 listeners = self .opts .event_listeners
1205- if self .enabled_for_cmap and publish_checkin :
1217+ if self .enabled_for_cmap :
12061218 listeners .publish_connection_checked_in (self .address , sock_info .id )
12071219 if self .pid != os .getpid ():
12081220 self .reset ()
0 commit comments