-
Notifications
You must be signed in to change notification settings - Fork 183
Add Server Streamable Http Transport #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f88b1d
c41c227
ee3e680
f62a6de
4d01980
66f88a1
db8894a
68da842
13ea9ac
24905f0
99800c4
bacf6e8
eda4337
613d3fb
c1ad3d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging | |
| import io.ktor.http.HttpStatusCode | ||
| import io.ktor.server.application.Application | ||
| import io.ktor.server.application.install | ||
| import io.ktor.server.request.header | ||
| import io.ktor.server.response.respond | ||
| import io.ktor.server.routing.Routing | ||
| import io.ktor.server.routing.RoutingContext | ||
|
|
@@ -14,6 +15,8 @@ import io.ktor.server.sse.SSE | |
| import io.ktor.server.sse.ServerSSESession | ||
| import io.ktor.server.sse.sse | ||
| import io.ktor.utils.io.KtorDsl | ||
| import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport | ||
| import io.modelcontextprotocol.kotlin.sdk.types.RPCError | ||
| import kotlinx.atomicfu.AtomicRef | ||
| import kotlinx.atomicfu.atomic | ||
| import kotlinx.atomicfu.update | ||
|
|
@@ -23,13 +26,15 @@ import kotlinx.coroutines.awaitCancellation | |
|
|
||
| private val logger = KotlinLogging.logger {} | ||
|
|
||
| internal class SseTransportManager(transports: Map<String, SseServerTransport> = emptyMap()) { | ||
| private val transports: AtomicRef<PersistentMap<String, SseServerTransport>> = atomic(transports.toPersistentMap()) | ||
| internal class TransportManager(transports: Map<String, AbstractTransport> = emptyMap()) { | ||
| private val transports: AtomicRef<PersistentMap<String, AbstractTransport>> = atomic(transports.toPersistentMap()) | ||
|
|
||
| fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId] | ||
| fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId) | ||
|
|
||
| fun addTransport(transport: SseServerTransport) { | ||
| transports.update { it.put(transport.sessionId, transport) } | ||
| fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId] | ||
|
|
||
| fun addTransport(sessionId: String, transport: AbstractTransport) { | ||
| transports.update { it.put(sessionId, transport) } | ||
| } | ||
|
|
||
| fun removeTransport(sessionId: String) { | ||
|
|
@@ -49,14 +54,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) { | |
| */ | ||
| @KtorDsl | ||
| public fun Routing.mcp(block: ServerSSESession.() -> Server) { | ||
| val sseTransportManager = SseTransportManager() | ||
| val transportManager = TransportManager() | ||
|
|
||
| sse { | ||
| mcpSseEndpoint("", sseTransportManager, block) | ||
| mcpSseEndpoint("", transportManager, block) | ||
| } | ||
|
|
||
| post { | ||
| mcpPostEndpoint(sseTransportManager) | ||
| mcpPostEndpoint(transportManager) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -69,18 +74,71 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) { | |
| } | ||
| } | ||
|
|
||
| /* | ||
| * Configures the Ktor Application to handle Model Context Protocol (MCP) over Streamable Http. | ||
| * It currently only works with JSON response. | ||
| */ | ||
| @KtorDsl | ||
| public fun Application.mcpStreamableHttp( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val transportManager = TransportManager() | ||
|
|
||
| routing { | ||
| post("/mcp") { | ||
| mcpStreamableHttpEndpoint( | ||
| transportManager, | ||
| enableDnsRebindingProtection, | ||
| allowedHosts, | ||
| allowedOrigins, | ||
| eventStore, | ||
| block, | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /* | ||
| * Configures the Ktor Application to handle Model Context Protocol (MCP) over stateless Streamable Http. | ||
| * It currently only works with JSON response. | ||
| */ | ||
| @KtorDsl | ||
zshea marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public fun Application.mcpStatelessStreamableHttp( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| routing { | ||
| post("/mcp") { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't implement because it's not needed for json response. Once we add SSE back, we can do it then. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http The server MUST either return Content-Type: text/event-stream in response to this HTTP GET, or else return HTTP 405 Method Not Allowed, indicating that the server does not offer an SSE stream at this endpoint Without processing (which responds with code 405), it seems the inspector was spamming errors. I replaced There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm drafting a PR to return 405 for the stateless extension on GET requests, as per spec
|
||
| mcpStatelessStreamableHttpEndpoint( | ||
| enableDnsRebindingProtection, | ||
| allowedHosts, | ||
| allowedOrigins, | ||
| eventStore, | ||
| block, | ||
| ) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| internal suspend fun ServerSSESession.mcpSseEndpoint( | ||
| postEndpoint: String, | ||
| sseTransportManager: SseTransportManager, | ||
| transportManager: TransportManager, | ||
| block: ServerSSESession.() -> Server, | ||
| ) { | ||
| val transport = mcpSseTransport(postEndpoint, sseTransportManager) | ||
| val transport = mcpSseTransport(postEndpoint, transportManager) | ||
|
|
||
| val server = block() | ||
|
|
||
| server.onClose { | ||
| logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } | ||
| sseTransportManager.removeTransport(transport.sessionId) | ||
| transportManager.removeTransport(transport.sessionId) | ||
| } | ||
|
|
||
| server.createSession(transport) | ||
|
|
@@ -92,24 +150,106 @@ internal suspend fun ServerSSESession.mcpSseEndpoint( | |
|
|
||
| internal fun ServerSSESession.mcpSseTransport( | ||
| postEndpoint: String, | ||
| sseTransportManager: SseTransportManager, | ||
| transportManager: TransportManager, | ||
| ): SseServerTransport { | ||
| val transport = SseServerTransport(postEndpoint, this) | ||
| sseTransportManager.addTransport(transport) | ||
| transportManager.addTransport(transport.sessionId, transport) | ||
| logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } | ||
|
|
||
| return transport | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) { | ||
| internal suspend fun RoutingContext.mcpStreamableHttpEndpoint( | ||
| transportManager: TransportManager, | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val sessionId = this.call.request.header(MCP_SESSION_ID_HEADER) | ||
| val transport = if (sessionId != null && transportManager.hasTransport(sessionId)) { | ||
| transportManager.getTransport(sessionId) | ||
| } else if (sessionId == null) { | ||
| val transport = StreamableHttpServerTransport( | ||
| enableDnsRebindingProtection = enableDnsRebindingProtection, | ||
| allowedHosts = allowedHosts, | ||
| allowedOrigins = allowedOrigins, | ||
| eventStore = eventStore, | ||
| enableJsonResponse = true, | ||
| ) | ||
|
|
||
| transport.setOnSessionInitialized { sessionId -> | ||
| transportManager.addTransport(sessionId, transport) | ||
|
|
||
| logger.info { "New StreamableHttp connection established and stored with sessionId: $sessionId" } | ||
| } | ||
|
|
||
| val server = block() | ||
| server.onClose { | ||
| logger.info { "Server connection closed for sessionId: ${transport.sessionId}" } | ||
| } | ||
|
|
||
| server.createSession(transport) | ||
|
|
||
| transport | ||
| } else { | ||
| null | ||
| } | ||
|
|
||
| if (transport == null) { | ||
| this.call.reject( | ||
| HttpStatusCode.BadRequest, | ||
| RPCError.ErrorCode.CONNECTION_CLOSED, | ||
| "Bad Request: No valid session ID provided", | ||
| ) | ||
| return | ||
| } | ||
|
|
||
| (transport as StreamableHttpServerTransport).handleRequest(null, this.call) | ||
| logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" } | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint( | ||
| enableDnsRebindingProtection: Boolean = false, | ||
| allowedHosts: List<String>? = null, | ||
| allowedOrigins: List<String>? = null, | ||
| eventStore: EventStore? = null, | ||
| block: RoutingContext.() -> Server, | ||
| ) { | ||
| val transport = StreamableHttpServerTransport( | ||
| enableDnsRebindingProtection = enableDnsRebindingProtection, | ||
| allowedHosts = allowedHosts, | ||
| allowedOrigins = allowedOrigins, | ||
| eventStore = eventStore, | ||
| enableJsonResponse = true, | ||
| ) | ||
| transport.setSessionIdGenerator(null) | ||
|
|
||
| logger.info { "New stateless StreamableHttp connection established without sessionId" } | ||
|
|
||
| val server = block() | ||
|
|
||
| server.onClose { | ||
| logger.info { "Server connection closed without sessionId" } | ||
| } | ||
|
|
||
| server.createSession(transport) | ||
|
|
||
| transport.handleRequest(null, this.call) | ||
|
|
||
| logger.debug { "Server connected to transport without sessionId" } | ||
| } | ||
|
|
||
| internal suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) { | ||
| val sessionId: String = call.request.queryParameters["sessionId"] ?: run { | ||
| call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") | ||
| return | ||
| } | ||
|
|
||
| logger.debug { "Received message for sessionId: $sessionId" } | ||
|
|
||
| val transport = sseTransportManager.getTransport(sessionId) | ||
| val transport = transportManager.getTransport(sessionId) as SseServerTransport? | ||
| if (transport == null) { | ||
| logger.warn { "Session not found for sessionId: $sessionId" } | ||
| call.respond(HttpStatusCode.NotFound, "Session not found") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is it for?