Skip to content

Commit a8ac735

Browse files
committed
mcp: batch server change notifications
Batch notifications for changes in feature sets (adding or removing tools, prompts, resources, or resource templates). Delay sending the notifications until a certain number have occurred, or until some time has passed. This PR handles server-side change notifications only. We could do the same for client-side notifications (e.g. roots changed), but at the moment we don't need to. BUG: notifications sent after a connection closes result in an error message printed out. I'm not sure how to avoid that. It would be racy to check for a closed connection just before sending, though it would reduce the number of messages considerably. Fixes #649.
1 parent 21fb03d commit a8ac735

File tree

2 files changed

+62
-20
lines changed

2 files changed

+62
-20
lines changed

mcp/mcp_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ func TestEndToEnd(t *testing.T) {
6161
var ct, st Transport = NewInMemoryTransports()
6262

6363
// Channels to check if notification callbacks happened.
64+
// These test asynchronous sending of notifications after a small delay (see
65+
// Server.sendNotification).
6466
notificationChans := map[string]chan int{}
6567
for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe", "elicitation_complete"} {
6668
notificationChans[name] = make(chan int, 1)
@@ -1695,14 +1697,15 @@ func TestSynchronousNotifications(t *testing.T) {
16951697
},
16961698
}
16971699
server := NewServer(testImpl, serverOpts)
1698-
cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) {
1700+
addTool := func(s *Server) {
16991701
AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) {
17001702
if !rootsChanged.Load() {
17011703
return nil, nil, fmt.Errorf("didn't get root change notification")
17021704
}
17031705
return new(CallToolResult), nil, nil
17041706
})
1705-
})
1707+
}
1708+
cs, ss, cleanup := basicClientServerConnection(t, client, server, addTool)
17061709
defer cleanup()
17071710

17081711
t.Run("from client", func(t *testing.T) {
@@ -1717,7 +1720,11 @@ func TestSynchronousNotifications(t *testing.T) {
17171720
})
17181721

17191722
t.Run("from server", func(t *testing.T) {
1720-
server.RemoveTools("tool")
1723+
// Because server change notifications are batched, we must generate a lot of them.
1724+
for range maxPendingNotifications/2 + 1 {
1725+
server.RemoveTools("tool")
1726+
addTool(server)
1727+
}
17211728
if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil {
17221729
t.Errorf("CreateMessage failed: %v", err)
17231730
}

mcp/server.go

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type Server struct {
5151
sendingMethodHandler_ MethodHandler
5252
receivingMethodHandler_ MethodHandler
5353
resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool
54+
pendingNotifications map[string]int // notification name -> count of unsent changes
5455
}
5556

5657
// ServerOptions is used to configure behavior of the server.
@@ -149,6 +150,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server {
149150
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
150151
receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession],
151152
resourceSubscriptions: make(map[string]map[*ServerSession]bool),
153+
pendingNotifications: make(map[string]int),
152154
}
153155
}
154156

@@ -158,15 +160,13 @@ func (s *Server) AddPrompt(p *Prompt, h PromptHandler) {
158160
// (It's possible an item was replaced with an identical one, but not worth checking.)
159161
s.changeAndNotify(
160162
notificationPromptListChanged,
161-
&PromptListChangedParams{},
162163
func() bool { s.prompts.add(&serverPrompt{p, h}); return true })
163164
}
164165

165166
// RemovePrompts removes the prompts with the given names.
166167
// It is not an error to remove a nonexistent prompt.
167168
func (s *Server) RemovePrompts(names ...string) {
168-
s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{},
169-
func() bool { return s.prompts.remove(names...) })
169+
s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) })
170170
}
171171

172172
// AddTool adds a [Tool] to the server, or replaces one with the same name.
@@ -235,8 +235,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
235235
// (It's possible a tool was replaced with an identical one, but not worth checking.)
236236
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
237237
// TODO: Surface notify error here? best not, in case we need to batch.
238-
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
239-
func() bool { s.tools.add(st); return true })
238+
s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true })
240239
}
241240

242241
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
@@ -419,14 +418,13 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
419418
// RemoveTools removes the tools with the given names.
420419
// It is not an error to remove a nonexistent tool.
421420
func (s *Server) RemoveTools(names ...string) {
422-
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
423-
func() bool { return s.tools.remove(names...) })
421+
s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) })
424422
}
425423

