Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,8 @@ import {
ListToolsResultSchema,
type LoggingLevel,
McpError,
type Notification,
type ReadResourceRequest,
ReadResourceResultSchema,
type Request,
type Result,
type ServerCapabilities,
SUPPORTED_PROTOCOL_VERSIONS,
type SubscribeRequest,
Expand All @@ -48,7 +45,10 @@ import {
ResourceListChangedNotificationSchema,
ListChangedOptions,
ListChangedOptionsBaseSchema,
type ListChangedHandlers
type ListChangedHandlers,
type Request,
type Notification,
type Result
} from '../types.js';
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
Expand Down Expand Up @@ -368,14 +368,14 @@ export class Client<
}

const { params } = validatedRequest.data;
const mode = params.mode ?? 'form';
params.mode = params.mode ?? 'form';
const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation);

if (mode === 'form' && !supportsFormMode) {
if (params.mode === 'form' && !supportsFormMode) {
throw new McpError(ErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests');
}

if (mode === 'url' && !supportsUrlMode) {
if (params.mode === 'url' && !supportsUrlMode) {
throw new McpError(ErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests');
}

Expand Down Expand Up @@ -404,9 +404,9 @@ export class Client<
}

const validatedResult = validationResult.data;
const requestedSchema = mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined;
const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined;

if (mode === 'form' && validatedResult.action === 'accept' && validatedResult.content && requestedSchema) {
if (params.mode === 'form' && validatedResult.action === 'accept' && validatedResult.content && requestedSchema) {
if (this._capabilities.elicitation?.form?.applyDefaults) {
try {
applyElicitationDefaults(requestedSchema, validatedResult.content);
Expand Down
4 changes: 2 additions & 2 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Transport, FetchLike, createFetchWithInit, normalizeHeaders } from '../shared/transport.js';
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessage, JSONRPCMessageSchema } from '../types.js';
import { auth, AuthResult, extractWWWAuthenticateParams, OAuthClientProvider, UnauthorizedError } from './auth.js';
import { EventSourceParserStream } from 'eventsource-parser/stream';

Expand Down Expand Up @@ -350,7 +350,7 @@ export class StreamableHTTPClientTransport implements Transport {
if (!event.event || event.event === 'message') {
try {
const message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
if (isJSONRPCResponse(message)) {
if (isJSONRPCResultResponse(message)) {
// Mark that we received a response - no need to reconnect for this request
receivedResponse = true;
if (replayMessageId !== undefined) {
Expand Down
5 changes: 3 additions & 2 deletions src/examples/server/simpleTaskInteractive.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ import {
ListToolsRequestSchema,
CallToolRequestSchema,
GetTaskRequestSchema,
GetTaskPayloadRequestSchema
GetTaskPayloadRequestSchema,
GetTaskPayloadResult
} from '../../types.js';
import { TaskMessageQueue, QueuedMessage, QueuedRequest, isTerminal, CreateTaskOptions } from '../../experimental/tasks/interfaces.js';
import { InMemoryTaskStore } from '../../experimental/tasks/stores/in-memory.js';
Expand Down Expand Up @@ -618,7 +619,7 @@ const createServer = (): Server => {
});

// Handle tasks/result
server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra): Promise<Result> => {
server.setRequestHandler(GetTaskPayloadRequestSchema, async (request, extra): Promise<GetTaskPayloadResult> => {
const { taskId } = request.params;
console.log(`[Server] tasks/result called for task ${taskId}`);
return taskResultHandler.handle(taskId, server, extra.sessionId ?? '');
Expand Down
12 changes: 6 additions & 6 deletions src/experimental/tasks/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

import {
Task,
Request,
RequestId,
Result,
JSONRPCRequest,
JSONRPCNotification,
JSONRPCResponse,
JSONRPCError,
JSONRPCResultResponse,
JSONRPCErrorResponse,
ServerRequest,
ServerNotification,
CallToolResult,
GetTaskResult,
ToolExecution
ToolExecution,
Request
} from '../../types.js';
import { CreateTaskResult } from './types.js';
import type { RequestHandlerExtra, RequestTaskStore } from '../../shared/protocol.js';
Expand Down Expand Up @@ -124,13 +124,13 @@ export interface QueuedNotification extends BaseQueuedMessage {
export interface QueuedResponse extends BaseQueuedMessage {
type: 'response';
/** The actual JSONRPC response */
message: JSONRPCResponse;
message: JSONRPCResultResponse;
}

export interface QueuedError extends BaseQueuedMessage {
type: 'error';
/** The actual JSONRPC error */
message: JSONRPCError;
message: JSONRPCErrorResponse;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion src/experimental/tasks/stores/in-memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* @experimental
*/

import { Task, Request, RequestId, Result } from '../../../types.js';
import { Task, RequestId, Result, Request } from '../../../types.js';
import { TaskStore, isTerminal, TaskMessageQueue, QueuedMessage, CreateTaskOptions } from '../interfaces.js';
import { randomBytes } from 'node:crypto';

Expand Down
8 changes: 4 additions & 4 deletions src/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ import {
LoggingLevelSchema,
type LoggingMessageNotification,
McpError,
type Notification,
type Request,
type ResourceUpdatedNotification,
type Result,
type ServerCapabilities,
type ServerNotification,
type ServerRequest,
Expand All @@ -40,7 +37,10 @@ import {
type ToolUseContent,
CallToolRequestSchema,
CallToolResultSchema,
CreateTaskResultSchema
CreateTaskResultSchema,
type Request,
type Notification,
type Result
} from '../types.js';
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js';
Expand Down
12 changes: 6 additions & 6 deletions src/server/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import {
MessageExtraInfo,
RequestInfo,
isInitializeRequest,
isJSONRPCError,
isJSONRPCRequest,
isJSONRPCResponse,
isJSONRPCResultResponse,
JSONRPCMessage,
JSONRPCMessageSchema,
RequestId,
SUPPORTED_PROTOCOL_VERSIONS,
DEFAULT_NEGOTIATED_PROTOCOL_VERSION
DEFAULT_NEGOTIATED_PROTOCOL_VERSION,
isJSONRPCErrorResponse
} from '../types.js';
import getRawBody from 'raw-body';
import contentType from 'content-type';
Expand Down Expand Up @@ -871,7 +871,7 @@ export class StreamableHTTPServerTransport implements Transport {

async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise<void> {
let requestId = options?.relatedRequestId;
if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) {
// If the message is a response, use the request ID from the message
requestId = message.id;
}
Expand All @@ -881,7 +881,7 @@ export class StreamableHTTPServerTransport implements Transport {
// Those will be sent via dedicated response SSE streams
if (requestId === undefined) {
// For standalone SSE streams, we can only send requests and notifications
if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) {
throw new Error('Cannot send a response on a standalone SSE stream unless resuming a previous client request');
}

Expand Down Expand Up @@ -924,7 +924,7 @@ export class StreamableHTTPServerTransport implements Transport {
}
}

if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) {
this._requestResponseMap.set(requestId, message);
const relatedIds = Array.from(this._requestToStreamMapping.entries())
.filter(([_, streamId]) => this._streamMapping.get(streamId) === response)
Expand Down
47 changes: 26 additions & 21 deletions src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@ import {
ListTasksResultSchema,
CancelTaskRequestSchema,
CancelTaskResultSchema,
isJSONRPCError,
isJSONRPCErrorResponse,
isJSONRPCRequest,
isJSONRPCResponse,
isJSONRPCResultResponse,
isJSONRPCNotification,
JSONRPCError,
JSONRPCErrorResponse,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
McpError,
Notification,
PingRequestSchema,
Progress,
ProgressNotification,
ProgressNotificationSchema,
RELATED_TASK_META_KEY,
Request,
RequestId,
Result,
ServerCapabilities,
Expand All @@ -41,7 +39,11 @@ import {
CancelledNotification,
Task,
TaskStatusNotification,
TaskStatusNotificationSchema
TaskStatusNotificationSchema,
Request,
Notification,
JSONRPCResultResponse,
isTaskAugmentedRequestParams
} from '../types.js';
import { Transport, TransportSendOptions } from './transport.js';
import { AuthInfo } from '../server/auth/types.js';
Expand Down Expand Up @@ -324,7 +326,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
> = new Map();
private _requestHandlerAbortControllers: Map<RequestId, AbortController> = new Map();
private _notificationHandlers: Map<string, (notification: JSONRPCNotification) => Promise<void>> = new Map();
private _responseHandlers: Map<number, (response: JSONRPCResponse | Error) => void> = new Map();
private _responseHandlers: Map<number, (response: JSONRPCResultResponse | Error) => void> = new Map();
private _progressHandlers: Map<number, ProgressCallback> = new Map();
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
private _pendingDebouncedNotifications = new Set<string>();
Expand All @@ -335,7 +337,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
private _taskStore?: TaskStore;
private _taskMessageQueue?: TaskMessageQueue;

private _requestResolvers: Map<RequestId, (response: JSONRPCResponse | Error) => void> = new Map();
private _requestResolvers: Map<RequestId, (response: JSONRPCResultResponse | Error) => void> = new Map();

/**
* Callback for when the connection is closed for any reason.
Expand Down Expand Up @@ -408,18 +410,18 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const requestId = message.id;

// Lookup resolver in _requestResolvers map
const resolver = this._requestResolvers.get(requestId);
const resolver = this._requestResolvers.get(requestId as RequestId);

if (resolver) {
// Remove resolver from map after invocation
this._requestResolvers.delete(requestId);
this._requestResolvers.delete(requestId as RequestId);

// Invoke resolver with response or error
if (queuedMessage.type === 'response') {
resolver(message as JSONRPCResponse);
resolver(message as JSONRPCResultResponse);
} else {
// Convert JSONRPCError to McpError
const errorMessage = message as JSONRPCError;
const errorMessage = message as JSONRPCErrorResponse;
const error = new McpError(
errorMessage.error.code,
errorMessage.error.message,
Expand Down Expand Up @@ -546,6 +548,9 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
}

private async _oncancel(notification: CancelledNotification): Promise<void> {
if (!notification.params.requestId) {
return;
}
// Handle request cancellation
const controller = this._requestHandlerAbortControllers.get(notification.params.requestId);
controller?.abort(notification.params.reason);
Expand Down Expand Up @@ -616,7 +621,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const _onmessage = this._transport?.onmessage;
this._transport.onmessage = (message, extra) => {
_onmessage?.(message, extra);
if (isJSONRPCResponse(message) || isJSONRPCError(message)) {
if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) {
this._onresponse(message);
} else if (isJSONRPCRequest(message)) {
this._onrequest(message, extra);
Expand Down Expand Up @@ -675,7 +680,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const relatedTaskId = request.params?._meta?.[RELATED_TASK_META_KEY]?.taskId;

if (handler === undefined) {
const errorResponse: JSONRPCError = {
const errorResponse: JSONRPCErrorResponse = {
jsonrpc: '2.0',
id: request.id,
error: {
Expand Down Expand Up @@ -706,7 +711,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const abortController = new AbortController();
this._requestHandlerAbortControllers.set(request.id, abortController);

const taskCreationParams = request.params?.task;
const taskCreationParams = isTaskAugmentedRequestParams(request.params) ? request.params.task : undefined;
const taskStore = this._taskStore ? this.requestTaskStore(request, capturedTransport?.sessionId) : undefined;

const fullExtra: RequestHandlerExtra<SendRequestT, SendNotificationT> = {
Expand Down Expand Up @@ -791,7 +796,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
return;
}

const errorResponse: JSONRPCError = {
const errorResponse: JSONRPCErrorResponse = {
jsonrpc: '2.0',
id: request.id,
error: {
Expand Down Expand Up @@ -852,14 +857,14 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
handler(params);
}

private _onresponse(response: JSONRPCResponse | JSONRPCError): void {
private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void {
const messageId = Number(response.id);

// Check if this is a response to a queued request
const resolver = this._requestResolvers.get(messageId);
if (resolver) {
this._requestResolvers.delete(messageId);
if (isJSONRPCResponse(response)) {
if (isJSONRPCResultResponse(response)) {
resolver(response);
} else {
const error = new McpError(response.error.code, response.error.message, response.error.data);
Expand All @@ -879,7 +884,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e

// Keep progress handler alive for CreateTaskResult responses
let isTaskResponse = false;
if (isJSONRPCResponse(response) && response.result && typeof response.result === 'object') {
if (isJSONRPCResultResponse(response) && response.result && typeof response.result === 'object') {
const result = response.result as Record<string, unknown>;
if (result.task && typeof result.task === 'object') {
const task = result.task as Record<string, unknown>;
Expand All @@ -894,7 +899,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
this._progressHandlers.delete(messageId);
}

if (isJSONRPCResponse(response)) {
if (isJSONRPCResultResponse(response)) {
handler(response);
} else {
const error = McpError.fromError(response.error.code, response.error.message, response.error.data);
Expand Down Expand Up @@ -1191,7 +1196,7 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
const relatedTaskId = relatedTask?.taskId;
if (relatedTaskId) {
// Store the response resolver for this request so responses can be routed back
const responseResolver = (response: JSONRPCResponse | Error) => {
const responseResolver = (response: JSONRPCResultResponse | Error) => {
const handler = this._responseHandlers.get(messageId);
if (handler) {
handler(response);
Expand Down
Loading
Loading