@@ -7,7 +7,6 @@ import io.ktor.http.HttpStatusCode
77import io.ktor.server.application.ApplicationCall
88import io.ktor.server.request.contentType
99import io.ktor.server.request.header
10- import io.ktor.server.request.host
1110import io.ktor.server.request.httpMethod
1211import io.ktor.server.request.receiveText
1312import io.ktor.server.response.header
@@ -41,6 +40,7 @@ import kotlin.uuid.Uuid
4140internal const val MCP_SESSION_ID_HEADER = " mcp-session-id"
4241private const val MCP_PROTOCOL_VERSION_HEADER = " mcp-protocol-version"
4342private const val MCP_RESUMPTION_TOKEN_HEADER = " Last-Event-ID"
43+ private const val MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 // 4 MB
4444
4545/* *
4646 * Interface for resumability support via event storage
@@ -415,6 +415,8 @@ public class StreamableHttpServerTransport(
415415
416416 try {
417417 call.appendSseHeaders()
418+ session.send(data = " " ) // flush headers immediately
419+
418420 val streamId = store.replayEventsAfter(lastEventId) { eventId, message ->
419421 try {
420422 session.send(
@@ -423,10 +425,16 @@ public class StreamableHttpServerTransport(
423425 data = McpJson .encodeToString(message),
424426 )
425427 } catch (e: Exception ) {
426- _onError (e )
428+ _onError (IllegalStateException ( " Failed to replay event: ${e.message} " , e) )
427429 }
428430 }
431+
429432 streamsMapping[streamId] = SessionContext (session, call)
433+
434+ session.coroutineContext.job.invokeOnCompletion { throwable ->
435+ streamsMapping.remove(streamId)
436+ throwable?.let { _onError (it) }
437+ }
430438 } catch (e: Exception ) {
431439 _onError (e)
432440 }
@@ -494,15 +502,19 @@ public class StreamableHttpServerTransport(
494502 if (! enableDnsRebindingProtection) return null
495503
496504 allowedHosts?.let { hosts ->
497- val hostHeader = call.request.host().substringBefore(' :' ).lowercase()
498- if (hostHeader !in hosts.map { it.substringBefore(' :' ).lowercase() }) {
505+ val hostHeader = call.request.headers[HttpHeaders .Host ]?.lowercase()
506+ val allowedHostsLowercase = hosts.map { it.lowercase() }
507+
508+ if (hostHeader == null || hostHeader !in allowedHostsLowercase) {
499509 return " Invalid Host header: $hostHeader "
500510 }
501511 }
502512
503513 allowedOrigins?.let { origins ->
504- val originHeader = call.request.headers[HttpHeaders .Origin ]?.removeSuffix(" /" )?.lowercase()
505- if (originHeader !in origins.map { it.removeSuffix(" /" ).lowercase() }) {
514+ val originHeader = call.request.headers[HttpHeaders .Origin ]?.lowercase()
515+ val allowedOriginsLowercase = origins.map { it.lowercase() }
516+
517+ if (originHeader == null || originHeader !in allowedOriginsLowercase) {
506518 return " Invalid Origin header: $originHeader "
507519 }
508520 }
@@ -511,7 +523,26 @@ public class StreamableHttpServerTransport(
511523 }
512524
513525 private suspend fun parseBody (call : ApplicationCall ): List <JSONRPCMessage >? {
526+ val contentLength = call.request.header(HttpHeaders .ContentLength )?.toIntOrNull() ? : 0
527+ if (contentLength > MAXIMUM_MESSAGE_SIZE ) {
528+ call.reject(
529+ HttpStatusCode .PayloadTooLarge ,
530+ RPCError .ErrorCode .INVALID_REQUEST ,
531+ " Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024 )} MB" ,
532+ )
533+ return null
534+ }
535+
514536 val body = call.receiveText()
537+ if (body.length > MAXIMUM_MESSAGE_SIZE ) {
538+ call.reject(
539+ HttpStatusCode .PayloadTooLarge ,
540+ RPCError .ErrorCode .INVALID_REQUEST ,
541+ " Invalid Request: message size exceeds maximum of ${MAXIMUM_MESSAGE_SIZE / (1024 * 1024 )} MB" ,
542+ )
543+ return null
544+ }
545+
515546 return when (val element = McpJson .parseToJsonElement(body)) {
516547 is JsonObject -> listOf (McpJson .decodeFromJsonElement(element))
517548
@@ -556,10 +587,5 @@ public class StreamableHttpServerTransport(
556587
557588internal suspend fun ApplicationCall.reject (status : HttpStatusCode , code : Int , message : String ) {
558589 this .response.status(status)
559- this .respond(
560- JSONRPCError (
561- id = null ,
562- error = RPCError (code = code, message = message),
563- ),
564- )
590+ this .respond(JSONRPCError (id = null , error = RPCError (code = code, message = message)))
565591}
0 commit comments