Skip to content

Commit b7796e1

Browse files
authored
PYTHON-3807 add types to mongo_client.py (#1315)
1 parent 883d57f commit b7796e1

File tree

9 files changed

+246
-114
lines changed

9 files changed

+246
-114
lines changed

pymongo/aggregation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pymongo.pool import Connection
3333
from pymongo.read_preferences import _ServerMode
3434
from pymongo.server import Server
35-
from pymongo.typings import _Pipeline
35+
from pymongo.typings import _DocumentType, _Pipeline
3636

3737

3838
class _AggregationCommand:
@@ -132,11 +132,11 @@ def get_read_preference(
132132

133133
def get_cursor(
134134
self,
135-
session: ClientSession,
135+
session: Optional[ClientSession],
136136
server: Server,
137137
conn: Connection,
138138
read_preference: _ServerMode,
139-
) -> CommandCursor:
139+
) -> CommandCursor[_DocumentType]:
140140
# Serialize command.
141141
cmd = SON([("aggregate", self._aggregation_target), ("pipeline", self._pipeline)])
142142
cmd.update(self._options)

pymongo/client_session.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@
150150
MutableMapping,
151151
NoReturn,
152152
Optional,
153-
Tuple,
154153
Type,
155154
TypeVar,
156155
)
@@ -180,6 +179,7 @@
180179

181180
from pymongo.pool import Connection
182181
from pymongo.server import Server
182+
from pymongo.typings import _Address
183183

184184

185185
class SessionOptions:
@@ -399,7 +399,7 @@ def __init__(self, opts: Optional[TransactionOptions], client: MongoClient):
399399
self.opts = opts
400400
self.state = _TxnState.NONE
401401
self.sharded = False
402-
self.pinned_address: Optional[Tuple[str, Optional[int]]] = None
402+
self.pinned_address: Optional[_Address] = None
403403
self.conn_mgr: Optional[_ConnectionManager] = None
404404
self.recovery_token = None
405405
self.attempt = 0
@@ -839,7 +839,9 @@ def _finish_transaction_with_retry(self, command_name: str) -> Dict[str, Any]:
839839
- `command_name`: Either "commitTransaction" or "abortTransaction".
840840
"""
841841

842-
def func(session: ClientSession, conn: Connection, retryable: bool) -> Dict[str, Any]:
842+
def func(
843+
session: Optional[ClientSession], conn: Connection, retryable: bool
844+
) -> Dict[str, Any]:
843845
return self._finish_transaction(conn, command_name)
844846

845847
return self._client._retry_internal(True, func, self, None)
@@ -947,7 +949,7 @@ def _starting_transaction(self) -> bool:
947949
return self._transaction.starting()
948950

949951
@property
950-
def _pinned_address(self) -> Optional[Tuple[str, Optional[int]]]:
952+
def _pinned_address(self) -> Optional[_Address]:
951953
"""The mongos address this transaction was created on."""
952954
if self._transaction.active():
953955
return self._transaction.pinned_address
@@ -1043,7 +1045,7 @@ def __copy__(self) -> NoReturn:
10431045
class _EmptyServerSession:
10441046
__slots__ = "dirty", "started_retryable_write"
10451047

1046-
def __init__(self):
1048+
def __init__(self) -> None:
10471049
self.dirty = False
10481050
self.started_retryable_write = False
10491051

pymongo/collection.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,7 @@
7474
_IndexKeyHint,
7575
_IndexList,
7676
)
77-
from pymongo.read_preferences import (
78-
Primary,
79-
PrimaryPreferred,
80-
ReadPreference,
81-
_ServerMode,
82-
)
77+
from pymongo.read_preferences import ReadPreference, _ServerMode
8378
from pymongo.results import (
8479
BulkWriteResult,
8580
DeleteResult,
@@ -264,7 +259,7 @@ def __init__(
264259

265260
def _conn_for_reads(
266261
self, session: ClientSession
267-
) -> ContextManager[Tuple[Connection, Union[PrimaryPreferred, Primary]]]:
262+
) -> ContextManager[Tuple[Connection, _ServerMode]]:
268263
return self.__database.client._conn_for_reads(self._read_preference_for(session), session)
269264

270265
def _conn_for_writes(self, session: Optional[ClientSession]) -> ContextManager[Connection]:
@@ -433,6 +428,24 @@ def database(self) -> Database[_DocumentType]:
433428
"""
434429
return self.__database
435430

