Skip to content

Commit e6e38ac

Browse files
authored
Fix providing progressToken with request (#405)
## Motivation and Context fix #220 ## How Has This Been Tested? locally ## Breaking Changes None ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [x] I have added appropriate error handling - [x] I have added or updated documentation as needed
1 parent 0f9209a commit e6e38ac

File tree

3 files changed

+234
-16
lines changed

3 files changed

+234
-16
lines changed

kotlin-sdk-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ kotlin {
122122
commonTest {
123123
dependencies {
124124
implementation(kotlin("test"))
125+
implementation(libs.kotlinx.coroutines.test)
125126
implementation(libs.kotest.assertions.core)
126127
implementation(libs.kotest.assertions.json)
127128
}

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,24 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
344344
if (handler != null) {
345345
messageId?.let { msg -> _progressHandlers.update { it.remove(msg) } }
346346
} else {
347-
onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}"))
347+
onError(
348+
IllegalStateException(
349+
"Received a response for an unknown message ID: ${McpJson.encodeToString(error ?: response)}",
350+
),
351+
)
348352
return
349353
}
350354

351355
if (response != null) {
352356
handler(response, null)
353357
} else {
354358
check(error != null)
355-
val error = McpException(
359+
val mcpException = McpException(
356360
code = error.error.code,
357361
message = error.error.message,
358362
data = error.error.data,
359363
)
360-
handler(null, error)
364+
handler(null, mcpException)
361365
}
362366
}
363367

@@ -403,18 +407,30 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
403407
assertCapabilityForMethod(request.method)
404408
}
405409

406-
val message = request.toJSON()
407-
val messageId = message.id
410+
val jsonRpcRequest = request.toJSON().run {
411+
options?.onProgress?.let { progressHandler ->
412+
logger.trace { "Registering progress handler for request id: $id" }
413+
_progressHandlers.update { current ->
414+
current.put(id, progressHandler)
415+
}
408416

409-
if (options?.onProgress != null) {
410-
logger.trace { "Registering progress handler for request id: $messageId" }
411-
_progressHandlers.update { current ->
412-
current.put(messageId, options.onProgress)
413-
}
417+
val paramsObject = (this.params as? JsonObject) ?: JsonObject(emptyMap())
418+
val metaObject = request.params?.meta?.json ?: JsonObject(emptyMap())
419+
420+
val updatedMeta = JsonObject(
421+
metaObject + ("progressToken" to McpJson.encodeToJsonElement(id)),
422+
)
423+
val updatedParams = JsonObject(
424+
paramsObject + ("_meta" to updatedMeta),
425+
)
426+
427+
this.copy(params = updatedParams)
428+
} ?: this
414429
}
430+
val jsonRpcRequestId = jsonRpcRequest.id
415431

416432
_responseHandlers.update { current ->
417-
current.put(messageId) { response, error ->
433+
current.put(jsonRpcRequestId) { response, error ->
418434
if (error != null) {
419435
result.completeExceptionally(error)
420436
return@put
@@ -430,12 +446,12 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
430446
}
431447

432448
val cancel: suspend (Throwable) -> Unit = { reason: Throwable ->
433-
_responseHandlers.update { current -> current.remove(messageId) }
434-
_progressHandlers.update { current -> current.remove(messageId) }
449+
_responseHandlers.update { current -> current.remove(jsonRpcRequestId) }
450+
_progressHandlers.update { current -> current.remove(jsonRpcRequestId) }
435451

436452
val notification = CancelledNotification(
437453
params = CancelledNotificationParams(
438-
requestId = messageId,
454+
requestId = jsonRpcRequestId,
439455
reason = reason.message ?: "Unknown",
440456
),
441457
)
@@ -452,8 +468,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
452468
val timeout = options?.timeout ?: DEFAULT_REQUEST_TIMEOUT
453469
try {
454470
withTimeout(timeout) {
455-
logger.trace { "Sending request message with id: $messageId" }
456-
this@Protocol.transport?.send(message)
471+
logger.trace { "Sending request message with id: $jsonRpcRequestId" }
472+
this@Protocol.transport?.send(jsonRpcRequest)
457473
}
458474
return result.await()
459475
} catch (cause: TimeoutCancellationException) {
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package io.modelcontextprotocol.kotlin.sdk.shared
2+
3+
import io.kotest.matchers.collections.shouldContainExactly
4+
import io.kotest.matchers.nulls.shouldNotBeNull
5+
import io.kotest.matchers.shouldBe
6+
import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest
7+
import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult
8+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
9+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest
10+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
11+
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
12+
import io.modelcontextprotocol.kotlin.sdk.types.Method
13+
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
14+
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
15+
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
16+
import kotlinx.coroutines.async
17+
import kotlinx.coroutines.channels.Channel
18+
import kotlinx.coroutines.test.runTest
19+
import kotlinx.serialization.json.JsonObject
20+
import kotlinx.serialization.json.JsonObjectBuilder
21+
import kotlinx.serialization.json.JsonPrimitive
22+
import kotlinx.serialization.json.buildJsonObject
23+
import kotlinx.serialization.json.encodeToJsonElement
24+
import kotlinx.serialization.json.int
25+
import kotlinx.serialization.json.jsonObject
26+
import kotlinx.serialization.json.jsonPrimitive
27+
import kotlin.test.BeforeTest
28+
import kotlin.test.Test
29+
30+
class ProtocolTest {
31+
private lateinit var protocol: TestProtocol
32+
private lateinit var transport: RecordingTransport
33+
34+
@BeforeTest
35+
fun setUp() {
36+
protocol = TestProtocol()
37+
transport = RecordingTransport()
38+
}
39+
40+
@Test
41+
fun `should preserve existing meta when adding progress token`() = runTest {
42+
protocol.connect(transport)
43+
val request = ReadResourceRequest(
44+
ReadResourceRequestParams(
45+
uri = "test://resource",
46+
meta = metaOf {
47+
put("customField", JsonPrimitive("customValue"))
48+
put("anotherField", JsonPrimitive(123))
49+
},
50+
),
51+
)
52+
53+
val inFlight = async {
54+
protocol.request<EmptyResult>(
55+
request = request,
56+
options = RequestOptions(onProgress = {}),
57+
)
58+
}
59+
60+
val sent = transport.awaitRequest()
61+
val params = sent.params?.jsonObject.shouldNotBeNull()
62+
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()
63+
64+
params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"
65+
meta["customField"]?.jsonPrimitive?.content shouldBe "customValue"
66+
meta["anotherField"]?.jsonPrimitive?.int shouldBe 123
67+
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)
68+
69+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
70+
inFlight.await()
71+
}
72+
73+
@Test
74+
fun `should create meta with progress token when none exists`() = runTest {
75+
protocol.connect(transport)
76+
val request = ReadResourceRequest(
77+
ReadResourceRequestParams(
78+
uri = "test://resource",
79+
meta = null,
80+
),
81+
)
82+
83+
val inFlight = async {
84+
protocol.request<EmptyResult>(
85+
request = request,
86+
options = RequestOptions(onProgress = {}),
87+
)
88+
}
89+
90+
val sent = transport.awaitRequest()
91+
val params = sent.params?.jsonObject.shouldNotBeNull()
92+
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()
93+
94+
params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"
95+
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)
96+
97+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
98+
inFlight.await()
99+
}
100+
101+
@Test
102+
fun `should not modify meta when onProgress is absent`() = runTest {
103+
protocol.connect(transport)
104+
val originalMeta = metaJson {
105+
put("customField", JsonPrimitive("customValue"))
106+
}
107+
val request = ReadResourceRequest(
108+
ReadResourceRequestParams(
109+
uri = "test://resource",
110+
meta = RequestMeta(originalMeta),
111+
),
112+
)
113+
114+
val inFlight = async {
115+
protocol.request<EmptyResult>(request)
116+
}
117+
118+
val sent = transport.awaitRequest()
119+
val params = sent.params?.jsonObject.shouldNotBeNull()
120+
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()
121+
122+
meta shouldBe originalMeta
123+
params["uri"]?.jsonPrimitive?.content shouldBe "test://resource"
124+
125+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
126+
inFlight.await()
127+
}
128+
129+
@Test
130+
fun `should create params object when request params are null`() = runTest {
131+
protocol.connect(transport)
132+
val request = CustomRequest(
133+
method = Method.Custom("example"),
134+
params = null,
135+
)
136+
137+
val inFlight = async {
138+
protocol.request<EmptyResult>(
139+
request = request,
140+
options = RequestOptions(onProgress = {}),
141+
)
142+
}
143+
144+
val sent = transport.awaitRequest()
145+
val params = sent.params?.jsonObject.shouldNotBeNull()
146+
val meta = params["_meta"]?.jsonObject.shouldNotBeNull()
147+
148+
params.keys shouldContainExactly setOf("_meta")
149+
meta["progressToken"] shouldBe McpJson.encodeToJsonElement(sent.id)
150+
151+
transport.deliver(JSONRPCResponse(sent.id, EmptyResult()))
152+
inFlight.await()
153+
}
154+
}
155+
156+
private class TestProtocol : Protocol(null) {
157+
override fun assertCapabilityForMethod(method: Method) {}
158+
override fun assertNotificationCapability(method: Method) {}
159+
override fun assertRequestHandlerCapability(method: Method) {}
160+
}
161+
162+
private class RecordingTransport : Transport {
163+
private val sentMessages = Channel<JSONRPCMessage>(Channel.UNLIMITED)
164+
private var onMessageCallback: (suspend (JSONRPCMessage) -> Unit)? = null
165+
private var onCloseCallback: (() -> Unit)? = null
166+
167+
override suspend fun start() {}
168+
169+
override suspend fun send(message: JSONRPCMessage) {
170+
sentMessages.send(message)
171+
}
172+
173+
override suspend fun close() {
174+
onCloseCallback?.invoke()
175+
}
176+
177+
override fun onClose(block: () -> Unit) {
178+
onCloseCallback = block
179+
}
180+
181+
override fun onError(block: (Throwable) -> Unit) {}
182+
183+
override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
184+
onMessageCallback = block
185+
}
186+
187+
suspend fun awaitRequest(): JSONRPCRequest {
188+
val message = sentMessages.receive()
189+
return message as? JSONRPCRequest
190+
?: error("Expected JSONRPCRequest but received ${message::class.simpleName}")
191+
}
192+
193+
suspend fun deliver(message: JSONRPCMessage) {
194+
val callback = onMessageCallback ?: error("onMessage callback not registered")
195+
callback(message)
196+
}
197+
}
198+
199+
private fun metaOf(builderAction: JsonObjectBuilder.() -> Unit): RequestMeta = RequestMeta(metaJson(builderAction))
200+
201+
private fun metaJson(builderAction: JsonObjectBuilder.() -> Unit): JsonObject = buildJsonObject(builderAction)

0 commit comments

Comments
 (0)