Skip to content

Commit 3886187

Browse files
zsheadevcrocod
authored andcommitted
update from comments
1 parent a18889a commit 3886187

File tree

3 files changed

+28
-29
lines changed

3 files changed

+28
-29
lines changed

kotlin-sdk-server/api/kotlin-sdk-server.api

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
public final class io/modelcontextprotocol/kotlin/sdk/LibVersionKt {
2-
public static final field LIB_VERSION Ljava/lang/String;
3-
}
4-
51
public abstract interface class io/modelcontextprotocol/kotlin/sdk/server/EventStore {
62
public abstract fun replayEventsAfter (Ljava/lang/String;Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
73
public abstract fun storeEvent (Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import io.ktor.server.routing.routing
1414
import io.ktor.server.sse.SSE
1515
import io.ktor.server.sse.ServerSSESession
1616
import io.ktor.server.sse.sse
17-
import io.ktor.util.collections.ConcurrentMap
1817
import io.ktor.utils.io.KtorDsl
1918
import kotlinx.atomicfu.AtomicRef
2019
import kotlinx.atomicfu.atomic
@@ -23,16 +22,19 @@ import kotlinx.collections.immutable.PersistentMap
2322
import kotlinx.collections.immutable.toPersistentMap
2423
import kotlinx.coroutines.awaitCancellation
2524
import io.modelcontextprotocol.kotlin.sdk.ErrorCode
25+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
2626

2727
private val logger = KotlinLogging.logger {}
2828

29-
internal class SseTransportManager(transports: Map<String, SseServerTransport> = emptyMap()) {
30-
private val transports: AtomicRef<PersistentMap<String, SseServerTransport>> = atomic(transports.toPersistentMap())
29+
internal class TransportManager(transports: Map<String, AbstractTransport> = emptyMap()) {
30+
private val transports: AtomicRef<PersistentMap<String, AbstractTransport>> = atomic(transports.toPersistentMap())
3131

32-
fun getTransport(sessionId: String): SseServerTransport? = transports.value[sessionId]
32+
fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId)
3333

34-
fun addTransport(transport: SseServerTransport) {
35-
transports.update { it.put(transport.sessionId, transport) }
34+
fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId]
35+
36+
fun addTransport(sessionId: String, transport: AbstractTransport) {
37+
transports.update { it.put(sessionId, transport) }
3638
}
3739

3840
fun removeTransport(sessionId: String) {
@@ -52,14 +54,14 @@ public fun Routing.mcp(path: String, block: ServerSSESession.() -> Server) {
5254
*/
5355
@KtorDsl
5456
public fun Routing.mcp(block: ServerSSESession.() -> Server) {
55-
val sseTransportManager = SseTransportManager()
57+
val transportManager = TransportManager()
5658

5759
sse {
58-
mcpSseEndpoint("", sseTransportManager, block)
60+
mcpSseEndpoint("", transportManager, block)
5961
}
6062

6163
post {
62-
mcpPostEndpoint(sseTransportManager)
64+
mcpPostEndpoint(transportManager)
6365
}
6466
}
6567

@@ -80,12 +82,12 @@ public fun Application.mcpStreamableHttp(
8082
eventStore: EventStore? = null,
8183
block: RoutingContext.() -> Server,
8284
) {
83-
val transports = ConcurrentMap<String, StreamableHttpServerTransport>()
85+
val transportManager = TransportManager()
8486

8587
routing {
8688
post("/mcp") {
8789
mcpStreamableHttpEndpoint(
88-
transports,
90+
transportManager,
8991
enableDnsRebindingProtection,
9092
allowedHosts,
9193
allowedOrigins,
@@ -119,16 +121,16 @@ public fun Application.mcpStatelessStreamableHttp(
119121

120122
internal suspend fun ServerSSESession.mcpSseEndpoint(
121123
postEndpoint: String,
122-
sseTransportManager: SseTransportManager,
124+
transportManager: TransportManager,
123125
block: ServerSSESession.() -> Server,
124126
) {
125-
val transport = mcpSseTransport(postEndpoint, sseTransportManager)
127+
val transport = mcpSseTransport(postEndpoint, transportManager)
126128

127129
val server = block()
128130

129131
server.onClose {
130132
logger.info { "Server connection closed for sessionId: ${transport.sessionId}" }
131-
sseTransportManager.removeTransport(transport.sessionId)
133+
transportManager.removeTransport(transport.sessionId)
132134
}
133135

134136
server.createSession(transport)
@@ -140,26 +142,26 @@ internal suspend fun ServerSSESession.mcpSseEndpoint(
140142

141143
internal fun ServerSSESession.mcpSseTransport(
142144
postEndpoint: String,
143-
sseTransportManager: SseTransportManager,
145+
transportManager: TransportManager,
144146
): SseServerTransport {
145147
val transport = SseServerTransport(postEndpoint, this)
146-
sseTransportManager.addTransport(transport)
148+
transportManager.addTransport(transport.sessionId, transport)
147149
logger.info { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
148150

149151
return transport
150152
}
151153

152154
internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
153-
transports: ConcurrentMap<String, StreamableHttpServerTransport>,
155+
transportManager: TransportManager,
154156
enableDnsRebindingProtection: Boolean = false,
155157
allowedHosts: List<String>? = null,
156158
allowedOrigins: List<String>? = null,
157159
eventStore: EventStore? = null,
158160
block: RoutingContext.() -> Server,
159161
) {
160162
val sessionId = this.call.request.header(MCP_SESSION_ID_HEADER)
161-
val transport = if (sessionId != null && transports.containsKey(sessionId)) {
162-
transports[sessionId]!!
163+
val transport = if (sessionId != null && transportManager.hasTransport(sessionId)) {
164+
transportManager.getTransport(sessionId)
163165
} else if (sessionId == null) {
164166
val transport = StreamableHttpServerTransport(
165167
enableDnsRebindingProtection = enableDnsRebindingProtection,
@@ -170,7 +172,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
170172
)
171173

172174
transport.setOnSessionInitialized { sessionId ->
173-
transports[sessionId] = transport
175+
transportManager.addTransport(sessionId, transport)
174176

175177
logger.info { "New StreamableHttp connection established and stored with sessionId: $sessionId" }
176178
}
@@ -196,7 +198,7 @@ internal suspend fun RoutingContext.mcpStreamableHttpEndpoint(
196198
return
197199
}
198200

199-
transport.handleRequest(null, this.call)
201+
(transport as StreamableHttpServerTransport).handleRequest(null, this.call)
200202
logger.debug { "Server connected to transport for sessionId: ${transport.sessionId}" }
201203
}
202204

@@ -231,15 +233,15 @@ internal suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
231233
logger.debug { "Server connected to transport without sessionId" }
232234
}
233235

234-
internal suspend fun RoutingContext.mcpPostEndpoint(sseTransportManager: SseTransportManager) {
236+
internal suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) {
235237
val sessionId: String = call.request.queryParameters["sessionId"] ?: run {
236238
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
237239
return
238240
}
239241

240242
logger.debug { "Received message for sessionId: $sessionId" }
241243

242-
val transport = sseTransportManager.getTransport(sessionId)
244+
val transport = transportManager.getTransport(sessionId) as SseServerTransport?
243245
if (transport == null) {
244246
logger.warn { "Session not found for sessionId: $sessionId" }
245247
call.respond(HttpStatusCode.NotFound, "Session not found")

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import io.ktor.server.request.httpMethod
1212
import io.ktor.server.request.receiveText
1313
import io.ktor.server.response.header
1414
import io.ktor.server.response.respond
15+
import io.ktor.server.response.respondBytes
1516
import io.ktor.server.response.respondNullable
1617
import io.ktor.server.sse.ServerSSESession
1718
import io.ktor.util.collections.ConcurrentMap
@@ -332,7 +333,7 @@ public class StreamableHttpServerTransport(
332333

333334
val hasRequest = messages.any { it is JSONRPCRequest }
334335
if (!hasRequest) {
335-
call.respondNullable(status = HttpStatusCode.Accepted, message = null)
336+
call.respondBytes(status = HttpStatusCode.Accepted, bytes = ByteArray(0))
336337
messages.forEach { message -> _onMessage(message) }
337338
return
338339
}
@@ -568,7 +569,7 @@ internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: ErrorC
568569
this.response.status(status)
569570
this.respond(
570571
JSONRPCResponse(
571-
id = null,
572+
id = RequestId.StringId("server-error"),
572573
error = JSONRPCError(message = message, code = code),
573574
),
574575
)

0 commit comments

Comments
 (0)