@@ -14,7 +14,6 @@ import io.ktor.server.routing.routing
1414import io.ktor.server.sse.SSE
1515import io.ktor.server.sse.ServerSSESession
1616import io.ktor.server.sse.sse
17- import io.ktor.util.collections.ConcurrentMap
1817import io.ktor.utils.io.KtorDsl
1918import kotlinx.atomicfu.AtomicRef
2019import kotlinx.atomicfu.atomic
@@ -23,16 +22,19 @@ import kotlinx.collections.immutable.PersistentMap
2322import kotlinx.collections.immutable.toPersistentMap
2423import kotlinx.coroutines.awaitCancellation
2524import io.modelcontextprotocol.kotlin.sdk.ErrorCode
25+ import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
2626
2727private val logger = KotlinLogging .logger {}
2828
29- internal class SseTransportManager (transports : Map <String , SseServerTransport > = emptyMap()) {
30- private val transports: AtomicRef <PersistentMap <String , SseServerTransport >> = atomic(transports.toPersistentMap())
29+ internal class TransportManager (transports : Map <String , AbstractTransport > = emptyMap()) {
30+ private val transports: AtomicRef <PersistentMap <String , AbstractTransport >> = atomic(transports.toPersistentMap())
3131
32- fun getTransport (sessionId : String ): SseServerTransport ? = transports.value[ sessionId]
32+ fun hasTransport (sessionId : String ): Boolean = transports.value.containsKey( sessionId)
3333
34- fun addTransport (transport : SseServerTransport ) {
35- transports.update { it.put(transport.sessionId, transport) }
34+ fun getTransport (sessionId : String ): AbstractTransport ? = transports.value[sessionId]
35+
36+ fun addTransport (sessionId : String , transport : AbstractTransport ) {
37+ transports.update { it.put(sessionId, transport) }
3638 }
3739
3840 fun removeTransport (sessionId : String ) {
@@ -52,14 +54,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) {
5254*/
5355@KtorDsl
5456public fun Routing.mcp (block : ServerSSESession .() -> Server ) {
55- val sseTransportManager = SseTransportManager ()
57+ val transportManager = TransportManager ()
5658
5759 sse {
58- mcpSseEndpoint(" " , sseTransportManager , block)
60+ mcpSseEndpoint(" " , transportManager , block)
5961 }
6062
6163 post {
62- mcpPostEndpoint(sseTransportManager )
64+ mcpPostEndpoint(transportManager )
6365 }
6466}
6567
@@ -80,12 +82,12 @@ public fun Application.mcpStreamableHttp(
8082 eventStore : EventStore ? = null,
8183 block : RoutingContext .() -> Server ,
8284) {
83- val transports = ConcurrentMap < String , StreamableHttpServerTransport > ()
85+ val transportManager = TransportManager ()
8486
8587 routing {
8688 post(" /mcp" ) {
8789 mcpStreamableHttpEndpoint(
88- transports ,
90+ transportManager ,
8991 enableDnsRebindingProtection,
9092 allowedHosts,
9193 allowedOrigins,
@@ -119,16 +121,16 @@ public fun Application.mcpStatelessStreamableHttp(
119121
120122internal suspend fun ServerSSESession.mcpSseEndpoint (
121123 postEndpoint : String ,
122- sseTransportManager : SseTransportManager ,
124+ transportManager : TransportManager ,
123125 block : ServerSSESession .() -> Server ,
124126) {
125- val transport = mcpSseTransport(postEndpoint, sseTransportManager )
127+ val transport = mcpSseTransport(postEndpoint, transportManager )
126128
127129 val server = block()
128130
129131 server.onClose {
130132 logger.info { " Server connection closed for sessionId: ${transport.sessionId} " }
131- sseTransportManager .removeTransport(transport.sessionId)
133+ transportManager .removeTransport(transport.sessionId)
132134 }
133135
134136 server.createSession(transport)
@@ -140,26 +142,26 @@ internal suspend fun ServerSSESession.mcpSseEndpoint(
140142
141143internal fun ServerSSESession.mcpSseTransport (
142144 postEndpoint : String ,
143- sseTransportManager : SseTransportManager ,
145+ transportManager : TransportManager ,
144146): SseServerTransport {
145147 val transport = SseServerTransport (postEndpoint, this )
146- sseTransportManager .addTransport(transport)
148+ transportManager .addTransport(transport.sessionId, transport)
147149 logger.info { " New SSE connection established and stored with sessionId: ${transport.sessionId} " }
148150
149151 return transport
150152}
151153
152154internal suspend fun RoutingContext.mcpStreamableHttpEndpoint (
153- transports : ConcurrentMap < String , StreamableHttpServerTransport > ,
155+ transportManager : TransportManager ,
154156 enableDnsRebindingProtection : Boolean = false,
155157 allowedHosts : List <String >? = null,
156158 allowedOrigins : List <String >? = null,
157159 eventStore : EventStore ? = null,
158160 block : RoutingContext .() -> Server ,
159161) {
160162 val sessionId = this .call.request.header(MCP_SESSION_ID_HEADER )
161- val transport = if (sessionId != null && transports.containsKey (sessionId)) {
162- transports[ sessionId] !!
163+ val transport = if (sessionId != null && transportManager.hasTransport (sessionId)) {
164+ transportManager.getTransport( sessionId)
163165 } else if (sessionId == null ) {
164166 val transport = StreamableHttpServerTransport (
165167 enableDnsRebindingProtection = enableDnsRebindingProtection,
@@ -170,7 +172,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
170172 )
171173
172174 transport.setOnSessionInitialized { sessionId ->
173- transports[ sessionId] = transport
175+ transportManager.addTransport( sessionId, transport)
174176
175177 logger.info { " New StreamableHttp connection established and stored with sessionId: $sessionId " }
176178 }
@@ -196,7 +198,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
196198 return
197199 }
198200
199- transport.handleRequest(null , this .call)
201+ ( transport as StreamableHttpServerTransport ) .handleRequest(null , this .call)
200202 logger.debug { " Server connected to transport for sessionId: ${transport.sessionId} " }
201203}
202204
@@ -231,15 +233,15 @@ internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
231233 logger.debug { " Server connected to transport without sessionId" }
232234}
233235
234- internal suspend fun RoutingContext.mcpPostEndpoint (sseTransportManager : SseTransportManager ) {
236+ internal suspend fun RoutingContext.mcpPostEndpoint (transportManager : TransportManager ) {
235237 val sessionId: String = call.request.queryParameters[" sessionId" ] ? : run {
236238 call.respond(HttpStatusCode .BadRequest , " sessionId query parameter is not provided" )
237239 return
238240 }
239241
240242 logger.debug { " Received message for sessionId: $sessionId " }
241243
242- val transport = sseTransportManager .getTransport(sessionId)
244+ val transport = transportManager .getTransport(sessionId) as SseServerTransport ?
243245 if (transport == null ) {
244246 logger.warn { " Session not found for sessionId: $sessionId " }
245247 call.respond(HttpStatusCode .NotFound , " Session not found" )
0 commit comments