1818
1919import java .io .ByteArrayInputStream ;
2020import java .io .ByteArrayOutputStream ;
21+ import java .io .Closeable ;
2122import java .io .IOException ;
2223import java .io .InputStream ;
2324import java .io .OutputStream ;
2425import java .net .URI ;
2526import java .time .Duration ;
2627import java .util .Arrays ;
2728import java .util .Collections ;
29+ import java .util .LinkedHashMap ;
2830import java .util .List ;
2931import java .util .Map ;
3032import java .util .concurrent .ConcurrentHashMap ;
4345import reactor .core .scheduler .Scheduler ;
4446import reactor .core .scheduler .Schedulers ;
4547
48+ import org .springframework .graphql .execution .ThreadLocalAccessor ;
4649import org .springframework .graphql .server .WebGraphQlHandler ;
4750import org .springframework .graphql .server .WebGraphQlRequest ;
4851import org .springframework .graphql .server .WebGraphQlResponse ;
5356import org .springframework .http .HttpOutputMessage ;
5457import org .springframework .http .converter .GenericHttpMessageConverter ;
5558import org .springframework .http .converter .HttpMessageConverter ;
59+ import org .springframework .http .server .ServerHttpRequest ;
60+ import org .springframework .http .server .ServerHttpResponse ;
5661import org .springframework .lang .Nullable ;
5762import org .springframework .util .Assert ;
5863import org .springframework .util .CollectionUtils ;
5964import org .springframework .web .socket .CloseStatus ;
6065import org .springframework .web .socket .SubProtocolCapable ;
6166import org .springframework .web .socket .TextMessage ;
67+ import org .springframework .web .socket .WebSocketHandler ;
6268import org .springframework .web .socket .WebSocketSession ;
6369import org .springframework .web .socket .handler .ExceptionWebSocketHandlerDecorator ;
6470import org .springframework .web .socket .handler .TextWebSocketHandler ;
71+ import org .springframework .web .socket .server .HandshakeHandler ;
72+ import org .springframework .web .socket .server .HandshakeInterceptor ;
73+ import org .springframework .web .socket .server .support .WebSocketHttpRequestHandler ;
6574
6675/**
6776 * WebSocketHandler for GraphQL based on
@@ -81,7 +90,9 @@ public class GraphQlWebSocketHandler extends TextWebSocketHandler implements Sub
8190
8291 private final WebGraphQlHandler graphQlHandler ;
8392
84- private final WebSocketGraphQlInterceptor webSocketInterceptor ;
93+ private final ContextHandshakeInterceptor contextHandshakeInterceptor ;
94+
95+ private final WebSocketGraphQlInterceptor webSocketGraphQlInterceptor ;
8596
8697 private final Duration initTimeoutDuration ;
8798
@@ -103,7 +114,8 @@ public GraphQlWebSocketHandler(
103114 Assert .notNull (converter , "HttpMessageConverter for JSON is required" );
104115
105116 this .graphQlHandler = graphQlHandler ;
106- this .webSocketInterceptor = this .graphQlHandler .webSocketInterceptor ();
117+ this .contextHandshakeInterceptor = new ContextHandshakeInterceptor (graphQlHandler .getThreadLocalAccessor ());
118+ this .webSocketGraphQlInterceptor = this .graphQlHandler .getWebSocketInterceptor ();
107119 this .initTimeoutDuration = connectionInitTimeout ;
108120 this .converter = converter ;
109121 }
@@ -113,6 +125,18 @@ public List<String> getSubProtocols() {
113125 return SUB_PROTOCOL_LIST ;
114126 }
115127
128+ /**
129+ * Return a {@link WebSocketHttpRequestHandler} that uses this instance as
130+ * its {@link WebGraphQlHandler} and adds a {@link HandshakeInterceptor} to
131+ * propagate context.
132+ */
133+ public WebSocketHttpRequestHandler asWebSocketHttpRequestHandler (HandshakeHandler handshakeHandler ) {
134+ WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler (this , handshakeHandler );
135+ handler .setHandshakeInterceptors (Collections .singletonList (this .contextHandshakeInterceptor ));
136+ return handler ;
137+ }
138+
139+
116140 @ Override
117141 public void afterConnectionEstablished (WebSocketSession session ) {
118142 if ("graphql-ws" .equalsIgnoreCase (session .getAcceptedProtocol ())) {
@@ -137,8 +161,15 @@ public void afterConnectionEstablished(WebSocketSession session) {
137161
138162 }
139163
164+ @ SuppressWarnings ({"unused" , "try" })
140165 @ Override
141166 protected void handleTextMessage (WebSocketSession session , TextMessage webSocketMessage ) throws Exception {
167+ try (Closeable closeable = this .contextHandshakeInterceptor .restoreThreadLocalValue (session )) {
168+ handleInternal (session , webSocketMessage );
169+ }
170+ }
171+
172+ private void handleInternal (WebSocketSession session , TextMessage webSocketMessage ) throws IOException {
142173 GraphQlWebSocketMessage message = decode (webSocketMessage );
143174 String id = message .getId ();
144175 Map <String , Object > payload = message .getPayload ();
@@ -174,7 +205,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
174205 if (subscription != null ) {
175206 subscription .cancel ();
176207 }
177- this .webSocketInterceptor .handleCancelledSubscription (session .getId (), id )
208+ this .webSocketGraphQlInterceptor .handleCancelledSubscription (session .getId (), id )
178209 .block (Duration .ofSeconds (10 ));
179210 }
180211 return ;
@@ -183,7 +214,7 @@ protected void handleTextMessage(WebSocketSession session, TextMessage webSocket
183214 GraphQlStatus .closeSession (session , GraphQlStatus .TOO_MANY_INIT_REQUESTS_STATUS );
184215 return ;
185216 }
186- this .webSocketInterceptor .handleConnectionInitialization (session .getId (), payload )
217+ this .webSocketGraphQlInterceptor .handleConnectionInitialization (session .getId (), payload )
187218 .defaultIfEmpty (Collections .emptyMap ())
188219 .publishOn (sessionState .getScheduler ()) // Serial blocking send via single thread
189220 .doOnNext (ackPayload -> {
@@ -285,7 +316,7 @@ public void afterConnectionClosed(WebSocketSession session, CloseStatus closeSta
285316 info .dispose ();
286317 Map <String , Object > connectionInitPayload = info .getConnectionInitPayload ();
287318 if (connectionInitPayload != null ) {
288- this .webSocketInterceptor .handleConnectionClosed (id , closeStatus .getCode (), connectionInitPayload );
319+ this .webSocketGraphQlInterceptor .handleConnectionClosed (id , closeStatus .getCode (), connectionInitPayload );
289320 }
290321 }
291322 }
@@ -296,6 +327,57 @@ public boolean supportsPartialMessages() {
296327 }
297328
298329
330+ /**
331+ * {@code HandshakeInterceptor} that propagates ThreadLocal context through
332+ * the attributes map in {@code WebSocketSession}.
333+ */
334+ private static class ContextHandshakeInterceptor implements HandshakeInterceptor {
335+
336+ private static final String SAVED_CONTEXT_KEY = ContextHandshakeInterceptor .class .getName ();
337+
338+ @ Nullable
339+ private final ThreadLocalAccessor accessor ;
340+
341+ ContextHandshakeInterceptor (@ Nullable ThreadLocalAccessor accessor ) {
342+ this .accessor = accessor ;
343+ }
344+
345+ @ Override
346+ public boolean beforeHandshake (
347+ ServerHttpRequest request , ServerHttpResponse response , WebSocketHandler wsHandler ,
348+ Map <String , Object > attributes ) {
349+
350+ if (this .accessor != null ) {
351+ Map <String , Object > valuesMap = new LinkedHashMap <>();
352+ this .accessor .extractValues (valuesMap );
353+ attributes .put (SAVED_CONTEXT_KEY , valuesMap );
354+ }
355+ return true ;
356+ }
357+
358+ @ Override
359+ public void afterHandshake (
360+ ServerHttpRequest request , ServerHttpResponse response , WebSocketHandler wsHandler ,
361+ @ Nullable Exception exception ) {
362+ }
363+
364+ @ SuppressWarnings ("unchecked" )
365+ public Closeable restoreThreadLocalValue (WebSocketSession session ) {
366+ if (this .accessor != null ) {
367+ Map <String , Object > valuesMap = (Map <String , Object >) session .getAttributes ().get (SAVED_CONTEXT_KEY );
368+ // Uncomment when Boot is updated to use HandshakeInterceptor
369+ // Assert.state(valuesMap != null, "No context");
370+ if (valuesMap != null ) {
371+ this .accessor .restoreValues (valuesMap );
372+ return () -> this .accessor .resetValues (valuesMap );
373+ }
374+ }
375+ return () -> {};
376+ }
377+
378+ }
379+
380+
299381 private static class GraphQlStatus {
300382
301383 private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus (4400 , "Invalid message" );
0 commit comments