Skip to content

Commit 669bd34

Browse files
committed
mcp: don't treat pointers differently given an OutputSchema override
In the fix for #199, I was overly cautious: I assumed that if the user provided their own schema, we shouldn't fix up typed nils for them. But this was wrong: consider that: 1. If the user provides a schema, it must be an "object" schema. 2. One way that we recommend customizing schemas is by calling jsonschema.For and modifying. It's very confusing that doing this would change the treatment of the zero value. The good news is that any tool that was returning a typed nil is provably wrong, so this is just a bug that we can and should fix. Fixes #691
1 parent 4c52f37 commit 669bd34

File tree

2 files changed

+50
-46
lines changed

2 files changed

+50
-46
lines changed

mcp/mcp_test.go

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,16 +1902,20 @@ func TestPointerArgEquivalence(t *testing.T) {
19021902
type input struct {
19031903
In string `json:",omitempty"`
19041904
}
1905+
inputSchema := json.RawMessage(`{"type":"object","properties":{"In":{"type":"string"}},"additionalProperties":false}`)
19051906
type output struct {
19061907
Out string
19071908
}
1909+
outputSchema := json.RawMessage(`{"type":"object","required":["Out"],"properties":{"Out":{"type":"string"}},"additionalProperties":false}`)
19081910
cs, _, cleanup := basicConnection(t, func(s *Server) {
1909-
// Add two equivalent tools, one of which operates in the 'pointer' realm,
1910-
// the other of which does not.
1911+
// Add three equivalent tools:
1912+
// - one operates on pointers, with inferred schemas
1913+
// - one operates on pointers, with user-provided schemas
1914+
// - one operates on non-pointers
19111915
//
19121916
// We handle a few different types of results, to assert they behave the
19131917
// same in all cases.
1914-
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) {
1918+
handlePointers := func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) {
19151919
switch in.In {
19161920
case "":
19171921
return nil, nil, fmt.Errorf("must provide input")
@@ -1924,7 +1928,13 @@ func TestPointerArgEquivalence(t *testing.T) {
19241928
default:
19251929
panic("unreachable")
19261930
}
1927-
})
1931+
}
1932+
AddTool(s, &Tool{Name: "pointer-inferred"}, handlePointers)
1933+
AddTool(s, &Tool{
1934+
Name: "pointer-provided",
1935+
InputSchema: inputSchema,
1936+
OutputSchema: outputSchema,
1937+
}, handlePointers)
19281938
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) {
19291939
switch in.In {
19301940
case "":
@@ -1947,50 +1957,51 @@ func TestPointerArgEquivalence(t *testing.T) {
19471957
if err != nil {
19481958
t.Fatal(err)
19491959
}
1950-
if got, want := len(tools.Tools), 2; got != want {
1960+
if got, want := len(tools.Tools), 3; got != want {
19511961
t.Fatalf("got %d tools, want %d", got, want)
19521962
}
1953-
t0 := tools.Tools[0]
1954-
t1 := tools.Tools[1]
1955-
1956-
// First, check that the tool schemas don't differ.
1957-
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
1958-
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1959-
}
1960-
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
1961-
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1962-
}
19631963

1964-
// Then, check that we handle empty input equivalently.
1965-
for _, args := range []any{nil, struct{}{}} {
1966-
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
1967-
if err != nil {
1968-
t.Fatal(err)
1969-
}
1970-
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
1971-
if err != nil {
1972-
t.Fatal(err)
1964+
t0 := tools.Tools[0]
1965+
for _, t1 := range tools.Tools[1:] {
1966+
// First, check that the tool schemas don't differ.
1967+
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
1968+
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
19731969
}
1974-
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
1975-
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
1970+
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
1971+
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
19761972
}
1977-
}
19781973

1979-
// Then, check that we handle different types of output equivalently.
1980-
for _, in := range []string{"nil", "empty", "ok"} {
1981-
t.Run(in, func(t *testing.T) {
1982-
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
1974+
// Then, check that we handle empty input equivalently.
1975+
for _, args := range []any{nil, struct{}{}} {
1976+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
19831977
if err != nil {
19841978
t.Fatal(err)
19851979
}
1986-
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
1980+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
19871981
if err != nil {
19881982
t.Fatal(err)
19891983
}
19901984
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
1991-
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
1985+
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
19921986
}
1993-
})
1987+
}
1988+
1989+
// Then, check that we handle different types of output equivalently.
1990+
for _, in := range []string{"nil", "empty", "ok"} {
1991+
t.Run(in, func(t *testing.T) {
1992+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
1993+
if err != nil {
1994+
t.Fatal(err)
1995+
}
1996+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
1997+
if err != nil {
1998+
t.Fatal(err)
1999+
}
2000+
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
2001+
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
2002+
}
2003+
})
2004+
}
19942005
}
19952006
}
19962007

mcp/server.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -359,21 +359,14 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
359359
// Pointers are treated equivalently to non-pointers when deriving the schema.
360360
// If an indirection occurred to derive the schema, a non-nil zero value is
361361
// returned to be used in place of the typed nil zero value.
362-
//
363-
// Note that if sfield already holds a schema, zero will be nil even if T is a
364-
// pointer: if the user provided the schema, they may have intentionally
365-
// derived it from the pointer type, and handling of zero values is up to them.
366-
//
367-
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
368-
// should have a jsonschema.Zero(schema) helper?
369362
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) {
363+
rt := reflect.TypeFor[T]()
364+
if rt.Kind() == reflect.Pointer {
365+
rt = rt.Elem()
366+
zero = reflect.Zero(rt).Interface()
367+
}
370368
var internalSchema *jsonschema.Schema
371369
if *sfield == nil {
372-
rt := reflect.TypeFor[T]()
373-
if rt.Kind() == reflect.Pointer {
374-
rt = rt.Elem()
375-
zero = reflect.Zero(rt).Interface()
376-
}
377370
// TODO: we should be able to pass nil opts here.
378371
internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
379372
if err == nil {

0 commit comments

Comments
 (0)