11package com .introproventures .graphql .jpa .query .autoconfigure ;
22
3+ import static graphql .Assert .assertTrue ;
4+ import static graphql .schema .FieldCoordinates .coordinates ;
5+ import static graphql .util .TraversalControl .CONTINUE ;
6+
7+ import java .util .Collection ;
38import java .util .List ;
49import java .util .Objects ;
10+ import java .util .Optional ;
511import java .util .stream .Collectors ;
612import java .util .stream .Stream ;
713
814import org .springframework .beans .factory .config .AbstractFactoryBean ;
915
16+ import graphql .Internal ;
17+ import graphql .schema .DataFetcher ;
18+ import graphql .schema .FieldCoordinates ;
19+ import graphql .schema .GraphQLCodeRegistry ;
1020import graphql .schema .GraphQLFieldDefinition ;
21+ import graphql .schema .GraphQLFieldsContainer ;
22+ import graphql .schema .GraphQLInterfaceType ;
1123import graphql .schema .GraphQLObjectType ;
1224import graphql .schema .GraphQLSchema ;
25+ import graphql .schema .GraphQLType ;
26+ import graphql .schema .GraphQLTypeVisitorStub ;
27+ import graphql .schema .GraphQLUnionType ;
28+ import graphql .schema .PropertyDataFetcher ;
29+ import graphql .schema .TypeResolver ;
30+ import graphql .schema .TypeTraverser ;
31+ import graphql .util .TraversalControl ;
32+ import graphql .util .TraverserContext ;
1333
1434public class GraphQLSchemaFactoryBean extends AbstractFactoryBean <GraphQLSchema >{
15-
16- private static final String QUERY_NAME = "Query" ;
35+
36+ private static final String QUERY_NAME = "Query" ;
1737 private static final String QUERY_DESCRIPTION = "" ;
1838 private static final String SUBSCRIPTION_NAME = "Subscription" ;
1939 private static final String SUBSCRIPTION_DESCRIPTION = "" ;
@@ -22,85 +42,118 @@ public class GraphQLSchemaFactoryBean extends AbstractFactoryBean<GraphQLSchema>
2242
2343
2444 private final GraphQLSchema [] managedGraphQLSchemas ;
25-
26- private String queryName = QUERY_NAME ;
27- private String queryDescription = QUERY_DESCRIPTION ;
45+
46+ private String queryName = QUERY_NAME ;
47+ private String queryDescription = QUERY_DESCRIPTION ;
2848
2949 private String subscriptionName = SUBSCRIPTION_NAME ;
3050 private String subscriptionDescription = SUBSCRIPTION_DESCRIPTION ;
3151
3252 private String mutationName = MUTATION_NAME ;
3353 private String mutationDescription = MUTATION_DESCRIPTION ;
3454
35-
36- public GraphQLSchemaFactoryBean (GraphQLSchema [] managedGraphQLSchemas ) {
37- this .managedGraphQLSchemas = managedGraphQLSchemas ;
38- }
39-
40- @ Override
41- protected GraphQLSchema createInstance () throws Exception {
42-
43- GraphQLSchema .Builder schemaBuilder = GraphQLSchema .newSchema ();
44-
45- List <GraphQLFieldDefinition > mutations = Stream .of (managedGraphQLSchemas )
46- .map (GraphQLSchema ::getMutationType )
47- .filter (Objects ::nonNull )
48- .map (GraphQLObjectType ::getFieldDefinitions )
49- .flatMap (children -> children .stream ())
50- .collect (Collectors .toList ());
51-
52- List <GraphQLFieldDefinition > queries = Stream .of (managedGraphQLSchemas )
53- .map (GraphQLSchema ::getQueryType )
54- .filter (Objects ::nonNull )
55- .filter (it -> !it .getName ().equals ("null" )) // filter out null placeholders
56- .map (GraphQLObjectType ::getFieldDefinitions )
57- .flatMap (children -> children .stream ())
58- .collect (Collectors .toList ());
59-
60- List <GraphQLFieldDefinition > subscriptions = Stream .of (managedGraphQLSchemas )
61- .map (GraphQLSchema ::getSubscriptionType )
62- .filter (Objects ::nonNull )
63- .map (GraphQLObjectType ::getFieldDefinitions )
64- .flatMap (children -> children .stream ())
65- .collect (Collectors .toList ());
66-
67- if (!mutations .isEmpty ())
68- schemaBuilder .mutation (GraphQLObjectType .newObject ()
55+
56+ public GraphQLSchemaFactoryBean (GraphQLSchema [] managedGraphQLSchemas ) {
57+ this .managedGraphQLSchemas = managedGraphQLSchemas ;
58+ }
59+
60+ @ Override
61+ protected GraphQLSchema createInstance () throws Exception {
62+
63+ GraphQLSchema .Builder schemaBuilder = GraphQLSchema .newSchema ();
64+ GraphQLCodeRegistry .Builder codeRegistryBuilder = GraphQLCodeRegistry .newCodeRegistry ();
65+ TypeTraverser typeTraverser = new TypeTraverser ();
66+
67+ List <GraphQLFieldDefinition > mutations = Stream .of (managedGraphQLSchemas )
68+ .filter (it -> it .getMutationType () != null )
69+ .peek (schema -> {
70+ schema .getCodeRegistry ().transform (builderConsumer -> {
71+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
72+ codeRegistryBuilder ,
73+ schema .getMutationType (),
74+ mutationName ),
75+ schema .getMutationType ());
76+ });
77+ })
78+ .map (GraphQLSchema ::getMutationType )
79+ .filter (Objects ::nonNull )
80+ .map (GraphQLObjectType ::getFieldDefinitions )
81+ .flatMap (Collection ::stream )
82+ .collect (Collectors .toList ());
83+
84+ List <GraphQLFieldDefinition > queries = Stream .of (managedGraphQLSchemas )
85+ .filter (it -> Optional .ofNullable (it .getQueryType ())
86+ .map (GraphQLType ::getName )
87+ .filter (name -> !"null" .equals (name )) // filter out null placeholders
88+ .isPresent ())
89+ .peek (schema -> {
90+ schema .getCodeRegistry ().transform (builderConsumer -> {
91+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
92+ codeRegistryBuilder ,
93+ schema .getQueryType (),
94+ queryName ),
95+ schema .getQueryType ());
96+ });
97+ })
98+ .map (GraphQLSchema ::getQueryType )
99+ .map (GraphQLObjectType ::getFieldDefinitions )
100+ .flatMap (Collection ::stream )
101+ .collect (Collectors .toList ());
102+
103+ List <GraphQLFieldDefinition > subscriptions = Stream .of (managedGraphQLSchemas )
104+ .filter (it -> it .getSubscriptionType () != null )
105+ .peek (schema -> {
106+ schema .getCodeRegistry ().transform (builderConsumer -> {
107+ typeTraverser .depthFirst (new CodeRegistryVisitor (builderConsumer ,
108+ codeRegistryBuilder ,
109+ schema .getSubscriptionType (),
110+ subscriptionName ),
111+ schema .getSubscriptionType ());
112+ });
113+ })
114+ .map (GraphQLSchema ::getSubscriptionType )
115+ .map (GraphQLObjectType ::getFieldDefinitions )
116+ .flatMap (Collection ::stream )
117+ .collect (Collectors .toList ());
118+
119+ if (!mutations .isEmpty ())
120+ schemaBuilder .mutation (GraphQLObjectType .newObject ()
69121 .name (this .mutationName )
70122 .description (this .mutationDescription )
71- .fields (mutations ));
123+ .fields (mutations ));
72124
73- if (!queries .isEmpty ())
74- schemaBuilder .query (GraphQLObjectType .newObject ()
75- .name (this .queryName )
76- .description (this .queryDescription )
77- .fields (queries ));
125+ if (!queries .isEmpty ())
126+ schemaBuilder .query (GraphQLObjectType .newObject ()
127+ .name (this .queryName )
128+ .description (this .queryDescription )
129+ .fields (queries ));
78130
79- if (!subscriptions .isEmpty ())
80- schemaBuilder .subscription (GraphQLObjectType .newObject ()
131+ if (!subscriptions .isEmpty ())
132+ schemaBuilder .subscription (GraphQLObjectType .newObject ()
81133 .name (this .subscriptionName )
82134 .description (this .subscriptionDescription )
83- .fields (subscriptions ));
84-
85- return schemaBuilder .build ();
86- }
87-
88- @ Override
89- public Class <?> getObjectType () {
90- return GraphQLSchema .class ;
91- }
92-
93- public GraphQLSchemaFactoryBean setQueryName (String name ) {
94- this .queryName = name ;
95-
96- return this ;
97- }
98-
99- public GraphQLSchemaFactoryBean setQueryDescription (String description ) {
100- this .queryDescription = description ;
101-
102- return this ;
103- }
135+ .fields (subscriptions ));
136+
137+ return schemaBuilder .codeRegistry (codeRegistryBuilder .build ())
138+ .build ();
139+ }
140+
141+ @ Override
142+ public Class <?> getObjectType () {
143+ return GraphQLSchema .class ;
144+ }
145+
146+ public GraphQLSchemaFactoryBean setQueryName (String name ) {
147+ this .queryName = name ;
148+
149+ return this ;
150+ }
151+
152+ public GraphQLSchemaFactoryBean setQueryDescription (String description ) {
153+ this .queryDescription = description ;
154+
155+ return this ;
156+ }
104157
105158 public GraphQLSchemaFactoryBean setSubscriptionName (String subscriptionName ) {
106159 this .subscriptionName = subscriptionName ;
@@ -125,5 +178,65 @@ public GraphQLSchemaFactoryBean setMutationDescription(String mutationDescriptio
125178
126179 return this ;
127180 }
128-
181+
182+ /**
183+ * This ensure that all fields have data fetchers and that unions and interfaces have type resolvers
184+ */
185+ @ Internal
186+ class CodeRegistryVisitor extends GraphQLTypeVisitorStub {
187+ private final GraphQLCodeRegistry .Builder source ;
188+ private final GraphQLCodeRegistry .Builder codeRegistry ;
189+ private final GraphQLFieldsContainer containerType ;
190+ private final String typeName ;
191+
192+ CodeRegistryVisitor (GraphQLCodeRegistry .Builder context ,
193+ GraphQLCodeRegistry .Builder codeRegistry ,
194+ GraphQLFieldsContainer containerType ,
195+ String typeName ) {
196+ this .source = context ;
197+ this .codeRegistry = codeRegistry ;
198+ this .containerType = containerType ;
199+ this .typeName = typeName ;
200+ }
201+
202+ @ Override
203+ public TraversalControl visitGraphQLFieldDefinition (GraphQLFieldDefinition node , TraverserContext <GraphQLType > context ) {
204+ GraphQLFieldsContainer parentContainerType = (GraphQLFieldsContainer ) context .getParentContext ().thisNode ();
205+ FieldCoordinates coordinates = parentContainerType .equals (containerType ) ? coordinates (typeName , node .getName ())
206+ : coordinates (parentContainerType , node );
207+
208+ DataFetcher <?> dataFetcher = source .getDataFetcher (parentContainerType ,
209+ node );
210+ if (dataFetcher == null ) {
211+ dataFetcher = new PropertyDataFetcher <>(node .getName ());
212+ }
213+
214+ codeRegistry .dataFetcherIfAbsent (coordinates ,
215+ dataFetcher );
216+ return CONTINUE ;
217+ }
218+
219+ @ Override
220+ public TraversalControl visitGraphQLInterfaceType (GraphQLInterfaceType node , TraverserContext <GraphQLType > context ) {
221+ TypeResolver typeResolver = codeRegistry .getTypeResolver (node );
222+
223+ if (typeResolver != null ) {
224+ codeRegistry .typeResolverIfAbsent (node , typeResolver );
225+ }
226+ assertTrue (codeRegistry .getTypeResolver (node ) != null , "You MUST provide a type resolver for the interface type '" + node .getName () + "'" );
227+ return CONTINUE ;
228+ }
229+
230+ @ Override
231+ public TraversalControl visitGraphQLUnionType (GraphQLUnionType node , TraverserContext <GraphQLType > context ) {
232+ TypeResolver typeResolver = codeRegistry .getTypeResolver (node );
233+ if (typeResolver != null ) {
234+ codeRegistry .typeResolverIfAbsent (node , typeResolver );
235+ }
236+ assertTrue (codeRegistry .getTypeResolver (node ) != null , "You MUST provide a type resolver for the union type '" + node .getName () + "'" );
237+ return CONTINUE ;
238+ }
239+ }
240+
241+
129242}
0 commit comments