diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index c2c949e8..7cdea439 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -61,6 +61,8 @@ func TestEndToEnd(t *testing.T) { var ct, st Transport = NewInMemoryTransports() // Channels to check if notification callbacks happened. + // These test asynchronous sending of notifications after a small delay (see + // Server.sendNotification). notificationChans := map[string]chan int{} for _, name := range []string{"initialized", "roots", "tools", "prompts", "resources", "progress_server", "progress_client", "resource_updated", "subscribe", "unsubscribe", "elicitation_complete"} { notificationChans[name] = make(chan int, 1) @@ -1671,14 +1673,15 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { } func TestSynchronousNotifications(t *testing.T) { - var toolsChanged atomic.Bool + var toolsChanged atomic.Int32 clientOpts := &ClientOptions{ ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) { - toolsChanged.Store(true) + toolsChanged.Add(1) }, CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { - if !toolsChanged.Load() { - return nil, fmt.Errorf("didn't get a tools changed notification") + // See the comment after "from server" below. + if n := toolsChanged.Load(); n != 1 { + return nil, fmt.Errorf("got %d tools-changed notification, wanted 1", n) } // TODO(rfindley): investigate the error returned from this test if // CreateMessageResult is new(CreateMessageResult): it's a mysterious @@ -1695,14 +1698,15 @@ func TestSynchronousNotifications(t *testing.T) { }, } server := NewServer(testImpl, serverOpts) - cs, ss, cleanup := basicClientServerConnection(t, client, server, func(s *Server) { + addTool := func(s *Server) { AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { if !rootsChanged.Load() { return nil, nil, fmt.Errorf("didn't get root change notification") } return new(CallToolResult), nil, nil }) - }) + } + cs, ss, cleanup := basicClientServerConnection(t, client, server, addTool) defer cleanup() t.Run("from client", func(t *testing.T) { @@ -1717,7 +1721,13 @@ func TestSynchronousNotifications(t *testing.T) { }) t.Run("from server", func(t *testing.T) { - server.RemoveTools("tool") + // Despite all this tool-changed activity, we expect only one notification. + for range 10 { + server.RemoveTools("tool") + addTool(server) + } + + time.Sleep(notificationDelay * 2) // Wait for delayed notification. if _, err := ss.CreateMessage(context.Background(), new(CreateMessageParams)); err != nil { t.Errorf("CreateMessage failed: %v", err) } diff --git a/mcp/server.go b/mcp/server.go index 254c2d5e..a5c27587 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -51,6 +51,7 @@ type Server struct { sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool + pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } // ServerOptions is used to configure behavior of the server. @@ -149,6 +150,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), + pendingNotifications: make(map[string]*time.Timer), } } @@ -158,15 +160,13 @@ func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { // (It's possible an item was replaced with an identical one, but not worth checking.) s.changeAndNotify( notificationPromptListChanged, - &PromptListChangedParams{}, func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) } // RemovePrompts removes the prompts with the given names. // It is not an error to remove a nonexistent prompt. func (s *Server) RemovePrompts(names ...string) { - s.changeAndNotify(notificationPromptListChanged, &PromptListChangedParams{}, - func() bool { return s.prompts.remove(names...) }) + s.changeAndNotify(notificationPromptListChanged, func() bool { return s.prompts.remove(names...) }) } // AddTool adds a [Tool] to the server, or replaces one with the same name. @@ -235,8 +235,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // (It's possible a tool was replaced with an identical one, but not worth checking.) // TODO: Batch these changes by size and time? The typescript SDK doesn't. // TODO: Surface notify error here? best not, in case we need to batch. - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(st); return true }) + s.changeAndNotify(notificationToolListChanged, func() bool { s.tools.add(st); return true }) } func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { @@ -419,14 +418,13 @@ func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { // RemoveTools removes the tools with the given names. // It is not an error to remove a nonexistent tool. func (s *Server) RemoveTools(names ...string) { - s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { return s.tools.remove(names...) }) + s.changeAndNotify(notificationToolListChanged, func() bool { return s.tools.remove(names...) }) } // AddResource adds a [Resource] to the server, or replaces one with the same URI. // AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). func (s *Server) AddResource(r *Resource, h ResourceHandler) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + s.changeAndNotify(notificationResourceListChanged, func() bool { if _, err := url.Parse(r.URI); err != nil { panic(err) // url.Parse includes the URI in the error @@ -439,14 +437,13 @@ func (s *Server) AddResource(r *Resource, h ResourceHandler) { // RemoveResources removes the resources with the given URIs. // It is not an error to remove a nonexistent resource. func (s *Server) RemoveResources(uris ...string) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, - func() bool { return s.resources.remove(uris...) }) + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resources.remove(uris...) }) } // AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces one with the same URI. // AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, + s.changeAndNotify(notificationResourceListChanged, func() bool { // Validate the URI template syntax _, err := uritemplate.New(t.URITemplate) @@ -461,8 +458,7 @@ func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { // RemoveResourceTemplates removes the resource templates with the given URI templates. // It is not an error to remove a nonexistent resource. func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { - s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, - func() bool { return s.resourceTemplates.remove(uriTemplates...) }) + s.changeAndNotify(notificationResourceListChanged, func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } func (s *Server) capabilities() *ServerCapabilities { @@ -497,18 +493,43 @@ func (s *Server) complete(ctx context.Context, req *CompleteRequest) (*CompleteR return s.opts.CompletionHandler(ctx, req) } +// Map from notification name to its corresponding params. The params have no fields, +// so a single struct can be reused. +var changeNotificationParams = map[string]Params{ + notificationToolListChanged: &ToolListChangedParams{}, + notificationPromptListChanged: &PromptListChangedParams{}, + notificationResourceListChanged: &ResourceListChangedParams{}, +} + +// How long to wait before sending a change notification. +const notificationDelay = 10 * time.Millisecond + // changeAndNotify is called when a feature is added or removed. // It calls change, which should do the work and report whether a change actually occurred. -// If there was a change, it notifies a snapshot of the sessions. -func (s *Server) changeAndNotify(notification string, params Params, change func() bool) { - var sessions []*ServerSession - // Lock for the change, but not for the notification. +// If there was a change, it sets a timer to send a notification. +// This debounces change notifications: a single notification is sent after +// multiple changes occur in close proximity. +func (s *Server) changeAndNotify(notification string, change func() bool) { s.mu.Lock() + defer s.mu.Unlock() if change() { - sessions = slices.Clone(s.sessions) + // Stop the outstanding delayed call, if any. + if t := s.pendingNotifications[notification]; t != nil { + t.Stop() + } + // + s.pendingNotifications[notification] = time.AfterFunc(notificationDelay, func() { s.notifySessions(notification) }) } - s.mu.Unlock() - notifySessions(sessions, notification, params) +} + +// notifySessions sends the notification n to all existing sessions. +// It is called asynchronously by changeAndNotify. +func (s *Server) notifySessions(n string) { + s.mu.Lock() + sessions := slices.Clone(s.sessions) + s.pendingNotifications[n] = nil + s.mu.Unlock() // Don't hold the lock during notification: it causes deadlock. + notifySessions(sessions, n, changeNotificationParams[n]) } // Sessions returns an iterator that yields the current set of server sessions. @@ -1068,7 +1089,6 @@ func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*Eli resolved, err := schema.Resolve(nil) if err != nil { - fmt.Printf(" resolve err: %s", err) return nil, err } if err := resolved.Validate(res.Content); err != nil { diff --git a/mcp/shared.go b/mcp/shared.go index 3fac40b2..27fc9edf 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -349,6 +349,8 @@ func notifySessions[S Session, P Params](sessions []S, method string, params P) if sessions == nil { return } + // Notify with the background context, so the messages are sent on the + // standalone stream. // TODO: make this timeout configurable, or call handleNotify asynchronously. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel()