From d189d4f56c56fcc0466a414b52afbadb79497322 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Thu, 4 Dec 2025 20:14:27 +0100 Subject: [PATCH 1/7] Add PrimingEventMessage type and update Protocol to handle it Co-authored-by: @zshea --- kotlin-sdk-core/api/kotlin-sdk-core.api | 10 ++++++++++ .../modelcontextprotocol/kotlin/sdk/shared/Protocol.kt | 2 ++ .../io/modelcontextprotocol/kotlin/sdk/types/common.kt | 4 ++-- .../modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt | 8 +++++++- .../kotlin/sdk/types/serializers.kt | 1 + 5 files changed, 22 insertions(+), 3 deletions(-) diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 49b88cb3..f503fce6 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -988,6 +988,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ClientResult$Default } public final class io/modelcontextprotocol/kotlin/sdk/types/CommonKt { + public static final field DEFAULT_NEGOTIATED_PROTOCOL_VERSION Ljava/lang/String; public static final field LATEST_PROTOCOL_VERSION Ljava/lang/String; public static final fun ProgressToken (J)Lio/modelcontextprotocol/kotlin/sdk/types/RequestId; public static final fun ProgressToken (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/types/RequestId; @@ -2947,6 +2948,15 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/PingRequestBuilder : public synthetic fun build$kotlin_sdk_core ()Lio/modelcontextprotocol/kotlin/sdk/types/Request; } +public final class io/modelcontextprotocol/kotlin/sdk/types/PrimingEventMessage : io/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/PrimingEventMessage; + public fun equals (Ljava/lang/Object;)Z + public fun getJsonrpc ()Ljava/lang/String; + public fun hashCode ()I + public final fun serializer ()Lkotlinx/serialization/KSerializer; + public fun toString ()Ljava/lang/String; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/Progress { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/Progress$Companion; public fun (DLjava/lang/Double;Ljava/lang/String;)V diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 7dd04bc0..4e2ea29a 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -13,6 +13,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.Notification import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.PrimingEventMessage import io.modelcontextprotocol.kotlin.sdk.types.Progress import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotification import io.modelcontextprotocol.kotlin.sdk.types.ProgressToken @@ -249,6 +250,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio is JSONRPCRequest -> onRequest(message) is JSONRPCNotification -> onNotification(message) is JSONRPCError -> onResponse(null, message) + is PrimingEventMessage -> Unit } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt index f715be73..e22dab7e 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/common.kt @@ -1,7 +1,5 @@ package io.modelcontextprotocol.kotlin.sdk.types -import io.modelcontextprotocol.kotlin.sdk.types.Icon.Theme.Dark -import io.modelcontextprotocol.kotlin.sdk.types.Icon.Theme.Light import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject @@ -12,6 +10,8 @@ import kotlinx.serialization.json.JsonObject public const val LATEST_PROTOCOL_VERSION: String = "2025-06-18" +public const val DEFAULT_NEGOTIATED_PROTOCOL_VERSION: String = "2025-03-26" + public val SUPPORTED_PROTOCOL_VERSIONS: List = listOf( LATEST_PROTOCOL_VERSION, "2025-03-26", diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt index 8a580eb0..30bd815f 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt @@ -85,6 +85,12 @@ public sealed interface JSONRPCMessage { public val jsonrpc: String } +@Serializable +public data object PrimingEventMessage : JSONRPCMessage { + @EncodeDefault + override val jsonrpc: String = JSONRPC_VERSION +} + // ============================================================================ // JSONRPCRequest // ============================================================================ @@ -197,7 +203,7 @@ public data class JSONRPCResponse(val id: RequestId, val result: RequestResult = * @property error Details about the error that occurred, including error code and message. */ @Serializable -public data class JSONRPCError(val id: RequestId, val error: RPCError) : JSONRPCMessage { +public data class JSONRPCError(val id: RequestId?, val error: RPCError) : JSONRPCMessage { @EncodeDefault override val jsonrpc: String = JSONRPC_VERSION } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt index 721a2aac..f3000105 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt @@ -384,6 +384,7 @@ internal object JSONRPCMessagePolymorphicSerializer : "result" in jsonObject -> JSONRPCResponse.serializer() "method" in jsonObject && "id" in jsonObject -> JSONRPCRequest.serializer() "method" in jsonObject -> JSONRPCNotification.serializer() + jsonObject.isEmpty() || jsonObject.keys == setOf("jsonrpc") -> PrimingEventMessage.serializer() else -> throw SerializationException("Invalid JSONRPCMessage type: ${jsonObject.keys}") } } From 4667bf0ee47d8449fd3e073d1223aa74dd4ff7b6 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Thu, 4 Dec 2025 20:18:36 +0100 Subject: [PATCH 2/7] Add StreamableHttpServerTransport implementation Co-authored-by: @zshea --- kotlin-sdk-server/api/kotlin-sdk-server.api | 33 +- .../server/StreamableHttpServerTransport.kt | 684 ++++++++++++++++++ 2 files changed, 713 insertions(+), 4 deletions(-) create mode 100644 kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index e5183f07..16739cc9 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -1,7 +1,17 @@ +public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore { + public abstract fun getStreamIdForEventId (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function1;)V + public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature { @@ -82,8 +92,6 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun onConnect (Lkotlin/jvm/functions/Function0;)V public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V public final fun ping (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun removeNotificationHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;)V - public final fun removeNotificationHandler (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/Method;)V public final fun removePrompt (Ljava/lang/String;)Z public final fun removePrompts (Ljava/util/List;)I public final fun removeResource (Ljava/lang/String;)Z @@ -95,8 +103,6 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun sendResourceListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun sendResourceUpdated (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun sendToolListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; - public final fun setNotificationHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function1;)V - public final fun setNotificationHandler (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function1;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { @@ -151,6 +157,25 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTranspor public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { + public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String; + public fun ()V + public fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Ljava/lang/Long;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun closeSseStream (Lio/modelcontextprotocol/kotlin/sdk/types/RequestId;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun getSessionId ()Ljava/lang/String; + public final fun handleDeleteRequest (Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleGetRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handlePostRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun handleRequest (Lio/ktor/server/sse/ServerSSESession;Lio/ktor/server/application/ApplicationCall;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setOnSessionClosed (Lkotlin/jvm/functions/Function1;)V + public final fun setOnSessionInitialized (Lkotlin/jvm/functions/Function1;)V + public final fun setSessionIdGenerator (Lkotlin/jvm/functions/Function0;)V + public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V public static final fun mcpWebSocket (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function0;)V diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt new file mode 100644 index 00000000..d1b9ce47 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -0,0 +1,684 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.request.contentType +import io.ktor.server.request.header +import io.ktor.server.request.httpMethod +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondNullable +import io.ktor.server.sse.ServerSSESession +import io.ktor.util.collections.ConcurrentMap +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions +import io.modelcontextprotocol.kotlin.sdk.types.DEFAULT_NEGOTIATED_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.PrimingEventMessage +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS +import kotlinx.coroutines.job +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +internal const val MCP_SESSION_ID_HEADER = "mcp-session-id" +private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" +private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" +private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB + +/** + * Interface for resumability support via event storage + */ +public interface EventStore { + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + public suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String + + /** + * Replays events after the specified event ID + * @param lastEventId The last event ID that was received + * @param sender Function to send events + * @return The stream ID for the replayed events + */ + public suspend fun replayEventsAfter( + lastEventId: String, + sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit, + ): String + + /** + * Returns the stream ID associated with [eventId], or null if the event is unknown. + * Default implementation is a no-op which disables extra validation during replay. + */ + public suspend fun getStreamIdForEventId(eventId: String): String? +} + +/** + * A holder for an active request call. + * If enableJsonResponse is true, session is null. + * Otherwise, session is not null. + */ +private data class SessionContext(val session: ServerSSESession?, val call: ApplicationCall) + +/** + * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It supports both SSE streaming and direct HTTP responses. + * + * In stateful mode: + * - Session ID is generated and included in response headers + * - Session ID is always included in initialization responses + * - Requests with invalid session IDs are rejected with 404 Not Found + * - Non-initialization requests without a session ID are rejected with 400 Bad Request + * - State is maintained in-memory (connections, message history) + * + * In stateless mode: + * - No Session ID is included in any responses + * - No session validation is performed + * + * @param enableJsonResponse If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + * Default is false (SSE streams are preferred). + * @param enableDnsRebindingProtection Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + * @param allowedHosts List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + * @param allowedOrigins List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + * @param eventStore Event store for resumability support + * If provided, resumability will be enabled, allowing clients to reconnect and resume messages + * @param retryIntervalMillis Retry interval (in milliseconds) advertised via SSE priming events to hint the client when to reconnect. + * Applies only when an [eventStore] is configured. Defaults to `null` (no retry hint). + */ +@OptIn(ExperimentalUuidApi::class, ExperimentalAtomicApi::class) +public class StreamableHttpServerTransport( + private val enableJsonResponse: Boolean = false, + private val enableDnsRebindingProtection: Boolean = false, + private val allowedHosts: List? = null, + private val allowedOrigins: List? = null, + private val eventStore: EventStore? = null, + private val retryIntervalMillis: Long? = null, +) : AbstractTransport() { + public var sessionId: String? = null + private set + + private var sessionIdGenerator: (() -> String)? = { Uuid.random().toString() } + private var onSessionInitialized: ((sessionId: String) -> Unit)? = null + private var onSessionClosed: ((sessionId: String) -> Unit)? = null + + private val started: AtomicBoolean = AtomicBoolean(false) + private val initialized: AtomicBoolean = AtomicBoolean(false) + + private val streamsMapping: ConcurrentMap = ConcurrentMap() + private val requestToStreamMapping: ConcurrentMap = ConcurrentMap() + private val requestToResponseMapping: ConcurrentMap = ConcurrentMap() + + private val sessionMutex = Mutex() + private val streamMutex = Mutex() + + private companion object { + const val STANDALONE_SSE_STREAM_ID = "_GET_stream" + } + + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure + * (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * Set undefined to disable session management. + */ + public fun setSessionIdGenerator(block: (() -> String)?) { + sessionIdGenerator = block + } + + /** + * A callback for session initialization events + * This is called when the server initializes a new session. + * Useful in cases when you need to register multiple mcp sessions + * and need to keep track of them. + */ + public fun setOnSessionInitialized(block: ((String) -> Unit)?) { + onSessionInitialized = block + } + + /** + * A callback for session close events + * This is called when the server closes a session due to a DELETE request. + * Useful in cases when you need to clean up resources associated with the session. + * Note that this is different from the transport closing, if you are handling + * HTTP requests from multiple nodes you might want to close each + * StreamableHTTPServerTransport after a request is completed while still keeping the + * session open/running. + */ + public fun setOnSessionClosed(block: ((String) -> Unit)?) { + onSessionClosed = block + } + + override suspend fun start() { + check(started.compareAndSet(expectedValue = false, newValue = true)) { + "StreamableHttpServerTransport already started! If using Server class, note that connect() calls start() automatically." + } + } + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + val responseRequestId: RequestId? = when (message) { + is JSONRPCResponse -> message.id + is JSONRPCError -> message.id + else -> null + } + val routingRequestId = responseRequestId ?: options?.relatedRequestId + + // Standalone SSE stream + if (routingRequestId == null) { + require(message !is JSONRPCResponse && message !is JSONRPCError) { + "Cannot send a response on a standalone SSE stream unless resuming a previous client request" + } + val standaloneStream = streamsMapping[STANDALONE_SSE_STREAM_ID] ?: return + emitOnStream(STANDALONE_SSE_STREAM_ID, standaloneStream.session, message) + return + } + + val streamId = requestToStreamMapping[routingRequestId] + ?: error("No connection established for request id $routingRequestId") + val activeStream = streamsMapping[streamId] + + if (!enableJsonResponse) { + activeStream?.let { stream -> + emitOnStream(streamId, stream.session, message) + } + } + + val isTerminated = message is JSONRPCResponse || message is JSONRPCError + if (!isTerminated) return + + requestToResponseMapping[responseRequestId!!] = message + val relatedIds = requestToStreamMapping.filterValues { it == streamId }.keys + + if (relatedIds.any { it !in requestToResponseMapping }) return + + streamMutex.withLock { + if (activeStream == null) error("No connection established for request ID: $routingRequestId") + + if (enableJsonResponse) { + activeStream.call.response.header(HttpHeaders.ContentType, ContentType.Application.Json.toString()) + sessionId?.let { activeStream.call.response.header(MCP_SESSION_ID_HEADER, it) } + val responses = relatedIds.mapNotNull { requestToResponseMapping[it] } + val payload = if (responses.size == 1) { + responses.first() + } else { + responses + } + activeStream.call.respond(payload) + } else { + activeStream.session?.close() + } + + // Clean up + relatedIds.forEach { requestId -> + requestToResponseMapping.remove(requestId) + requestToStreamMapping.remove(requestId) + } + } + } + + override suspend fun close() { + streamMutex.withLock { + streamsMapping.values.forEach { + try { + it.session?.close() + } catch (_: Exception) { + } + } + streamsMapping.clear() + requestToStreamMapping.clear() + requestToResponseMapping.clear() + _onClose() + } + } + + /** + * Handles an incoming HTTP request, whether GET, POST or DELETE + */ + public suspend fun handleRequest(session: ServerSSESession?, call: ApplicationCall) { + validateHeaders(call)?.let { reason -> + call.reject(HttpStatusCode.Forbidden, RPCError.ErrorCode.CONNECTION_CLOSED, reason) + _onError(Error(reason)) + return + } + + when (call.request.httpMethod) { + HttpMethod.Post -> handlePostRequest(session, call) + + HttpMethod.Get -> handleGetRequest(session, call) + + HttpMethod.Delete -> handleDeleteRequest(call) + + else -> call.run { + response.header(HttpHeaders.Allow, "GET, POST, DELETE") + reject(HttpStatusCode.MethodNotAllowed, RPCError.ErrorCode.CONNECTION_CLOSED, "Method not allowed.") + } + } + } + + /** + * Handles POST requests containing JSON-RPC messages + */ + public suspend fun handlePostRequest(session: ServerSSESession?, call: ApplicationCall) { + try { + if (!enableJsonResponse && session == null) error("Server session can't be null with json response") + + val acceptHeader = call.request.header(HttpHeaders.Accept) + val isAcceptEventStream = acceptHeader.accepts(ContentType.Text.EventStream) + val isAcceptJson = acceptHeader.accepts(ContentType.Application.Json) + + if (!isAcceptEventStream || !isAcceptJson) { + call.reject( + HttpStatusCode.NotAcceptable, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Not Acceptable: Client must accept both application/json and text/event-stream", + ) + return + } + + if (!call.request.contentType().match(ContentType.Application.Json)) { + call.reject( + HttpStatusCode.UnsupportedMediaType, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Unsupported Media Type: Content-Type must be application/json", + ) + return + } + + val messages = parseBody(call) ?: return + val isInitializationRequest = messages.any { + it is JSONRPCRequest && it.method == Method.Defined.Initialize.value + } + + if (isInitializationRequest) { + if (initialized.load() && sessionId != null) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.INVALID_REQUEST, + "Invalid Request: Server already initialized", + ) + return + } + if (messages.size > 1) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.INVALID_REQUEST, + "Invalid Request: Only one initialization request is allowed", + ) + return + } + + sessionMutex.withLock { + if (sessionId != null) return@withLock + sessionId = sessionIdGenerator?.invoke() + initialized.store(true) + sessionId?.let { onSessionInitialized?.invoke(it) } + } + } else { + if (!validateSession(call) || !validateProtocolVersion(call)) return + } + + val hasRequest = messages.any { it is JSONRPCRequest } + if (!hasRequest) { + call.respondNullable(status = HttpStatusCode.Accepted, message = null) + messages.forEach { message -> _onMessage(message) } + return + } + + val streamId = Uuid.random().toString() + if (!enableJsonResponse) { + call.appendSseHeaders() + flushSse(session) // flush headers immediately + maybeSendPrimingEvent(streamId, session) + } + + streamMutex.withLock { + streamsMapping[streamId] = SessionContext(session, call) + messages.filterIsInstance().forEach { requestToStreamMapping[it.id] = streamId } + } + call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } + + messages.forEach { message -> _onMessage(message) } + } catch (e: Exception) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.PARSE_ERROR, + "Parse error: ${e.message}", + ) + _onError(e) + } + } + + public suspend fun handleGetRequest(session: ServerSSESession?, call: ApplicationCall) { + if (enableJsonResponse) { + call.reject( + HttpStatusCode.MethodNotAllowed, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Method not allowed.", + ) + return + } + val sseSession = session ?: error("Server session can't be null for streaming GET requests") + + val acceptHeader = call.request.header(HttpHeaders.Accept) + if (!acceptHeader.accepts(ContentType.Text.EventStream)) { + call.reject( + HttpStatusCode.NotAcceptable, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Not Acceptable: Client must accept text/event-stream", + ) + return + } + + if (!validateSession(call) || !validateProtocolVersion(call)) return + + eventStore?.let { store -> + call.request.header(MCP_RESUMPTION_TOKEN_HEADER)?.let { lastEventId -> + replayEvents(store, lastEventId, sseSession) + return + } + } + + if (STANDALONE_SSE_STREAM_ID in streamsMapping) { + call.reject( + HttpStatusCode.Conflict, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Conflict: Only one SSE stream is allowed per session", + ) + return + } + + call.appendSseHeaders() + flushSse(sseSession) // flush headers immediately + streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call) + maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession) + sseSession.coroutineContext.job.invokeOnCompletion { + streamsMapping.remove(STANDALONE_SSE_STREAM_ID) + } + } + + public suspend fun handleDeleteRequest(call: ApplicationCall) { + if (!validateSession(call) || !validateProtocolVersion(call)) return + sessionId?.let { onSessionClosed?.invoke(it) } + close() + call.respondNullable(status = HttpStatusCode.OK, message = null) + } + + /** + * Closes the SSE stream associated with the given [requestId], prompting the client to reconnect. + * Useful for implementing polling behavior for long-running operations. + */ + public suspend fun closeSseStream(requestId: RequestId) { + if (enableJsonResponse) return + val streamId = requestToStreamMapping[requestId] ?: return + val sessionContext = streamsMapping[streamId] ?: return + + try { + sessionContext.session?.close() + } catch (e: Exception) { + _onError(e) + } finally { + streamsMapping.remove(streamId) + } + } + + private suspend fun replayEvents(store: EventStore, lastEventId: String, session: ServerSSESession) { + val call: ApplicationCall = session.call + + try { + var lookupSupported = true + val lookupStreamId = try { + store.getStreamIdForEventId(lastEventId) + } catch (_: NotImplementedError) { + lookupSupported = false + null + } catch (_: UnsupportedOperationException) { + lookupSupported = false + null + } + + if (lookupSupported) { + val streamId = lookupStreamId + ?: run { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Invalid event ID format", + ) + return + } + + if (streamId in streamsMapping) { + call.reject( + HttpStatusCode.Conflict, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Conflict: Stream already has an active connection", + ) + return + } + } + + call.appendSseHeaders() + flushSse(session) // flush headers immediately + + val streamId = store.replayEventsAfter(lastEventId) { eventId, message -> + try { + session.send( + event = "message", + id = eventId, + data = McpJson.encodeToString(message), + ) + } catch (e: Exception) { + _onError(IllegalStateException("Failed to replay event: ${e.message}", e)) + } + } + + streamsMapping[streamId] = SessionContext(session, call) + + session.coroutineContext.job.invokeOnCompletion { throwable -> + streamsMapping.remove(streamId) + throwable?.let { _onError(it) } + } + } catch (e: Exception) { + _onError(e) + } + } + + private suspend fun validateSession(call: ApplicationCall): Boolean { + if (sessionIdGenerator == null) return true + + if (!initialized.load()) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Bad Request: Server not initialized", + ) + return false + } + + val sessionHeaderValues = call.request.headers.getAll(MCP_SESSION_ID_HEADER) + + if (sessionHeaderValues.isNullOrEmpty()) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Bad Request: Mcp-Session-Id header is required", + ) + return false + } + + if (sessionHeaderValues.size > 1) { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Bad Request: Mcp-Session-Id header must be a single value", + ) + return false + } + + val headerId = sessionHeaderValues.single() + + return when (headerId) { + sessionId -> true + + else -> { + call.reject( + HttpStatusCode.NotFound, + -32001, + "Session not found", + ) + false + } + } + } + + private suspend fun validateProtocolVersion(call: ApplicationCall): Boolean { + val protocolVersions = call.request.headers.getAll(MCP_PROTOCOL_VERSION_HEADER) + val version = protocolVersions?.lastOrNull() ?: DEFAULT_NEGOTIATED_PROTOCOL_VERSION + + return when (version) { + !in SUPPORTED_PROTOCOL_VERSIONS -> { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.CONNECTION_CLOSED, + "Bad Request: Unsupported protocol version (supported versions: ${ + SUPPORTED_PROTOCOL_VERSIONS.joinToString( + ", ", + ) + })", + ) + false + } + + else -> true + } + } + + private fun validateHeaders(call: ApplicationCall): String? { + if (!enableDnsRebindingProtection) return null + + allowedHosts?.let { hosts -> + val hostHeader = call.request.headers[HttpHeaders.Host]?.lowercase() + val allowedHostsLowercase = hosts.map { it.lowercase() } + + if (hostHeader == null || hostHeader !in allowedHostsLowercase) { + return "Invalid Host header: $hostHeader" + } + } + + allowedOrigins?.let { origins -> + val originHeader = call.request.headers[HttpHeaders.Origin]?.lowercase() + val allowedOriginsLowercase = origins.map { it.lowercase() } + + if (originHeader == null || originHeader !in allowedOriginsLowercase) { + return "Invalid Origin header: $originHeader" + } + } + + return null + } + + private suspend fun flushSse(session: ServerSSESession?) { + try { + session?.send(data = "") + } catch (e: Exception) { + _onError(e) + } + } + + private suspend fun parseBody(call: ApplicationCall): List? { + val contentLength = call.request.header(HttpHeaders.ContentLength)?.toIntOrNull() ?: 0 + if (contentLength > MAXIMUM_MESSAGE_SIZE) { + call.reject( + HttpStatusCode.PayloadTooLarge, + RPCError.ErrorCode.INVALID_REQUEST, + "Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024)} MB", + ) + return null + } + + val body = call.receiveText() + if (body.length > MAXIMUM_MESSAGE_SIZE) { + call.reject( + HttpStatusCode.PayloadTooLarge, + RPCError.ErrorCode.INVALID_REQUEST, + "Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024)} MB", + ) + return null + } + + return when (val element = McpJson.parseToJsonElement(body)) { + is JsonObject -> listOf(McpJson.decodeFromJsonElement(element)) + + is JsonArray -> McpJson.decodeFromJsonElement>(element) + + else -> { + call.reject( + HttpStatusCode.BadRequest, + RPCError.ErrorCode.INVALID_REQUEST, + "Invalid Request: unable to parse JSON body", + ) + null + } + } + } + + private fun String?.accepts(mime: ContentType): Boolean = + this?.lowercase()?.contains(mime.toString().lowercase()) == true + + private suspend fun emitOnStream(streamId: String, session: ServerSSESession?, message: JSONRPCMessage) { + val eventId = eventStore?.storeEvent(streamId, message) + try { + session?.send(event = "message", id = eventId, data = McpJson.encodeToString(message)) + } catch (_: Exception) { + streamsMapping.remove(streamId) + } + } + + private suspend fun maybeSendPrimingEvent(streamId: String, session: ServerSSESession?) { + val store = eventStore ?: return + val sseSession = session ?: return + try { + val primingEventId = store.storeEvent(streamId, PrimingEventMessage) + sseSession.send(id = primingEventId, retry = retryIntervalMillis, data = "") + } catch (e: Exception) { + _onError(e) + } + } + + private fun ApplicationCall.appendSseHeaders() { + this.response.headers.append(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + this.response.headers.append(HttpHeaders.CacheControl, "no-cache, no-transform") + this.response.headers.append(HttpHeaders.Connection, "keep-alive") + sessionId?.let { this.response.headers.append(MCP_SESSION_ID_HEADER, it) } + this.response.status(HttpStatusCode.OK) + } +} + +internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: Int, message: String) { + this.response.status(status) + this.respond(JSONRPCError(id = null, error = RPCError(code = code, message = message))) +} From a40db1deb366799d3ffed248593d6293f71e76aa Mon Sep 17 00:00:00 2001 From: devcrocod Date: Thu, 4 Dec 2025 21:09:36 +0100 Subject: [PATCH 3/7] update api --- kotlin-sdk-server/api/kotlin-sdk-server.api | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index 16739cc9..53c971b1 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -8,10 +8,6 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static final fun mcp (Lio/ktor/server/application/Application;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)V public static final fun mcp (Lio/ktor/server/routing/Routing;Lkotlin/jvm/functions/Function1;)V - public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature { @@ -92,6 +88,8 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun onConnect (Lkotlin/jvm/functions/Function0;)V public final fun onInitialized (Lkotlin/jvm/functions/Function0;)V public final fun ping (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun removeNotificationHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;)V + public final fun removeNotificationHandler (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/Method;)V public final fun removePrompt (Ljava/lang/String;)Z public final fun removePrompts (Ljava/util/List;)I public final fun removeResource (Ljava/lang/String;)Z @@ -103,6 +101,8 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server { public final fun sendResourceListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun sendResourceUpdated (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/ResourceUpdatedNotification;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun sendToolListChanged (Ljava/lang/String;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun setNotificationHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function1;)V + public final fun setNotificationHandler (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function1;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/ServerOptions : io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { From e593c2e776e3d3f3686328de272833bb89f4b128 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 5 Dec 2025 16:00:03 +0100 Subject: [PATCH 4/7] Add tests for StreamableHttpServerTransport with necessary dependencies --- gradle/libs.versions.toml | 3 + kotlin-sdk-server/build.gradle.kts | 7 + .../StreamableHttpServerTransportTest.kt | 269 ++++++++++++++++++ 3 files changed, 279 insertions(+) create mode 100644 kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c1d67730..0cbc5f0a 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -44,6 +44,9 @@ kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx- ktor-client-apache5 = { group = "io.ktor", name = "ktor-client-apache5", version.ref = "ktor" } ktor-client-core = { group = "io.ktor", name = "ktor-client-core", version.ref = "ktor" } ktor-client-logging = { group = "io.ktor", name = "ktor-client-logging", version.ref = "ktor" } +ktor-server-content-negotiation = { group = "io.ktor", name = "ktor-server-content-negotiation", version.ref = "ktor" } +ktor-client-content-negotiation = { group = "io.ktor", name = "ktor-client-content-negotiation", version.ref = "ktor" } +ktor-serialization = { group = "io.ktor", name = "ktor-serialization-kotlinx-json", version.ref = "ktor" } ktor-server-core = { group = "io.ktor", name = "ktor-server-core", version.ref = "ktor" } ktor-server-sse = { group = "io.ktor", name = "ktor-server-sse", version.ref = "ktor" } ktor-server-websockets = { group = "io.ktor", name = "ktor-server-websockets", version.ref = "ktor" } diff --git a/kotlin-sdk-server/build.gradle.kts b/kotlin-sdk-server/build.gradle.kts index 64a27497..e51a6432 100644 --- a/kotlin-sdk-server/build.gradle.kts +++ b/kotlin-sdk-server/build.gradle.kts @@ -26,6 +26,13 @@ kotlin { jvmTest { dependencies { + implementation(libs.ktor.client.logging) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.client.content.negotiation) + implementation(libs.ktor.serialization) + implementation(libs.ktor.server.test.host) + implementation(libs.kotest.assertions.core) + implementation(libs.kotest.assertions.json) runtimeOnly(libs.slf4j.simple) } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt new file mode 100644 index 00000000..968c0572 --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -0,0 +1,269 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.equals.shouldBeEqual +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.install +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import io.modelcontextprotocol.kotlin.sdk.types.Tool +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.serialization.builtins.ListSerializer +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation as ServerContentNegotiation + +class StreamableHttpServerTransportTest { + private val path = "/transport" + + @Test + fun `POST without event-stream accept header is rejected`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val onMessageCalled = AtomicBoolean(false) + transport.onMessage { + onMessageCalled.set(true) + } + + configureTransportEndpoint(transport) + + val payload = buildInitializeRequestPayload() + + val response = client.post(path) { + contentType(ContentType.Application.Json) + header(HttpHeaders.Accept, ContentType.Application.Json.toString()) + setBody(payload) + } + + assertEquals(HttpStatusCode.NotAcceptable, response.status) + assertFalse(onMessageCalled.get(), "Transport should not deliver messages when headers are invalid") + } + + @Test + fun `initialization request establishes session and returns json response`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val expectedSessionId = "session-test-id" + transport.setSessionIdGenerator { expectedSessionId } + + var observedRequest: JSONRPCRequest? = null + transport.onMessage { message -> + if (message is JSONRPCRequest) { + observedRequest = message + transport.send(JSONRPCResponse(message.id, EmptyResult()), null) + } + } + + configureTransportEndpoint(transport) + + val payload = buildInitializeRequestPayload() + + val response = client.post(path) { + addStreamableHeaders() + setBody(payload) + } + + assertEquals(HttpStatusCode.OK, response.status) + assertEquals(expectedSessionId, response.headers[MCP_SESSION_ID_HEADER]) + val request = assertNotNull(observedRequest, "Initialization request should be forwarded") + + response.body() shouldBe JSONRPCResponse(request.id) + } + + @Test + fun `notifications only payload responds with 202 Accepted`() = testApplication { + configTestServer() + + val client = createTestClient() + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val receivedMessages = mutableListOf() + transport.onMessage { message -> + if (message is JSONRPCRequest) { + transport.send(JSONRPCResponse(message.id, EmptyResult())) + } + receivedMessages.add(message) + } + + configureTransportEndpoint(transport) + + val initRequest = buildInitializeRequestPayload() + + val responseInit = client.post(path) { + addStreamableHeaders() + setBody(initRequest) + } + + val notificationPayload = encodeMessages( + listOf(InitializedNotification().toJSON()), + ) + + val response = client.post(path) { + addStreamableHeaders() + header("mcp-session-id", responseInit.headers[MCP_SESSION_ID_HEADER]) + setBody(notificationPayload) + } + + assertEquals(HttpStatusCode.Accepted, response.status) + receivedMessages shouldBeEqual listOf(initRequest, InitializedNotification().toJSON()) + } + + @Test + fun `batched requests wait for all responses before replying`() = testApplication { + configTestServer() + + val client = createTestClient(logging = true) + + val transport = StreamableHttpServerTransport(enableJsonResponse = true) + val firstRequest = JSONRPCRequest(id = RequestId("first"), method = Method.Defined.ToolsList.value) + val secondRequest = JSONRPCRequest(id = RequestId("second"), method = Method.Defined.ResourcesList.value) + + val firstResult = ListToolsResult( + tools = listOf( + Tool(name = "tool-1", inputSchema = ToolSchema()), + ), + meta = buildJsonObject { put("label", "first") }, + ) + val secondResult = ListResourcesResult( + resources = emptyList(), + meta = buildJsonObject { put("label", "second") }, + ) + + transport.onMessage { message -> + if (message is JSONRPCRequest) { + val result = when (message.id) { + firstRequest.id -> firstResult + secondRequest.id -> secondResult + else -> EmptyResult() + } + transport.send(JSONRPCResponse(message.id, result), null) + } + } + + configureTransportEndpoint(transport) + + val initRequest = buildInitializeRequestPayload() + + val responseInit = client.post(path) { + addStreamableHeaders() + setBody(initRequest) + } + + val payload = encodeMessages(listOf(firstRequest, secondRequest)) + + val response = client.post(path) { + addStreamableHeaders() + header("mcp-session-id", responseInit.headers[MCP_SESSION_ID_HEADER]) + setBody(payload) + } + + assertEquals(HttpStatusCode.OK, response.status) + + val responses = response.body>().map { it.result } + responses shouldContain (firstResult) + responses shouldContain (secondResult) + // TODO(check order) +// assertEquals(listOf(firstRequest.id, secondRequest.id), responses.map { it.id }) +// val firstMeta = (responses[0].result as EmptyResult).meta +// val secondMeta = (responses[1].result as EmptyResult).meta +// assertEquals("first", firstMeta?.get("label")?.jsonPrimitive?.content) +// assertEquals("second", secondMeta?.get("label")?.jsonPrimitive?.content) + } + + private fun ApplicationTestBuilder.configureTransportEndpoint(transport: StreamableHttpServerTransport) { + application { + routing { + post(path) { + transport.handlePostRequest(null, call) + } + } + } + } + + private fun HttpRequestBuilder.addStreamableHeaders() { + header( + HttpHeaders.Accept, + listOf(ContentType.Application.Json, ContentType.Text.EventStream).joinToString(", ") { + it.toString() + }, + ) + contentType(ContentType.Application.Json) + } + + private fun buildInitializeRequestPayload(): JSONRPCRequest { + val request = InitializeRequest( + InitializeRequestParams( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ClientCapabilities(), + clientInfo = Implementation(name = "test-client", version = "1.0.0"), + ), + ).toJSON() + + return request + } + + private fun encodeMessages(messages: List): String = + McpJson.encodeToString(ListSerializer(JSONRPCMessage.serializer()), messages) + + private fun ApplicationTestBuilder.configTestServer() { + application { + install(ServerContentNegotiation) { + json(McpJson) + } + } + } + + private fun ApplicationTestBuilder.createTestClient(logging: Boolean = false): HttpClient = createClient { + install(ClientContentNegotiation) { + json(McpJson) + } + if (logging) { + install(Logging) { + level = LogLevel.ALL + } + } + } +} From 02c3c15bc85da768c9b1d331f49e91acd34496b5 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 5 Dec 2025 16:52:59 +0100 Subject: [PATCH 5/7] Replace `PrimingEventMessage` with `JSONRPCEmptyMessage` across SDK and protocols --- .../kotlin/sdk/shared/Protocol.kt | 4 +-- .../kotlin/sdk/types/jsonRpc.kt | 3 +- .../kotlin/sdk/types/serializers.kt | 2 +- .../kotlin/sdk/types/JsonRpcTest.kt | 12 +++++++ .../kotlin/sdk/server/EventStore.kt | 33 ++++++++++++++++++ .../server/StreamableHttpServerTransport.kt | 34 ++----------------- 6 files changed, 51 insertions(+), 37 deletions(-) create mode 100644 kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/EventStore.kt diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 4e2ea29a..358b5659 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.types.CancelledNotification import io.modelcontextprotocol.kotlin.sdk.types.CancelledNotificationParams import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCEmptyMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest @@ -13,7 +14,6 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.Notification import io.modelcontextprotocol.kotlin.sdk.types.PingRequest -import io.modelcontextprotocol.kotlin.sdk.types.PrimingEventMessage import io.modelcontextprotocol.kotlin.sdk.types.Progress import io.modelcontextprotocol.kotlin.sdk.types.ProgressNotification import io.modelcontextprotocol.kotlin.sdk.types.ProgressToken @@ -250,7 +250,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio is JSONRPCRequest -> onRequest(message) is JSONRPCNotification -> onNotification(message) is JSONRPCError -> onResponse(null, message) - is PrimingEventMessage -> Unit + is JSONRPCEmptyMessage -> Unit } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt index 30bd815f..50fd02fc 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt @@ -86,8 +86,7 @@ public sealed interface JSONRPCMessage { } @Serializable -public data object PrimingEventMessage : JSONRPCMessage { - @EncodeDefault +public data object JSONRPCEmptyMessage : JSONRPCMessage { override val jsonrpc: String = JSONRPC_VERSION } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt index f3000105..a00b26be 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt @@ -384,7 +384,7 @@ internal object JSONRPCMessagePolymorphicSerializer : "result" in jsonObject -> JSONRPCResponse.serializer() "method" in jsonObject && "id" in jsonObject -> JSONRPCRequest.serializer() "method" in jsonObject -> JSONRPCNotification.serializer() - jsonObject.isEmpty() || jsonObject.keys == setOf("jsonrpc") -> PrimingEventMessage.serializer() + jsonObject.isEmpty() || jsonObject.keys == setOf("jsonrpc") -> JSONRPCEmptyMessage.serializer() else -> throw SerializationException("Invalid JSONRPCMessage type: ${jsonObject.keys}") } } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/JsonRpcTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/JsonRpcTest.kt index d8d20583..ce02cc04 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/JsonRpcTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/JsonRpcTest.kt @@ -430,4 +430,16 @@ class JsonRpcTest { request.method shouldBe "notifications/log" request.params shouldBeSameInstanceAs params } + + @Test + fun `should deserialize JSONRPCEmptyMessage`() { + val json = """ + { + "jsonrpc": "2.0" + } + """.trimIndent() + + val message = McpJson.decodeFromString(json) + message shouldBeSameInstanceAs JSONRPCEmptyMessage + } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/EventStore.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/EventStore.kt new file mode 100644 index 00000000..ff9abbf6 --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/EventStore.kt @@ -0,0 +1,33 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage + +/** + * Interface for resumability support via event storage + */ +public interface EventStore { + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + public suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String + + /** + * Replays events after the specified event ID + * @param lastEventId The last event ID that was received + * @param sender Function to send events + * @return The stream ID for the replayed events + */ + public suspend fun replayEventsAfter( + lastEventId: String, + sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit, + ): String + + /** + * Returns the stream ID associated with [eventId], or null if the event is unknown. + * Default implementation is a no-op which disables extra validation during replay. + */ + public suspend fun getStreamIdForEventId(eventId: String): String? +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index d1b9ce47..56bcead4 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -17,13 +17,13 @@ import io.ktor.util.collections.ConcurrentMap import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.types.DEFAULT_NEGOTIATED_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCEmptyMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.Method -import io.modelcontextprotocol.kotlin.sdk.types.PrimingEventMessage import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS @@ -43,36 +43,6 @@ private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID" private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB -/** - * Interface for resumability support via event storage - */ -public interface EventStore { - /** - * Stores an event for later retrieval - * @param streamId ID of the stream the event belongs to - * @param message The JSON-RPC message to store - * @returns The generated event ID for the stored event - */ - public suspend fun storeEvent(streamId: String, message: JSONRPCMessage): String - - /** - * Replays events after the specified event ID - * @param lastEventId The last event ID that was received - * @param sender Function to send events - * @return The stream ID for the replayed events - */ - public suspend fun replayEventsAfter( - lastEventId: String, - sender: suspend (eventId: String, message: JSONRPCMessage) -> Unit, - ): String - - /** - * Returns the stream ID associated with [eventId], or null if the event is unknown. - * Default implementation is a no-op which disables extra validation during replay. - */ - public suspend fun getStreamIdForEventId(eventId: String): String? -} - /** * A holder for an active request call. * If enableJsonResponse is true, session is null. @@ -662,7 +632,7 @@ public class StreamableHttpServerTransport( val store = eventStore ?: return val sseSession = session ?: return try { - val primingEventId = store.storeEvent(streamId, PrimingEventMessage) + val primingEventId = store.storeEvent(streamId, JSONRPCEmptyMessage) sseSession.send(id = primingEventId, retry = retryIntervalMillis, data = "") } catch (e: Exception) { _onError(e) From 04956fb0fb253f3a50d00997a4b1d37999dbb1dd Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 5 Dec 2025 16:53:51 +0100 Subject: [PATCH 6/7] Add JSONRPCEmptyMessage type to the API --- kotlin-sdk-core/api/kotlin-sdk-core.api | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index f503fce6..50208af1 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -2047,6 +2047,15 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/InitializedNotificat public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/types/JSONRPCEmptyMessage : io/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCEmptyMessage; + public fun equals (Ljava/lang/Object;)Z + public fun getJsonrpc ()Ljava/lang/String; + public fun hashCode ()I + public final fun serializer ()Lkotlinx/serialization/KSerializer; + public fun toString ()Ljava/lang/String; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/JSONRPCError : io/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCError$Companion; public fun (Lio/modelcontextprotocol/kotlin/sdk/types/RequestId;Lio/modelcontextprotocol/kotlin/sdk/types/RPCError;)V From 600850e48569c11e9728b1aa088ffc81734df028 Mon Sep 17 00:00:00 2001 From: devcrocod Date: Fri, 5 Dec 2025 17:19:43 +0100 Subject: [PATCH 7/7] Remove `PrimingEventMessage` type from API --- kotlin-sdk-core/api/kotlin-sdk-core.api | 9 --------- 1 file changed, 9 deletions(-) diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index 50208af1..8fff88b7 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -2957,15 +2957,6 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/PingRequestBuilder : public synthetic fun build$kotlin_sdk_core ()Lio/modelcontextprotocol/kotlin/sdk/types/Request; } -public final class io/modelcontextprotocol/kotlin/sdk/types/PrimingEventMessage : io/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage { - public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/PrimingEventMessage; - public fun equals (Ljava/lang/Object;)Z - public fun getJsonrpc ()Ljava/lang/String; - public fun hashCode ()I - public final fun serializer ()Lkotlinx/serialization/KSerializer; - public fun toString ()Ljava/lang/String; -} - public final class io/modelcontextprotocol/kotlin/sdk/types/Progress { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/Progress$Companion; public fun (DLjava/lang/Double;Ljava/lang/String;)V