431+
# @overload
432+
# def with_options(
433+
# self,
434+
# codec_options: None = None,
435+
# read_preference: Optional[_ServerMode] = None,
436+
# write_concern: Optional[WriteConcern] = None,
437+
# read_concern: Optional[ReadConcern] = None,
438+
# ) -> Collection[Dict[str, Any]]: ...
439+
440+
# @overload
441+
# def with_options(
442+
# self,
443+
# codec_options: bson.CodecOptions[_DocumentType],
444+
# read_preference: Optional[_ServerMode] = None,
445+
# write_concern: Optional[WriteConcern] = None,
446+
# read_concern: Optional[ReadConcern] = None,
447+
# ) -> Collection[_DocumentType]: ...
448+
436449
def with_options(
437450
self,
438451
codec_options: Optional[bson.CodecOptions[_DocumentTypeArg]] = None,
@@ -597,7 +610,7 @@ def _insert_one(
597610
command["comment"] = comment
598611

599612
def _insert_command(
600-
session: ClientSession, conn: Connection, retryable_write: bool
613+
session: Optional[ClientSession], conn: Connection, retryable_write: bool
601614
) -> None:
602615
if bypass_doc_val:
603616
command["bypassDocumentValidation"] = True
@@ -861,7 +874,7 @@ def _update_retryable(
861874
session: Optional[ClientSession] = None,
862875
let: Optional[Mapping[str, Any]] = None,
863876
comment: Optional[Any] = None,
864-
) -> Mapping[str, Any]:
877+
) -> Optional[Mapping[str, Any]]:
865878
"""Internal update / replace helper."""
866879

867880
def _update(
@@ -1737,7 +1750,7 @@ def find_raw_batches(self, *args: Any, **kwargs: Any) -> RawBatchCursor[_Documen
17371750

17381751
def _count_cmd(
17391752
self,
1740-
session: ClientSession,
1753+
session: Optional[ClientSession],
17411754
conn: Connection,
17421755
read_preference: Optional[_ServerMode],
17431756
cmd: Mapping[str, Any],
@@ -1766,7 +1779,7 @@ def _aggregate_one_result(
17661779
read_preference: Optional[_ServerMode],
17671780
cmd: Mapping[str, Any],
17681781
collation: Optional[_CollationIn],
1769-
session: ClientSession,
1782+
session: Optional[ClientSession],
17701783
) -> Optional[Mapping[str, Any]]:
17711784
"""Internal helper to run an aggregate that returns a single result."""
17721785
result = self._command(
@@ -1819,7 +1832,7 @@ def estimated_document_count(self, comment: Optional[Any] = None, **kwargs: Any)
18191832
kwargs["comment"] = comment
18201833

18211834
def _cmd(
1822-
session: ClientSession,
1835+
session: Optional[ClientSession],
18231836
server: Server,
18241837
conn: Connection,
18251838
read_preference: Optional[_ServerMode],
@@ -1908,7 +1921,7 @@ def count_documents(
19081921
cmd.update(kwargs)
19091922

19101923
def _cmd(
1911-
session: ClientSession,
1924+
session: Optional[ClientSession],
19121925
server: Server,
19131926
conn: Connection,
19141927
read_preference: Optional[_ServerMode],
@@ -1922,7 +1935,7 @@ def _cmd(
19221935

19231936
def _retryable_non_cursor_read(
19241937
self,
1925-
func: Callable[[ClientSession, Server, Connection, Optional[_ServerMode]], T],
1938+
func: Callable[[Optional[ClientSession], Server, Connection, Optional[_ServerMode]], T],
19261939
session: Optional[ClientSession],
19271940
) -> T:
19281941
"""Non-cursor read helper to handle implicit session creation."""
@@ -2276,18 +2289,19 @@ def list_indexes(
22762289
.. versionadded:: 3.0
22772290
"""
22782291
codec_options: CodecOptions = CodecOptions(SON)
2279-
coll = self.with_options(
2280-
codec_options=codec_options, read_preference=ReadPreference.PRIMARY
2292+
coll = cast(
2293+
Collection[MutableMapping[str, Any]],
2294+
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),
22812295
)
22822296
read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY
22832297
explicit_session = session is not None
22842298

22852299
def _cmd(
2286-
session: ClientSession,
2300+
session: Optional[ClientSession],
22872301
server: Server,
22882302
conn: Connection,
22892303
read_preference: _ServerMode,
2290-
) -> CommandCursor[_DocumentType]:
2304+
) -> CommandCursor[MutableMapping[str, Any]]:
22912305
cmd = SON([("listIndexes", self.__name), ("cursor", {})])
22922306
if comment is not None:
22932307
cmd["comment"] = comment
@@ -2404,7 +2418,7 @@ def list_search_indexes(
24042418

24052419
return self.__database.client._retryable_read(
24062420
cmd.get_cursor,
2407-
cmd.get_read_preference(session),
2421+
cmd.get_read_preference(session), # type: ignore[arg-type]
24082422
session,
24092423
retryable=not cmd._performs_write,
24102424
)
@@ -2618,7 +2632,7 @@ def _aggregate(
26182632
let: Optional[Mapping[str, Any]] = None,
26192633
comment: Optional[Any] = None,
26202634
**kwargs: Any,
2621-
) -> Union[CommandCursor[_DocumentType], RawBatchCursor[_DocumentType]]:
2635+
) -> CommandCursor[_DocumentType]:
26222636
if comment is not None:
26232637
kwargs["comment"] = comment
26242638
cmd = aggregation_command(
@@ -2633,7 +2647,7 @@ def _aggregate(
26332647

26342648
return self.__database.client._retryable_read(
26352649
cmd.get_cursor,
2636-
cmd.get_read_preference(session),
2650+
cmd.get_read_preference(session), # type: ignore[arg-type]
26372651
session,
26382652
retryable=not cmd._performs_write,
26392653
)
@@ -2724,18 +2738,15 @@ def aggregate(
27242738
https://mongodb.com/docs/manual/reference/command/aggregate
27252739
"""
27262740
with self.__database.client._tmp_session(session, close=False) as s:
2727-
return cast(
2728-
CommandCursor[_DocumentType],
2729-
self._aggregate(
2730-
_CollectionAggregationCommand,
2731-
pipeline,
2732-
CommandCursor,
2733-
session=s,
2734-
explicit_session=session is not None,
2735-
let=let,
2736-
comment=comment,
2737-
**kwargs,
2738-
),
2741+
return self._aggregate(
2742+
_CollectionAggregationCommand,
2743+
pipeline,
2744+
CommandCursor,
2745+
session=s,
2746+
explicit_session=session is not None,
2747+
let=let,
2748+
comment=comment,
2749+
**kwargs,
27392750
)
27402751

27412752
def aggregate_raw_batches(
@@ -3047,7 +3058,7 @@ def distinct(
30473058
cmd["comment"] = comment
30483059

30493060
def _cmd(
3050-
session: ClientSession,
3061+
session: Optional[ClientSession],
30513062
server: Server,
30523063
conn: Connection,
30533064
read_preference: Optional[_ServerMode],
@@ -3112,7 +3123,7 @@ def __find_and_modify(
31123123
write_concern = self._write_concern_for_cmd(cmd, session)
31133124

31143125
def _find_and_modify(
3115-
session: ClientSession, conn: Connection, retryable_write: bool
3126+
session: Optional[ClientSession], conn: Connection, retryable_write: bool
31163127
) -> Any:
31173128
acknowledged = write_concern.acknowledged
31183129
if array_filters is not None:

pymongo/command_cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __send_message(self, operation: _GetMore) -> None:
205205
self.__id = cursor["id"]
206206
else:
207207
documents = response.docs
208+
assert isinstance(response.data, _OpReply)
208209
self.__id = response.data.cursor_id
209210

210211
if self.__id == 0:

pymongo/cursor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,21 @@
5050
from pymongo.message import (
5151
_CursorAddress,
5252
_GetMore,
53+
_OpMsg,
54+
_OpReply,
5355
_Query,
5456
_RawBatchGetMore,
5557
_RawBatchQuery,
5658
)
5759
from pymongo.response import PinnedResponse
58-
from pymongo.typings import _CollationIn, _DocumentType
60+
from pymongo.typings import _Address, _CollationIn, _DocumentType
5961

6062
if TYPE_CHECKING:
6163
from _typeshed import SupportsItems
6264

6365
from bson.codec_options import CodecOptions
6466
from pymongo.client_session import ClientSession
6567
from pymongo.collection import Collection
66-
from pymongo.message import _OpMsg, _OpReply
6768
from pymongo.pool import Connection
6869
from pymongo.read_preferences import _ServerMode
6970

@@ -298,7 +299,7 @@ def __init__(
298299
self.__empty = False
299300

300301
self.__data: deque = deque()
301-
self.__address = None
302+
self.__address: Optional[_Address] = None
302303
self.__retrieved = 0
303304

304305
self.__codec_options = collection.codec_options
@@ -1108,6 +1109,7 @@ def __send_message(self, operation: Union[_Query, _GetMore]) -> None:
11081109
self.__data = deque(docs)
11091110
self.__retrieved += len(docs)
11101111
else:
1112+
assert isinstance(response.data, _OpReply)
11111113
self.__id = response.data.cursor_id
11121114
self.__data = deque(docs)
11131115
self.__retrieved += response.data.number_returned

pymongo/database.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def aggregate(
553553
user_fields={"cursor": {"firstBatch": 1}},
554554
)
555555
return self.client._retryable_read(
556-
cmd.get_cursor, cmd.get_read_preference(s), s, retryable=not cmd._performs_write
556+
cmd.get_cursor, cmd.get_read_preference(s), s, retryable=not cmd._performs_write # type: ignore[arg-type]
557557
)
558558

559559
def watch(
@@ -1033,9 +1033,12 @@ def _list_collections(
10331033
session: Optional[ClientSession],
10341034
read_preference: _ServerMode,
10351035
**kwargs: Any,
1036-
) -> CommandCursor:
1036+
) -> CommandCursor[MutableMapping[str, Any]]:
10371037
"""Internal listCollections helper."""
1038-
coll = self.get_collection("$cmd", read_preference=read_preference)
1038+
coll = cast(
1039+
Collection[MutableMapping[str, Any]],
1040+
self.get_collection("$cmd", read_preference=read_preference),
1041+
)
10391042
cmd = SON([("listCollections", 1), ("cursor", {})])
10401043
cmd.update(kwargs)
10411044
with self.__client._tmp_session(session, close=False) as tmp_session:
@@ -1059,7 +1062,7 @@ def list_collections(
10591062
filter: Optional[Mapping[str, Any]] = None,
10601063
comment: Optional[Any] = None,
10611064
**kwargs: Any,
1062-
) -> CommandCursor[Dict[str, Any]]:
1065+
) -> CommandCursor[MutableMapping[str, Any]]:
10631066
"""Get a cursor over the collections of this database.
10641067
10651068
:Parameters:
@@ -1092,7 +1095,7 @@ def _cmd(
10921095
server: Server,
10931096
conn: Connection,
10941097
read_preference: _ServerMode,
1095-
) -> CommandCursor[_DocumentType]:
1098+
) -> CommandCursor[MutableMapping[str, Any]]:
10961099
return self._list_collections(conn, session, read_preference=read_preference, **kwargs)
10971100

10981101
return self.__client._retryable_read(_cmd, read_pref, session)

0 commit comments

Comments
 (0)