Skip to content

Commit 107496e

Browse files
committed
Enforced message size limits, improved error handling in event replay and request validation, and optimized DNS rebind protection logic. Added safeguards for headers and session cleanups.
1 parent ecaeb03 commit 107496e

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import io.ktor.http.HttpStatusCode
77
import io.ktor.server.application.ApplicationCall
88
import io.ktor.server.request.contentType
99
import io.ktor.server.request.header
10-
import io.ktor.server.request.host
1110
import io.ktor.server.request.httpMethod
1211
import io.ktor.server.request.receiveText
1312
import io.ktor.server.response.header
@@ -41,6 +40,7 @@ import kotlin.uuid.Uuid
4140
internal const val MCP_SESSION_ID_HEADER = "mcp-session-id"
4241
private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
4342
private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID"
43+
private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB
4444

4545
/**
4646
* Interface for resumability support via event storage
@@ -415,6 +415,8 @@ public class StreamableHttpServerTransport(
415415

416416
try {
417417
call.appendSseHeaders()
418+
session.send(data = "") // flush headers immediately
419+
418420
val streamId = store.replayEventsAfter(lastEventId) { eventId, message ->
419421
try {
420422
session.send(
@@ -423,10 +425,16 @@ public class StreamableHttpServerTransport(
423425
data = McpJson.encodeToString(message),
424426
)
425427
} catch (e: Exception) {
426-
_onError(e)
428+
_onError(IllegalStateException("Failed to replay event: ${e.message}", e))
427429
}
428430
}
431+
429432
streamsMapping[streamId] = SessionContext(session, call)
433+
434+
session.coroutineContext.job.invokeOnCompletion { throwable ->
435+
streamsMapping.remove(streamId)
436+
throwable?.let { _onError(it) }
437+
}
430438
} catch (e: Exception) {
431439
_onError(e)
432440
}
@@ -494,15 +502,19 @@ public class StreamableHttpServerTransport(
494502
if (!enableDnsRebindingProtection) return null
495503

496504
allowedHosts?.let { hosts ->
497-
val hostHeader = call.request.host().substringBefore(':').lowercase()
498-
if (hostHeader !in hosts.map { it.substringBefore(':').lowercase() }) {
505+
val hostHeader = call.request.headers[HttpHeaders.Host]?.lowercase()
506+
val allowedHostsLowercase = hosts.map { it.lowercase() }
507+
508+
if (hostHeader == null || hostHeader !in allowedHostsLowercase) {
499509
return "Invalid Host header: $hostHeader"
500510
}
501511
}
502512

503513
allowedOrigins?.let { origins ->
504-
val originHeader = call.request.headers[HttpHeaders.Origin]?.removeSuffix("/")?.lowercase()
505-
if (originHeader !in origins.map { it.removeSuffix("/").lowercase() }) {
514+
val originHeader = call.request.headers[HttpHeaders.Origin]?.lowercase()
515+
val allowedOriginsLowercase = origins.map { it.lowercase() }
516+
517+
if (originHeader == null || originHeader !in allowedOriginsLowercase) {
506518
return "Invalid Origin header: $originHeader"
507519
}
508520
}
@@ -511,7 +523,26 @@ public class StreamableHttpServerTransport(
511523
}
512524

513525
private suspend fun parseBody(call: ApplicationCall): List<JSONRPCMessage>? {
526+
val contentLength = call.request.header(HttpHeaders.ContentLength)?.toIntOrNull() ?: 0
527+
if (contentLength > MAXIMUM_MESSAGE_SIZE) {
528+
call.reject(
529+
HttpStatusCode.PayloadTooLarge,
530+
RPCError.ErrorCode.INVALID_REQUEST,
531+
"Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024)} MB",
532+
)
533+
return null
534+
}
535+
514536
val body = call.receiveText()
537+
if (body.length > MAXIMUM_MESSAGE_SIZE) {
538+
call.reject(
539+
HttpStatusCode.PayloadTooLarge,
540+
RPCError.ErrorCode.INVALID_REQUEST,
541+
"Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024)} MB",
542+
)
543+
return null
544+
}
545+
515546
return when (val element = McpJson.parseToJsonElement(body)) {
516547
is JsonObject -> listOf(McpJson.decodeFromJsonElement(element))
517548

@@ -556,10 +587,5 @@ public class StreamableHttpServerTransport(
556587

557588
internal suspend fun ApplicationCall.reject(status: HttpStatusCode, code: Int, message: String) {
558589
this.response.status(status)
559-
this.respond(
560-
JSONRPCError(
561-
id = null,
562-
error = RPCError(code = code, message = message),
563-
),
564-
)
590+
this.respond(JSONRPCError(id = null, error = RPCError(code = code, message = message)))
565591
}

0 commit comments

Comments
 (0)