Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
func() bool { s.tools.add(st); return true })
}

func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out], schemaOpts *jsonschema.ForOptions) (*Tool, ToolHandler, error) {
tt := *t

// Special handling for an "any" input: treat as an empty object.
Expand All @@ -248,7 +248,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
}

var inputResolved *jsonschema.Resolved
if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil {
if _, err := setSchema[In](&tt.InputSchema, &inputResolved, schemaOpts); err != nil {
return nil, nil, fmt.Errorf("input schema: %w", err)
}

Expand All @@ -263,7 +263,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
)
if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
var err error
elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved)
elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved, schemaOpts)
if err != nil {
return nil, nil, fmt.Errorf("output schema: %v", err)
}
Expand Down Expand Up @@ -366,16 +366,15 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
//
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
// should have a jsonschema.Zero(schema) helper?
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) {
func setSchema[T any](sfield *any, rfield **jsonschema.Resolved, schemaOpts *jsonschema.ForOptions) (zero any, err error) {
var internalSchema *jsonschema.Schema
if *sfield == nil {
rt := reflect.TypeFor[T]()
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
zero = reflect.Zero(rt).Interface()
}
// TODO: we should be able to pass nil opts here.
internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
internalSchema, err = jsonschema.ForType(rt, schemaOpts)
if err == nil {
*sfield = internalSchema
}
Expand All @@ -389,6 +388,20 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err
return zero, err
}

// AddToolOption is an option for the AddTool function.
type AddToolOption func(*addToolOptions)

type addToolOptions struct {
schemaOpts *jsonschema.ForOptions
}

// WithSchemaOptions returns an AddToolOption that sets options for schema inference.
func WithSchemaOptions(opts *jsonschema.ForOptions) AddToolOption {
return func(ato *addToolOptions) {
ato.schemaOpts = opts
}
}

// AddTool adds a tool and typed tool handler to the server.
//
// If the tool's input schema is nil, it is set to the schema inferred from the
Expand All @@ -408,8 +421,14 @@ func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err
// Unlike [Server.AddTool], AddTool does a lot automatically, and forces
// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed
// description of this automatic behavior.
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
tt, hh, err := toolForErr(t, h)
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out], opts ...AddToolOption) {
o := addToolOptions{
schemaOpts: &jsonschema.ForOptions{},
}
for _, opt := range opts {
opt(&o)
}
tt, hh, err := toolForErr(t, h, o.schemaOpts)
if err != nil {
panic(fmt.Sprintf("AddTool: tool %q: %v", t.Name, err))
}
Expand Down
2 changes: 1 addition & 1 deletion mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out
th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) {
return nil, out, nil
}
gott, goth, err := toolForErr(tool, th)
gott, goth, err := toolForErr(tool, th, &jsonschema.ForOptions{})
if err != nil {
t.Fatal(err)
}
Expand Down