426424
// AddResource adds a [Resource] to the server, or replaces one with the same URI.
427425
// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme).
428426
func (s *Server) AddResource(r *Resource, h ResourceHandler) {
429-
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
427+
s.changeAndNotify(notificationResourceListChanged,
430428
func() bool {
431429
if _, err := url.Parse(r.URI); err != nil {
432430
panic(err) // url.Parse includes the URI in the error
@@ -439,14 +437,13 @@ func (s *Server) AddResource(r *Resource, h ResourceHandler) {
439437
// RemoveResources removes the resources with the given URIs.
440438
// It is not an error to remove a nonexistent resource.
441439
func (s *Server) RemoveResources(uris ...string) {
442-
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
443-
func() bool { return s.resources.remove(uris...) })
440+
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) })
444441
}
445442

446443
// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI.
447444
// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme).
448445
func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
449-
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
446+
s.changeAndNotify(notificationResourceListChanged,
450447
func() bool {
451448
// Validate the URI template syntax
452449
_, err := uritemplate.New(t.URITemplate)
@@ -461,8 +458,7 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) {
461458
// RemoveResourceTemplates removes the resource templates with the given URI templates.
462459
// It is not an error to remove a nonexistent resource.
463460
func (s *Server) RemoveResourceTemplates(uriTemplates ...string) {
464-
s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{},
465-
func() bool { return s.resourceTemplates.remove(uriTemplates...) })
461+
s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) })
466462
}
467463

468464
func (s *Server) capabilities() *ServerCapabilities {
@@ -497,18 +493,57 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR
497493
return s.opts.CompletionHandler(ctx, req)
498494
}
499495

496+
// Map from notification name to its corresponding params. The params have no fields,
497+
// so a single struct can be reused.
498+
var changeNotificationParams = map[string]Params{
499+
notificationToolListChanged: &ToolListChangedParams{},
500+
notificationPromptListChanged: &PromptListChangedParams{},
501+
notificationResourceListChanged: &ResourceListChangedParams{},
502+
}
503+
504+
// The maximum number of change notifications of a particular type (e.g. tools-changed)
505+
// that can be pending.
506+
const maxPendingNotifications = 10
507+
508+
// How long to wait before sending a change notification.
509+
var notificationDelay = 50 * time.Millisecond
510+
500511
// changeAndNotify is called when a feature is added or removed.
501512
// It calls change, which should do the work and report whether a change actually occurred.
502513
// If there was a change, it notifies a snapshot of the sessions.
503-
func (s *Server) changeAndNotify(notification string, params Params, change func() bool) {
514+
func (s *Server) changeAndNotify(notification string, change func() bool) {
504515
var sessions []*ServerSession
505-
// Lock for the change, but not for the notification.
516+
send := false
506517
s.mu.Lock()
507518
if change() {
508-
sessions = slices.Clone(s.sessions)
519+
pending := s.pendingNotifications[notification]
520+
if pending >= maxPendingNotifications {
521+
send = true
522+
pending = 0
523+
// Make a local copy of the session list so we can use it without holding the lock.
524+
sessions = slices.Clone(s.sessions)
525+
} else {
526+
pending++
527+
if pending == 1 {
528+
time.AfterFunc(notificationDelay, func() { s.sendNotification(notification) })
529+
}
530+
}
531+
s.pendingNotifications[notification] = pending
532+
}
533+
s.mu.Unlock() // Don't hold lock during notifications.
534+
if send {
535+
notifySessions(sessions, notification, changeNotificationParams[notification])
509536
}
537+
}
538+
539+
// sendNotification is called asynchronously to ensure that notifications are sent
540+
// soon after they occur.
541+
func (s *Server) sendNotification(n string) {
542+
s.mu.Lock()
543+
sessions := slices.Clone(s.sessions)
544+
s.pendingNotifications[n] = 0
510545
s.mu.Unlock()
511-
notifySessions(sessions, notification, params)
546+
notifySessions(sessions, n, changeNotificationParams[n])
512547
}
513548

514549
// Sessions returns an iterator that yields the current set of server sessions.

0 commit comments

Comments
 (0)