Skip to content

Commit c1ad3d2

Browse files
committed
Refactored StreamableHttpServerTransport to improve SSE session handling, validate header values, and simplify MIME type matching logic.
1 parent 613d3fb commit c1ad3d2

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ public class StreamableHttpServerTransport(
249249
}
250250
}
251251
streamsMapping.clear()
252+
requestToStreamMapping.clear()
252253
requestToResponseMapping.clear()
253254
_onClose()
254255
}
@@ -380,7 +381,7 @@ public class StreamableHttpServerTransport(
380381
)
381382
return
382383
}
383-
session!!
384+
val sseSession = session ?: error("Server session can't be null for streaming GET requests")
384385

385386
val acceptHeader = call.request.header(HttpHeaders.Accept)
386387
if (!acceptHeader.accepts(ContentType.Text.EventStream)) {
@@ -396,7 +397,7 @@ public class StreamableHttpServerTransport(
396397

397398
eventStore?.let { store ->
398399
call.request.header(MCP_RESUMPTION_TOKEN_HEADER)?.let { lastEventId ->
399-
replayEvents(store, lastEventId, session)
400+
replayEvents(store, lastEventId, sseSession)
400401
return
401402
}
402403
}
@@ -411,10 +412,12 @@ public class StreamableHttpServerTransport(
411412
}
412413

413414
call.appendSseHeaders()
414-
flushSse(session) // flush headers immediately
415-
streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(session, call)
416-
maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, session)
417-
session.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(STANDALONE_SSE_STREAM_ID) }
415+
flushSse(sseSession) // flush headers immediately
416+
streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call)
417+
maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession)
418+
sseSession.coroutineContext.job.invokeOnCompletion {
419+
streamsMapping.remove(STANDALONE_SSE_STREAM_ID)
420+
}
418421
}
419422

420423
public suspend fun handleDeleteRequest(call: ApplicationCall) {
@@ -516,33 +519,45 @@ public class StreamableHttpServerTransport(
516519
return false
517520
}
518521

519-
val headerId = call.request.header(MCP_SESSION_ID_HEADER)
522+
val sessionHeaderValues = call.request.headers.getAll(MCP_SESSION_ID_HEADER)
520523

521-
return when {
522-
headerId == null -> {
523-
call.reject(
524-
HttpStatusCode.BadRequest,
525-
RPCError.ErrorCode.CONNECTION_CLOSED,
526-
"Bad Request: Mcp-Session-Id header is required",
527-
)
528-
false
529-
}
524+
if (sessionHeaderValues.isNullOrEmpty()) {
525+
call.reject(
526+
HttpStatusCode.BadRequest,
527+
RPCError.ErrorCode.CONNECTION_CLOSED,
528+
"Bad Request: Mcp-Session-Id header is required",
529+
)
530+
return false
531+
}
530532

531-
headerId != sessionId -> {
533+
if (sessionHeaderValues.size > 1) {
534+
call.reject(
535+
HttpStatusCode.BadRequest,
536+
RPCError.ErrorCode.CONNECTION_CLOSED,
537+
"Bad Request: Mcp-Session-Id header must be a single value",
538+
)
539+
return false
540+
}
541+
542+
val headerId = sessionHeaderValues.single()
543+
544+
return when (headerId) {
545+
sessionId -> true
546+
547+
else -> {
532548
call.reject(
533549
HttpStatusCode.NotFound,
534550
-32001,
535551
"Session not found",
536552
)
537553
false
538554
}
539-
540-
else -> true
541555
}
542556
}
543557

544558
private suspend fun validateProtocolVersion(call: ApplicationCall): Boolean {
545-
val version = call.request.header(MCP_PROTOCOL_VERSION_HEADER) ?: DEFAULT_NEGOTIATED_PROTOCOL_VERSION
559+
val protocolVersions = call.request.headers.getAll(MCP_PROTOCOL_VERSION_HEADER)
560+
val version = protocolVersions?.lastOrNull() ?: DEFAULT_NEGOTIATED_PROTOCOL_VERSION
546561

547562
return when (version) {
548563
!in SUPPORTED_PROTOCOL_VERSIONS -> {
@@ -631,13 +646,8 @@ public class StreamableHttpServerTransport(
631646
}
632647
}
633648

634-
private fun String?.accepts(mime: ContentType): Boolean {
635-
if (this == null) return false
636-
637-
val escaped = Regex.escape(mime.toString())
638-
val pattern = Regex("""(^|,\s*)$escaped(\s*(;|,|$))""", RegexOption.IGNORE_CASE)
639-
return pattern.containsMatchIn(this)
640-
}
649+
private fun String?.accepts(mime: ContentType): Boolean =
650+
this?.lowercase()?.contains(mime.toString().lowercase()) == true
641651

642652
private suspend fun emitOnStream(streamId: String, session: ServerSSESession?, message: JSONRPCMessage) {
643653
val eventId = eventStore?.storeEvent(streamId, message)

0 commit comments

Comments
 (0)