8686
8787__all__ = [
8888 "extend_schema" ,
89+ "extend_schema_impl" ,
8990 "get_description" ,
9091 "ASTDefinitionBuilder" ,
9192]
@@ -125,6 +126,18 @@ def extend_schema(
125126
126127 assert_valid_sdl_extension (document_ast , schema )
127128
129+ schema_kwargs = schema .to_kwargs ()
130+ extended_kwargs = extend_schema_impl (schema_kwargs , document_ast , assume_valid )
131+ return (
132+ schema if schema_kwargs is extended_kwargs else GraphQLSchema (** extended_kwargs )
133+ )
134+
135+
136+ def extend_schema_impl (
137+ schema_kwargs : Dict [str , Any ], document_ast : DocumentNode , assume_valid = False
138+ ) -> Dict [str , Any ]:
139+ # Note: schema_kwargs should become a TypedDict once we require Python 3.8
140+
128141 # Collect the type definitions and extensions found in the document.
129142 type_defs : List [TypeDefinitionNode ] = []
130143 type_extensions_map : DefaultDict [str , Any ] = defaultdict (list )
@@ -159,7 +172,7 @@ def extend_schema(
159172 and not schema_extensions
160173 and not schema_def
161174 ):
162- return schema
175+ return schema_kwargs
163176
164177 # Below are functions used for producing this schema that have closed over this
165178 # scope and have access to the schema, cache, and newly defined types.
@@ -359,23 +372,15 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
359372 ast_builder = ASTDefinitionBuilder (resolve_type )
360373
361374 type_map = ast_builder .build_type_map (type_defs , type_extensions_map )
362- for existing_type_name , existing_type in schema . type_map . items () :
363- type_map [existing_type_name ] = extend_named_type (existing_type )
375+ for existing_type in schema_kwargs [ "types" ] :
376+ type_map [existing_type . name ] = extend_named_type (existing_type )
364377
365378 # Get the extended root operation types.
366- operation_types : Dict [OperationType , GraphQLObjectType ] = {}
367- if schema .query_type :
368- operation_types [OperationType .QUERY ] = cast (
369- GraphQLObjectType , replace_named_type (schema .query_type )
370- )
371- if schema .mutation_type :
372- operation_types [OperationType .MUTATION ] = cast (
373- GraphQLObjectType , replace_named_type (schema .mutation_type )
374- )
375- if schema .subscription_type :
376- operation_types [OperationType .SUBSCRIPTION ] = cast (
377- GraphQLObjectType , replace_named_type (schema .subscription_type )
378- )
379+ operation_types : Dict [OperationType , GraphQLNamedType ] = {}
380+ for operation_type in OperationType :
381+ original_type = schema_kwargs [operation_type .value ]
382+ if original_type :
383+ operation_types [operation_type ] = replace_named_type (original_type )
379384 # Then, incorporate schema definition and all schema extensions.
380385 if schema_def :
381386 operation_types .update (ast_builder .get_operation_types ([schema_def ]))
@@ -384,26 +389,27 @@ def resolve_type(type_name: str) -> GraphQLNamedType:
384389
385390 # Then produce and return a Schema with these types.
386391 get_operation = operation_types .get
387- return GraphQLSchema (
388- # Note: While this could make early assertions to get the correctly
389- # typed values, that would throw immediately while type system
390- # validation with validateSchema() will produce more actionable results.
391- query = get_operation (OperationType .QUERY ),
392- mutation = get_operation (OperationType .MUTATION ),
393- subscription = get_operation (OperationType .SUBSCRIPTION ),
394- types = type_map .values (),
395- directives = [replace_directive (directive ) for directive in schema .directives ]
392+ return {
393+ "query" : get_operation (OperationType .QUERY ),
394+ "mutation" : get_operation (OperationType .MUTATION ),
395+ "subscription" : get_operation (OperationType .SUBSCRIPTION ),
396+ "types" : type_map .values (),
397+ "directives" : [
398+ replace_directive (directive ) for directive in schema_kwargs ["directives" ]
399+ ]
396400 + ast_builder .build_directives (directive_defs ),
397- ast_node = schema_def or schema .ast_node ,
398- extension_ast_nodes = (
401+ "extensions" : None ,
402+ "ast_node" : schema_def or schema_kwargs ["ast_node" ],
403+ "extension_ast_nodes" : (
399404 (
400- schema . extension_ast_nodes
405+ schema_kwargs [ " extension_ast_nodes" ]
401406 or cast (FrozenList [SchemaExtensionNode ], FrozenList ())
402407 )
403408 + schema_extensions
404409 )
405410 or None ,
406- )
411+ "assume_valid" : assume_valid ,
412+ }
407413
408414
409415def default_type_resolver (type_name : str , * _args ) -> NoReturn :
@@ -427,15 +433,14 @@ def get_operation_types(
427433 # Note: While this could make early assertions to get the correctly
428434 # typed values below, that would throw immediately while type system
429435 # validation with validate_schema() will produce more actionable results.
430- op_types : Dict [OperationType , GraphQLObjectType ] = {}
431- for node in nodes :
432- if node .operation_types :
433- for operation_type in node .operation_types :
434- type_name = operation_type .type .name .value
435- op_types [operation_type .operation ] = cast (
436- GraphQLObjectType , self ._resolve_type (type_name )
437- )
438- return op_types
436+ return {
437+ operation_type .operation : cast (
438+ GraphQLObjectType , self ._resolve_type (operation_type .type .name .value )
439+ )
440+ for node in nodes
441+ if node .operation_types
442+ for operation_type in node .operation_types
443+ }
439444
440445 def get_named_type (self , node : NamedTypeNode ) -> GraphQLNamedType :
441446 name = node .name .value
0 commit comments