Skip to content

Commit fa2d802

Browse files
Add resource completion for GitHub repository resources (#1493)
Port resource completion from the remote GitHub MCP Server --- Co-authored-by: Ksenia Bobrova <almaleksia@github.com>
1 parent 9b34211 commit fa2d802

File tree

4 files changed

+735
-12
lines changed

4 files changed

+735
-12
lines changed

internal/ghmcp/server.go

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,6 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) {
124124
// Generate instructions based on enabled toolsets
125125
instructions := github.GenerateInstructions(enabledToolsets)
126126

127-
ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{
128-
Instructions: instructions,
129-
HasTools: true,
130-
HasResources: true,
131-
HasPrompts: true,
132-
Logger: cfg.Logger,
133-
})
134-
135-
// Add middlewares
136-
ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext)
137-
ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient))
138-
139127
getClient := func(_ context.Context) (*gogithub.Client, error) {
140128
return restClient, nil // closing over client
141129
}
@@ -152,6 +140,16 @@ func NewMCPServer(cfg MCPServerConfig) (*mcp.Server, error) {
152140
return raw.NewClient(client, apiHost.rawURL), nil // closing over client
153141
}
154142

143+
ghServer := github.NewServer(cfg.Version, &mcp.ServerOptions{
144+
Instructions: instructions,
145+
Logger: cfg.Logger,
146+
CompletionHandler: github.CompletionsHandler(getClient),
147+
})
148+
149+
// Add middlewares
150+
ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext)
151+
ghServer.AddReceivingMiddleware(addUserAgentsMiddleware(cfg, restClient, gqlHTTPClient))
152+
155153
// Create default toolsets
156154
tsg := github.DefaultToolsetGroup(
157155
cfg.ReadOnly,
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"strings"
8+
9+
"github.com/google/go-github/v79/github"
10+
"github.com/modelcontextprotocol/go-sdk/mcp"
11+
)
12+
13+
// CompleteHandler defines function signature for completion handlers
14+
type CompleteHandler func(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error)
15+
16+
// RepositoryResourceArgumentResolvers is a map of argument names to their completion handlers
17+
var RepositoryResourceArgumentResolvers = map[string]CompleteHandler{
18+
"owner": completeOwner,
19+
"repo": completeRepo,
20+
"branch": completeBranch,
21+
"sha": completeSHA,
22+
"tag": completeTag,
23+
"prNumber": completePRNumber,
24+
"path": completePath,
25+
}
26+
27+
// RepositoryResourceCompletionHandler returns a CompletionHandlerFunc for repository resource completions.
28+
func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) {
29+
return func(ctx context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) {
30+
if req.Params.Ref.Type != "ref/resource" {
31+
return nil, nil // Not a resource completion
32+
}
33+
34+
argName := req.Params.Argument.Name
35+
argValue := req.Params.Argument.Value
36+
resolved := req.Params.Context.Arguments
37+
if resolved == nil {
38+
resolved = map[string]string{}
39+
}
40+
41+
client, err := getClient(ctx)
42+
if err != nil {
43+
return nil, err
44+
}
45+
46+
// Argument resolver functions
47+
resolvers := RepositoryResourceArgumentResolvers
48+
49+
resolver, ok := resolvers[argName]
50+
if !ok {
51+
return nil, errors.New("no resolver for argument: " + argName)
52+
}
53+
54+
values, err := resolver(ctx, client, resolved, argValue)
55+
if err != nil {
56+
return nil, err
57+
}
58+
if len(values) > 100 {
59+
values = values[:100]
60+
}
61+
62+
return &mcp.CompleteResult{
63+
Completion: mcp.CompletionResultDetails{
64+
Values: values,
65+
Total: len(values),
66+
HasMore: false,
67+
},
68+
}, nil
69+
}
70+
}
71+
72+
// --- Per-argument resolver functions ---
73+
74+
func completeOwner(ctx context.Context, client *github.Client, _ map[string]string, argValue string) ([]string, error) {
75+
var values []string
76+
user, _, err := client.Users.Get(ctx, "")
77+
if err == nil && user.GetLogin() != "" {
78+
values = append(values, user.GetLogin())
79+
}
80+
81+
orgs, _, err := client.Organizations.List(ctx, "", &github.ListOptions{PerPage: 100})
82+
if err != nil {
83+
return nil, err
84+
}
85+
for _, org := range orgs {
86+
values = append(values, org.GetLogin())
87+
}
88+
89+
// filter values based on argValue and replace values slice
90+
if argValue != "" {
91+
var filteredValues []string
92+
for _, value := range values {
93+
if strings.Contains(value, argValue) {
94+
filteredValues = append(filteredValues, value)
95+
}
96+
}
97+
values = filteredValues
98+
}
99+
if len(values) > 100 {
100+
values = values[:100]
101+
return values, nil // Limit to 100 results
102+
}
103+
// Else also do a client.Search.Users()
104+
if argValue == "" {
105+
return values, nil // No need to search if no argValue
106+
}
107+
users, _, err := client.Search.Users(ctx, argValue, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100 - len(values)}})
108+
if err != nil || users == nil {
109+
return nil, err
110+
}
111+
for _, user := range users.Users {
112+
values = append(values, user.GetLogin())
113+
}
114+
115+
if len(values) > 100 {
116+
values = values[:100]
117+
}
118+
return values, nil
119+
}
120+
121+
func completeRepo(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
122+
var values []string
123+
owner := resolved["owner"]
124+
if owner == "" {
125+
return values, errors.New("owner not specified")
126+
}
127+
128+
query := fmt.Sprintf("org:%s", owner)
129+
130+
if argValue != "" {
131+
query = fmt.Sprintf("%s %s", query, argValue)
132+
}
133+
repos, _, err := client.Search.Repositories(ctx, query, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}})
134+
if err != nil || repos == nil {
135+
return values, errors.New("failed to get repositories")
136+
}
137+
// filter repos based on argValue
138+
for _, repo := range repos.Repositories {
139+
name := repo.GetName()
140+
if argValue == "" || strings.HasPrefix(name, argValue) {
141+
values = append(values, name)
142+
}
143+
}
144+
145+
return values, nil
146+
}
147+
148+
func completeBranch(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
149+
var values []string
150+
owner := resolved["owner"]
151+
repo := resolved["repo"]
152+
if owner == "" || repo == "" {
153+
return values, errors.New("owner or repo not specified")
154+
}
155+
branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil)
156+
157+
for _, branch := range branches {
158+
if argValue == "" || strings.HasPrefix(branch.GetName(), argValue) {
159+
values = append(values, branch.GetName())
160+
}
161+
}
162+
if len(values) > 100 {
163+
values = values[:100]
164+
}
165+
return values, nil
166+
}
167+
168+
func completeSHA(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
169+
var values []string
170+
owner := resolved["owner"]
171+
repo := resolved["repo"]
172+
if owner == "" || repo == "" {
173+
return values, errors.New("owner or repo not specified")
174+
}
175+
commits, _, _ := client.Repositories.ListCommits(ctx, owner, repo, nil)
176+
177+
for _, commit := range commits {
178+
sha := commit.GetSHA()
179+
if argValue == "" || strings.HasPrefix(sha, argValue) {
180+
values = append(values, sha)
181+
}
182+
}
183+
if len(values) > 100 {
184+
values = values[:100]
185+
}
186+
return values, nil
187+
}
188+
189+
func completeTag(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
190+
owner := resolved["owner"]
191+
repo := resolved["repo"]
192+
if owner == "" || repo == "" {
193+
return nil, errors.New("owner or repo not specified")
194+
}
195+
tags, _, _ := client.Repositories.ListTags(ctx, owner, repo, nil)
196+
var values []string
197+
for _, tag := range tags {
198+
if argValue == "" || strings.Contains(tag.GetName(), argValue) {
199+
values = append(values, tag.GetName())
200+
}
201+
}
202+
if len(values) > 100 {
203+
values = values[:100]
204+
}
205+
return values, nil
206+
}
207+
208+
func completePRNumber(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
209+
var values []string
210+
owner := resolved["owner"]
211+
repo := resolved["repo"]
212+
if owner == "" || repo == "" {
213+
return values, errors.New("owner or repo not specified")
214+
}
215+
216+
prs, _, err := client.Search.Issues(ctx, fmt.Sprintf("repo:%s/%s is:open is:pr", owner, repo), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}})
217+
if err != nil {
218+
return values, err
219+
}
220+
for _, pr := range prs.Issues {
221+
num := fmt.Sprintf("%d", pr.GetNumber())
222+
if argValue == "" || strings.HasPrefix(num, argValue) {
223+
values = append(values, num)
224+
}
225+
}
226+
if len(values) > 100 {
227+
values = values[:100]
228+
}
229+
return values, nil
230+
}
231+
232+
func completePath(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) {
233+
owner := resolved["owner"]
234+
repo := resolved["repo"]
235+
if owner == "" || repo == "" {
236+
return nil, errors.New("owner or repo not specified")
237+
}
238+
refVal := resolved["branch"]
239+
if refVal == "" {
240+
refVal = resolved["sha"]
241+
}
242+
if refVal == "" {
243+
refVal = resolved["tag"]
244+
}
245+
if refVal == "" {
246+
refVal = "HEAD"
247+
}
248+
249+
// Determine the prefix to complete (directory path or file path)
250+
prefix := argValue
251+
if prefix != "" && !strings.HasSuffix(prefix, "/") {
252+
lastSlash := strings.LastIndex(prefix, "/")
253+
if lastSlash >= 0 {
254+
prefix = prefix[:lastSlash+1]
255+
} else {
256+
prefix = ""
257+
}
258+
}
259+
260+
// Get the tree for the ref (recursive)
261+
tree, _, err := client.Git.GetTree(ctx, owner, repo, refVal, true)
262+
if err != nil || tree == nil {
263+
return nil, errors.New("failed to get file tree")
264+
}
265+
266+
// Collect immediate children of the prefix (files and directories, no duplicates)
267+
dirs := map[string]struct{}{}
268+
files := map[string]struct{}{}
269+
prefixLen := len(prefix)
270+
for _, entry := range tree.Entries {
271+
if !strings.HasPrefix(entry.GetPath(), prefix) {
272+
continue
273+
}
274+
rel := entry.GetPath()[prefixLen:]
275+
if rel == "" {
276+
continue
277+
}
278+
// Only immediate children
279+
slashIdx := strings.Index(rel, "/")
280+
if slashIdx >= 0 {
281+
// Directory: only add the directory name (with trailing slash), prefixed with full path
282+
dirName := prefix + rel[:slashIdx+1]
283+
dirs[dirName] = struct{}{}
284+
} else if entry.GetType() == "blob" {
285+
// File: add as-is, prefixed with full path
286+
fileName := prefix + rel
287+
files[fileName] = struct{}{}
288+
}
289+
}
290+
291+
// Optionally filter by argValue (if user is typing after last slash)
292+
var filter string
293+
if argValue != "" {
294+
if lastSlash := strings.LastIndex(argValue, "/"); lastSlash >= 0 {
295+
filter = argValue[lastSlash+1:]
296+
} else {
297+
filter = argValue
298+
}
299+
}
300+
301+
var values []string
302+
// Add directories first, then files, both filtered
303+
for dir := range dirs {
304+
// Only filter on the last segment after the last slash
305+
if filter == "" {
306+
values = append(values, dir)
307+
} else {
308+
last := dir
309+
if idx := strings.LastIndex(strings.TrimRight(dir, "/"), "/"); idx >= 0 {
310+
last = dir[idx+1:]
311+
}
312+
if strings.HasPrefix(last, filter) {
313+
values = append(values, dir)
314+
}
315+
}
316+
}
317+
for file := range files {
318+
if filter == "" {
319+
values = append(values, file)
320+
} else {
321+
last := file
322+
if idx := strings.LastIndex(file, "/"); idx >= 0 {
323+
last = file[idx+1:]
324+
}
325+
if strings.HasPrefix(last, filter) {
326+
values = append(values, file)
327+
}
328+
}
329+
}
330+
331+
if len(values) > 100 {
332+
values = values[:100]
333+
}
334+
return values, nil
335+
}

0 commit comments

Comments
 (0)