diff --git a/src/server/context.ts b/src/server/context.ts new file mode 100644 index 000000000..91b075615 --- /dev/null +++ b/src/server/context.ts @@ -0,0 +1,339 @@ +import { + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + ElicitResultSchema, + JSONRPCRequest, + LoggingMessageNotification, + Notification, + Request, + RequestId, + RequestInfo, + RequestMeta, + Result, + ServerNotification, + ServerRequest +} from '../types.js'; +import { RequestHandlerExtra, RequestOptions, RequestTaskStore } from '../shared/protocol.js'; +import { Server } from './index.js'; +import { AuthInfo } from './auth/types.js'; +import { AnySchema, SchemaOutput } from './zod-compat.js'; + +/** + * Interface for sending logging messages to the client via {@link LoggingMessageNotification}. + */ +export interface LoggingMessageNotificationSenderInterface { + /** + * Sends a logging message to the client. + */ + log(params: LoggingMessageNotification['params'], sessionId?: string): Promise; + /** + * Sends a debug log message to the client. + */ + debug(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an info log message to the client. + */ + info(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends a warning log message to the client. + */ + warning(message: string, extraLogData?: Record, sessionId?: string): Promise; + /** + * Sends an error log message to the client. + */ + error(message: string, extraLogData?: Record, sessionId?: string): Promise; +} + +export class ServerLogger implements LoggingMessageNotificationSenderInterface { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + constructor(private readonly server: Server) {} + + /** + * Sends a logging message. + */ + public async log(params: LoggingMessageNotification['params'], sessionId?: string) { + await this.server.sendLoggingMessage(params, sessionId); + } + + /** + * Sends a debug log message. + */ + public async debug(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'debug', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an info log message. + */ + public async info(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'info', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends a warning log message. + */ + public async warning(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'warning', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } + + /** + * Sends an error log message. + */ + public async error(message: string, extraLogData?: Record, sessionId?: string) { + await this.log( + { + level: 'error', + data: { + ...extraLogData, + message + }, + logger: 'server' + }, + sessionId + ); + } +} + +export interface ContextInterface + extends RequestHandlerExtra { + elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise; + requestSampling: (params: CreateMessageRequest['params'], options?: RequestOptions) => Promise; + loggingNotification: LoggingMessageNotificationSenderInterface; +} +/** + * A context object that is passed to request handlers. + * + * Implements the RequestHandlerExtra interface for backwards compatibility. + */ +export class Context + implements ContextInterface +{ + private readonly server: Server; + + /** + * The request context. + * A type-safe context that is passed to request handlers. + */ + private readonly requestCtx: RequestHandlerExtra; + + /** + * The MCP context - Contains information about the current MCP request and session. + */ + public readonly mcpContext: { + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + */ + requestId: RequestId; + /** + * The method of the request. + */ + method: string; + /** + * The metadata of the request. + */ + _meta?: RequestMeta; + /** + * The session ID of the request. + */ + sessionId?: string; + }; + + public readonly task: + | { + id: string | undefined; + store: RequestTaskStore | undefined; + requestedTtl: number | null | undefined; + } + | undefined; + + public readonly stream: + | { + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior during long-running operations. + */ + closeSSEStream: (() => void) | undefined; + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using StreamableHTTPServerTransport with eventStore configured. + * Use this to implement polling behavior for server-initiated notifications. + */ + closeStandaloneSSEStream: (() => void) | undefined; + } + | undefined; + + public readonly loggingNotification: LoggingMessageNotificationSenderInterface; + + constructor(args: { + server: Server; + request: JSONRPCRequest; + requestCtx: RequestHandlerExtra; + }) { + this.server = args.server; + this.requestCtx = args.requestCtx; + this.mcpContext = { + requestId: args.requestCtx.requestId, + method: args.request.method, + _meta: args.requestCtx._meta, + sessionId: args.requestCtx.sessionId + }; + + this.task = { + id: args.requestCtx.taskId, + store: args.requestCtx.taskStore, + requestedTtl: args.requestCtx.taskRequestedTtl + }; + + this.loggingNotification = new ServerLogger(args.server); + + this.stream = { + closeSSEStream: args.requestCtx.closeSSEStream, + closeStandaloneSSEStream: args.requestCtx.closeStandaloneSSEStream + }; + } + + /** + * The JSON-RPC ID of the request being handled. + * This can be useful for tracking or logging purposes. + * + * @deprecated Use {@link mcpContext.requestId} instead. + */ + public get requestId(): RequestId { + return this.requestCtx.requestId; + } + + public get signal(): AbortSignal { + return this.requestCtx.signal; + } + + public get authInfo(): AuthInfo | undefined { + return this.requestCtx.authInfo; + } + + public get requestInfo(): RequestInfo | undefined { + return this.requestCtx.requestInfo; + } + + /** + * @deprecated Use {@link mcpContext._meta} instead. + */ + public get _meta(): RequestMeta | undefined { + return this.requestCtx._meta; + } + + /** + * @deprecated Use {@link mcpContext.sessionId} instead. + */ + public get sessionId(): string | undefined { + return this.mcpContext.sessionId; + } + + /** + * @deprecated Use {@link task.id} instead. + */ + public get taskId(): string | undefined { + return this.requestCtx.taskId; + } + + /** + * @deprecated Use {@link task.store} instead. + */ + public get taskStore(): RequestTaskStore | undefined { + return this.requestCtx.taskStore; + } + + /** + * @deprecated Use {@link task.requestedTtl} instead. + */ + public get taskRequestedTtl(): number | undefined { + return this.requestCtx.taskRequestedTtl ?? undefined; + } + + /** + * @deprecated Use {@link stream.closeSSEStream} instead. + */ + public get closeSSEStream(): (() => void) | undefined { + return this.requestCtx.closeSSEStream; + } + + /** + * @deprecated Use {@link stream.closeStandaloneSSEStream} instead. + */ + public get closeStandaloneSSEStream(): (() => void) | undefined { + return this.requestCtx.closeStandaloneSSEStream; + } + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendNotification = (notification: NotificationT | ServerNotification): Promise => { + return this.requestCtx.sendNotification(notification); + }; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + public sendRequest = ( + request: RequestT | ServerRequest, + resultSchema: U, + options?: RequestOptions + ): Promise> => { + return this.requestCtx.sendRequest(request, resultSchema, { ...options, relatedRequestId: this.requestId }); + }; + + /** + * Sends a request to sample an LLM via the client. + */ + public requestSampling(params: CreateMessageRequest['params'], options?: RequestOptions) { + return this.server.createMessage(params, options); + } + + /** + * Sends an elicitation request to the client. + */ + public async elicitInput(params: ElicitRequest['params'], options?: RequestOptions): Promise { + const request: ElicitRequest = { + method: 'elicitation/create', + params + }; + return await this.server.request(request, ElicitResultSchema, { ...options, relatedRequestId: this.requestId }); + } +} diff --git a/src/server/index.ts b/src/server/index.ts index aa1a62d00..fa0d038f1 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -40,7 +40,10 @@ import { type ToolUseContent, CallToolRequestSchema, CallToolResultSchema, - CreateTaskResultSchema + CreateTaskResultSchema, + JSONRPCRequest, + TaskCreationParams, + MessageExtraInfo } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js'; @@ -56,6 +59,9 @@ import { import { RequestHandlerExtra } from '../shared/protocol.js'; import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; import { assertToolsCallTaskCapability, assertClientRequestTaskCapability } from '../experimental/tasks/helpers.js'; +import { Context } from './context.js'; +import { TaskStore } from '../experimental/index.js'; +import { Transport } from '../shared/transport.js'; export type ServerOptions = ProtocolOptions & { /** @@ -219,9 +225,31 @@ export class Server< requestSchema: T, handler: ( request: SchemaOutput, - extra: RequestHandlerExtra + extra: Context ) => ServerResult | ResultT | Promise ): void { + // Wrap the handler to ensure the extra is a Context and return a decorated handler that can be passed to the base implementation + + // Factory function to create a handler decorator that ensures the extra is a Context and returns a decorated handler that can be passed to the base implementation + const handlerDecoratorFactory = ( + innerHandler: ( + request: SchemaOutput, + extra: Context + ) => ServerResult | ResultT | Promise + ) => { + const decoratedHandler = ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => { + if (!this.isContextExtra(extra)) { + throw new Error('Internal error: Expected Context for request handler extra'); + } + return innerHandler(request, extra); + }; + + return decoratedHandler; + }; + const shape = getObjectShape(requestSchema); const methodSchema = shape?.method; if (!methodSchema) { @@ -259,7 +287,7 @@ export class Server< const { params } = validatedRequest.data; - const result = await Promise.resolve(handler(request, extra)); + const result = await Promise.resolve(handlerDecoratorFactory(handler)(request, extra)); // When task creation is requested, validate and return CreateTaskResult if (params.task) { @@ -286,11 +314,18 @@ export class Server< }; // Install the wrapped handler - return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(wrappedHandler)); } // Other handlers use default behavior - return super.setRequestHandler(requestSchema, handler); + return super.setRequestHandler(requestSchema, handlerDecoratorFactory(handler)); + } + + // Runtime type guard: ensure extra is our Context + private isContextExtra( + extra: RequestHandlerExtra + ): extra is Context { + return extra instanceof Context; } protected assertCapabilityForMethod(method: RequestT['method']): void { @@ -468,6 +503,25 @@ export class Server< return this._capabilities; } + protected createRequestExtra(args: { + request: JSONRPCRequest; + taskStore: TaskStore | undefined; + relatedTaskId: string | undefined; + taskCreationParams: TaskCreationParams | undefined; + abortController: AbortController; + capturedTransport: Transport | undefined; + extra?: MessageExtraInfo; + }): RequestHandlerExtra { + const base = super.createRequestExtra(args) as RequestHandlerExtra; + + // Expose a Context instance to handlers, which implements RequestHandlerExtra + return new Context({ + server: this, + request: args.request, + requestCtx: base + }); + } + async ping() { return this.request({ method: 'ping' }, EmptyResultSchema); } diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 7e61b4364..6d4e1a5ef 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -63,6 +63,7 @@ import { validateAndWarnToolName } from '../shared/toolNameValidation.js'; import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcp-server.js'; import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; import { ZodOptional } from 'zod'; +import { ContextInterface } from './context.js'; /** * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. @@ -324,7 +325,7 @@ export class McpServer { private async executeToolHandler( tool: RegisteredTool, args: unknown, - extra: RequestHandlerExtra + extra: ContextInterface ): Promise { const handler = tool.handler as AnyToolHandler; const isTaskHandler = 'createTask' in handler; @@ -1135,7 +1136,7 @@ export class McpServer { /** * Registers a prompt with a config object and callback. */ - registerPrompt( + registerPrompt( name: string, config: { title?: string; @@ -1270,7 +1271,7 @@ export class ResourceTemplate { export type BaseToolCallback< SendResultT extends Result, - Extra extends RequestHandlerExtra, + Extra extends ContextInterface, Args extends undefined | ZodRawShapeCompat | AnySchema > = Args extends ZodRawShapeCompat ? (args: ShapeOutput, extra: Extra) => SendResultT | Promise @@ -1290,7 +1291,7 @@ export type BaseToolCallback< */ export type ToolCallback = BaseToolCallback< CallToolResult, - RequestHandlerExtra, + ContextInterface, Args >; @@ -1409,7 +1410,7 @@ export type ResourceMetadata = Omit; * Callback to list all resources matching a given template. */ export type ListResourcesCallback = ( - extra: RequestHandlerExtra + extra: ContextInterface ) => ListResourcesResult | Promise; /** @@ -1417,7 +1418,7 @@ export type ListResourcesCallback = ( */ export type ReadResourceCallback = ( uri: URL, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResource = { @@ -1445,7 +1446,7 @@ export type RegisteredResource = { export type ReadResourceTemplateCallback = ( uri: URL, variables: Variables, - extra: RequestHandlerExtra + extra: ContextInterface ) => ReadResourceResult | Promise; export type RegisteredResourceTemplate = { @@ -1470,8 +1471,8 @@ export type RegisteredResourceTemplate = { type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise - : (extra: RequestHandlerExtra) => GetPromptResult | Promise; + ? (args: ShapeOutput, extra: ContextInterface) => GetPromptResult | Promise + : (extra: ContextInterface) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index e195478f2..87d70b10d 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -231,6 +231,8 @@ export interface RequestTaskStore { /** * Extra data given to request handlers. + * + * @deprecated Use {@link ContextInterface} from {@link Context} instead. Future major versions will remove this type. */ export type RequestHandlerExtra = { /** @@ -709,43 +711,15 @@ export abstract class Protocol = { - signal: abortController.signal, - sessionId: capturedTransport?.sessionId, - _meta: request.params?._meta, - sendNotification: async notification => { - // Include related-task metadata if this request is part of a task - const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; - if (relatedTaskId) { - notificationOptions.relatedTask = { taskId: relatedTaskId }; - } - await this.notification(notification, notificationOptions); - }, - sendRequest: async (r, resultSchema, options?) => { - // Include related-task metadata if this request is part of a task - const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - - // Set task status to input_required when sending a request within a task context - // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await this.request(r, resultSchema, requestOptions); - }, - authInfo: extra?.authInfo, - requestId: request.id, - requestInfo: extra?.requestInfo, - taskId: relatedTaskId, - taskStore: taskStore, - taskRequestedTtl: taskCreationParams?.ttl, - closeSSEStream: extra?.closeSSEStream, - closeStandaloneSSEStream: extra?.closeStandaloneSSEStream - }; + const fullExtra: RequestHandlerExtra = this.createRequestExtra({ + request, + taskStore, + relatedTaskId, + taskCreationParams, + abortController, + capturedTransport, + extra + }); // Starting with Promise.resolve() puts any synchronous errors into the monad as well. Promise.resolve() @@ -823,6 +797,60 @@ export abstract class Protocol { + const { request, taskStore, relatedTaskId, taskCreationParams, abortController, capturedTransport, extra } = args; + + return { + signal: abortController.signal, + sessionId: capturedTransport?.sessionId, + _meta: request.params?._meta, + sendNotification: async notification => { + // Include related-task metadata if this request is part of a task + const notificationOptions: NotificationOptions = { relatedRequestId: request.id }; + if (relatedTaskId) { + notificationOptions.relatedTask = { taskId: relatedTaskId }; + } + await this.notification(notification, notificationOptions); + }, + sendRequest: async (r, resultSchema, options?) => { + // Include related-task metadata if this request is part of a task + const requestOptions: RequestOptions = { ...options, relatedRequestId: request.id }; + if (relatedTaskId && !requestOptions.relatedTask) { + requestOptions.relatedTask = { taskId: relatedTaskId }; + } + + // Set task status to input_required when sending a request within a task context + // Use the taskId from options (explicit) or fall back to relatedTaskId (inherited) + const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; + if (effectiveTaskId && taskStore) { + await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); + } + + return await this.request(r, resultSchema, requestOptions); + }, + authInfo: extra?.authInfo, + requestId: request.id, + requestInfo: extra?.requestInfo, + taskId: relatedTaskId, + taskStore: taskStore, + taskRequestedTtl: taskCreationParams?.ttl, + closeSSEStream: extra?.closeSSEStream, + closeStandaloneSSEStream: extra?.closeStandaloneSSEStream + } as RequestHandlerExtra; + } + private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); diff --git a/test/server/context.test.ts b/test/server/context.test.ts new file mode 100644 index 000000000..48f547601 --- /dev/null +++ b/test/server/context.test.ts @@ -0,0 +1,272 @@ +import { z } from 'zod/v4'; +import { Client } from '../../src/client/index.js'; +import { McpServer, ResourceTemplate } from '../../src/server/mcp.js'; +import { Context } from '../../src/server/context.js'; +import { + CallToolResultSchema, + GetPromptResultSchema, + ListResourcesResultSchema, + LoggingMessageNotificationSchema, + ReadResourceResultSchema, + ServerNotification, + ServerRequest +} from '../../src/types.js'; +import { InMemoryTransport } from '../../src/inMemory.js'; +import { RequestHandlerExtra } from '../../src/shared/protocol.js'; + +describe('Context', () => { + /*** + * Test: `extra` provided to callbacks is Context (parameterized) + */ + type Seen = { isContext: boolean; hasRequestId: boolean }; + const contextCases: Array<[string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise]> = + [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + (_args: { name: string }, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + mcpServer.registerResource('ctx-resource', 'test://res/1', { title: 'ctx-resource' }, async (_uri, extra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + }); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + const template = new ResourceTemplate('test://items/{id}', { + list: async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + mcpServer.registerPrompt('ctx-prompt', {}, async extra => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + + const logLevelsThroughContext = ['debug', 'info', 'warning', 'error'] as const; + + //it.each for each log level, test that logging message is sent to client + it.each(logLevelsThroughContext)('should send logging message to client for %s level from Context', async level => { + const mcpServer = new McpServer( + { name: 'ctx-test', version: '1.0' }, + { + capabilities: { + logging: {} + } + } + ); + const client = new Client( + { name: 'ctx-client', version: '1.0' }, + { + capabilities: {} + } + ); + + let seen = 0; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + seen++; + expect(notification.params.level).toBe(level); + expect(notification.params.data).toBe('Test message'); + expect(notification.params.test).toBe('test'); + expect(notification.params.sessionId).toBe('sample-session-id'); + return; + }); + + mcpServer.registerTool('ctx-log-test', { inputSchema: z.object({ name: z.string() }) }, async (_args: { name: string }, extra) => { + await extra.loggingNotification[level]('Test message', { test: 'test' }, 'sample-session-id'); + await extra.loggingNotification.log( + { + level, + data: 'Test message', + logger: 'test-logger-namespace' + }, + 'sample-session-id' + ); + return { content: [{ type: 'text', text: 'ok' }] }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { name: 'ctx-log-test', arguments: { name: 'ctx-log-test-name' } } + }, + CallToolResultSchema + ); + + // two messages should have been sent - one from the .log method and one from the .debug/info/warning/error method + expect(seen).toBe(2); + + expect(result.content).toHaveLength(1); + expect(result.content[0]).toMatchObject({ + type: 'text', + text: 'ok' + }); + }); + describe('Legacy RequestHandlerExtra API', () => { + const contextCases: Array< + [string, (mcpServer: McpServer, seen: Seen) => void | Promise, (client: Client) => Promise] + > = [ + [ + 'tool', + (mcpServer, seen) => { + mcpServer.registerTool( + 'ctx-tool', + { + inputSchema: z.object({ name: z.string() }) + }, + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + (_args: { name: string }, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { content: [{ type: 'text', text: 'ok' }] }; + } + ); + }, + client => + client.request( + { + method: 'tools/call', + params: { + name: 'ctx-tool', + arguments: { + name: 'ctx-tool-name' + } + } + }, + CallToolResultSchema + ) + ], + [ + 'resource', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerResource( + 'ctx-resource', + 'test://res/1', + { title: 'ctx-resource' }, + async (_uri, extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { contents: [{ uri: 'test://res/1', mimeType: 'text/plain', text: 'hello' }] }; + } + ); + }, + client => client.request({ method: 'resources/read', params: { uri: 'test://res/1' } }, ReadResourceResultSchema) + ], + [ + 'resource template list', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + const template = new ResourceTemplate('test://items/{id}', { + list: async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { resources: [] }; + } + }); + mcpServer.registerResource('ctx-template', template, { title: 'ctx-template' }, async (_uri, _vars, _extra) => ({ + contents: [] + })); + }, + client => client.request({ method: 'resources/list', params: {} }, ListResourcesResultSchema) + ], + [ + 'prompt', + (mcpServer, seen) => { + // The test is to ensure that the extra is compatible with the RequestHandlerExtra type + mcpServer.registerPrompt('ctx-prompt', {}, async (extra: RequestHandlerExtra) => { + seen.isContext = extra instanceof Context; + seen.hasRequestId = !!extra.requestId; + return { messages: [] }; + }); + }, + client => client.request({ method: 'prompts/get', params: { name: 'ctx-prompt', arguments: {} } }, GetPromptResultSchema) + ] + ]; + + test.each(contextCases)('should pass Context as extra to %s callbacks', async (_kind, register, trigger) => { + const mcpServer = new McpServer({ name: 'ctx-test', version: '1.0' }); + const client = new Client({ name: 'ctx-client', version: '1.0' }); + + const seen: Seen = { isContext: false, hasRequestId: false }; + + await register(mcpServer, seen); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await trigger(client); + + expect(seen.isContext).toBe(true); + expect(seen.hasRequestId).toBe(true); + }); + }); +}); diff --git a/test/server/mcp.test.ts b/test/server/mcp.test.ts index f6c2124e1..4be2d24f4 100644 --- a/test/server/mcp.test.ts +++ b/test/server/mcp.test.ts @@ -17,12 +17,15 @@ import { ReadResourceResultSchema, type TextContent, UrlElicitationRequiredError, - ErrorCode + ErrorCode, + ServerRequest, + ServerNotification } from '../../src/types.js'; import { completable } from '../../src/server/completable.js'; import { McpServer, ResourceTemplate } from '../../src/server/mcp.js'; import { InMemoryTaskStore } from '../../src/experimental/tasks/stores/in-memory.js'; import { zodTestMatrix, type ZodMatrixEntry } from '../../src/__fixtures__/zodTestMatrix.js'; +import { Context, ContextInterface } from '../../src/server/context.js'; function createLatch() { let latch = false; @@ -243,7 +246,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { sendNotification: () => { throw new Error('Not implemented'); } - }); + } as unknown as ContextInterface); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); }); @@ -4387,17 +4390,20 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { }) } }, - async ({ department, name }) => ({ - messages: [ - { - role: 'assistant', - content: { - type: 'text', - text: `Hello ${name}, welcome to the ${department} team!` + async ({ department, name }, extra: ContextInterface) => { + expect(extra).toBeInstanceOf(Context); + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } } - } - ] - }) + ] + }; + } ); const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); diff --git a/test/server/streamableHttp.test.ts b/test/server/streamableHttp.test.ts index 9fc2d3017..0161d82fb 100644 --- a/test/server/streamableHttp.test.ts +++ b/test/server/streamableHttp.test.ts @@ -2285,8 +2285,8 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => { // Verify we received the notification that was sent while disconnected expect(allText).toContain('Missed while disconnected'); - }); - }, 10000); + }, 10000); + }); // Test onsessionclosed callback describe('StreamableHTTPServerTransport onsessionclosed callback', () => {