diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0c3f82b --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Binary output +coding-agent-context-cli + +# Build artifacts +/dist/ +/tmp/ diff --git a/bootstrap b/bootstrap new file mode 100644 index 0000000..691079f --- /dev/null +++ b/bootstrap @@ -0,0 +1,4 @@ +#!/bin/bash +set -euo pipefail + +find bootstrap.d -type f -exec {} \; diff --git a/go.mod b/go.mod index 7d36fac..7c069ca 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/kitproj/coding-agent-context-cli go 1.24.4 + +require go.yaml.in/yaml/v2 v2.4.2 diff --git a/go.sum b/go.sum index e69de29..2b30a96 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,4 @@ +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 0f372b4..20ef0e3 100644 --- a/main.go +++ b/main.go @@ -1,43 +1,150 @@ package main import ( - "context" + "crypto/sha256" + _ "embed" "flag" "fmt" + "log/slog" "os" - "os/signal" - "syscall" + "path/filepath" + "text/template" ) -var () +//go:embed bootstrap +var bootstrap string + +var ( + dirs stringSlice + outputDir = "." + params = make(paramMap) +) func main() { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer cancel() + userConfigDir, err := os.UserConfigDir() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + dirs = []string{ + ".coding-agent-context", + filepath.Join(userConfigDir, "coding-agent-context"), + "/var/local/coding-agent-context", + } + + flag.Var(&dirs, "d", "Directory to include in the context. Can be specified multiple times.") + flag.StringVar(&outputDir, "o", ".", "Directory to write the context files to.") + flag.Var(¶ms, "p", "Parameter to substitute in the prompt. Can be specified multiple times as key=value.") flag.Usage = func() { w := flag.CommandLine.Output() fmt.Fprintf(w, "Usage:") fmt.Fprintln(w) - fmt.Fprintln(w, " coding-agent-context ") fmt.Fprintln(w) fmt.Fprintln(w, "Options:") flag.PrintDefaults() } flag.Parse() - if err := run(ctx, flag.Args()); err != nil { + if err := run(flag.Args()); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) flag.Usage() os.Exit(1) } } -func run(ctx context.Context, args []string) error { +func run(args []string) error { if len(args) < 1 { return fmt.Errorf("invalid usage") } - return nil + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("failed to create output dir: %w", err) + } + + bootstrapDir := filepath.Join(outputDir, "bootstrap.d") + if err := os.MkdirAll(bootstrapDir, 0755); err != nil { + return fmt.Errorf("failed to create bootstrap dir: %w", err) + } + + output, err := os.Create(filepath.Join(outputDir, "prompt.md")) + if err != nil { + return fmt.Errorf("failed to create prompt file: %w", err) + } + defer output.Close() + + for _, dir := range dirs { + memoryDir := filepath.Join(dir, "memories") + err := filepath.Walk(memoryDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + slog.Info("Including memory file", "path", path) + + var frontmatter struct { + Bootstrap string `yaml:"bootstrap"` + } + + content, err := parseMarkdownFile(path, &frontmatter) + if err != nil { + return fmt.Errorf("failed to parse markdown file: %w", err) + } + + if bootstrap := frontmatter.Bootstrap; bootstrap != "" { + hash := sha256.Sum256([]byte(bootstrap)) + bootstrapPath := filepath.Join(bootstrapDir, fmt.Sprintf("%x", hash)) + if err := os.WriteFile(bootstrapPath, []byte(bootstrap), 0700); err != nil { + return fmt.Errorf("failed to write bootstrap file: %w", err) + } + } + + if _, err := output.WriteString(content + "\n\n"); err != nil { + return fmt.Errorf("failed to write to output file: %w", err) + } + + return nil + + }) + if err != nil { + return fmt.Errorf("failed to walk memory dir: %w", err) + } + } + + if err := os.WriteFile(filepath.Join(outputDir, "bootstrap"), []byte(bootstrap), 0755); err != nil { + return fmt.Errorf("failed to write bootstrap file: %w", err) + } + + taskName := args[0] + for _, dir := range dirs { + promptFile := filepath.Join(dir, "prompts", taskName+".md") + + if _, err := os.Stat(promptFile); err == nil { + slog.Info("Using prompt file", "path", promptFile) + + content, err := parseMarkdownFile(promptFile, &struct{}{}) + if err != nil { + return fmt.Errorf("failed to parse prompt file: %w", err) + } + + t, err := template.New("prompt").Parse(content) + if err != nil { + return fmt.Errorf("failed to parse prompt template: %w", err) + } + + if err := t.Execute(output, params); err != nil { + return fmt.Errorf("failed to execute prompt template: %w", err) + } + + return nil + + } + } + return fmt.Errorf("prompt file not found for task: %s", taskName) } diff --git a/markdown.go b/markdown.go new file mode 100644 index 0000000..159afbb --- /dev/null +++ b/markdown.go @@ -0,0 +1,51 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "os" + + yaml "go.yaml.in/yaml/v2" +) + +// parseMarkdownFile parses the file into frontmatter and content +func parseMarkdownFile(path string, frontmatter any) (string, error) { + + fh, err := os.Open(path) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer fh.Close() + + s := bufio.NewScanner(fh) + + if s.Scan() && s.Text() == "---" { + var frontMatterBytes bytes.Buffer + for s.Scan() { + line := s.Text() + if line == "---" { + break + } + + if _, err := frontMatterBytes.WriteString(line + "\n"); err != nil { + return "", fmt.Errorf("failed to write frontmatter: %w", err) + } + } + + if err := yaml.Unmarshal(frontMatterBytes.Bytes(), frontmatter); err != nil { + return "", fmt.Errorf("failed to unmarshal frontmatter: %w", err) + } + } + + var content bytes.Buffer + for s.Scan() { + if _, err := content.WriteString(s.Text() + "\n"); err != nil { + return "", fmt.Errorf("failed to write content: %w", err) + } + } + if err := s.Err(); err != nil { + return "", fmt.Errorf("failed to scan file: %w", err) + } + return content.String(), nil +} diff --git a/markdown_test.go b/markdown_test.go new file mode 100644 index 0000000..bc98f0e --- /dev/null +++ b/markdown_test.go @@ -0,0 +1,94 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseMarkdownFile(t *testing.T) { + tests := []struct { + name string + content string + wantContent string + wantFrontmatter map[string]string + wantErr bool + }{ + { + name: "markdown with frontmatter", + content: `--- +title: Test Title +author: Test Author +--- +This is the content +of the markdown file. +`, + wantContent: "This is the content\nof the markdown file.\n", + wantFrontmatter: map[string]string{ + "title": "Test Title", + "author": "Test Author", + }, + wantErr: false, + }, + { + name: "markdown without frontmatter", + content: `This is a simple markdown file +without any frontmatter. +`, + wantContent: "without any frontmatter.\n", + wantFrontmatter: map[string]string{}, + wantErr: false, + }, + { + name: "empty file", + content: "", + wantContent: "", + wantFrontmatter: map[string]string{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a temporary file + tmpDir := t.TempDir() + tmpFile := filepath.Join(tmpDir, "test.md") + if err := os.WriteFile(tmpFile, []byte(tt.content), 0644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + // Parse the file + var frontmatter map[string]string + content, err := parseMarkdownFile(tmpFile, &frontmatter) + + // Check error + if (err != nil) != tt.wantErr { + t.Errorf("parseMarkdownFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check content + if content != tt.wantContent { + t.Errorf("parseMarkdownFile() content = %q, want %q", content, tt.wantContent) + } + + // Check frontmatter + if len(frontmatter) != len(tt.wantFrontmatter) { + t.Errorf("parseMarkdownFile() frontmatter length = %d, want %d", len(frontmatter), len(tt.wantFrontmatter)) + } + for k, v := range tt.wantFrontmatter { + if frontmatter[k] != v { + t.Errorf("parseMarkdownFile() frontmatter[%q] = %q, want %q", k, frontmatter[k], v) + } + } + }) + } +} + +func TestParseMarkdownFile_FileNotFound(t *testing.T) { + var frontmatter map[string]string + _, err := parseMarkdownFile("/nonexistent/file.md", &frontmatter) + if err == nil { + t.Error("parseMarkdownFile() expected error for non-existent file, got nil") + } +} diff --git a/param_map.go b/param_map.go new file mode 100644 index 0000000..3b4f836 --- /dev/null +++ b/param_map.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "strings" +) + +type paramMap map[string]string + +func (p *paramMap) String() string { + return fmt.Sprint(*p) +} + +func (p *paramMap) Set(value string) error { + kv := strings.SplitN(value, "=", 2) + if len(kv) != 2 { + return fmt.Errorf("invalid parameter format: %s", value) + } + if *p == nil { + *p = make(map[string]string) + } + (*p)[kv[0]] = kv[1] + return nil +} diff --git a/param_map_test.go b/param_map_test.go new file mode 100644 index 0000000..ffa1ea6 --- /dev/null +++ b/param_map_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "testing" +) + +func TestParamMap_Set(t *testing.T) { + tests := []struct { + name string + value string + wantKey string + wantVal string + wantErr bool + }{ + { + name: "valid key=value", + value: "key=value", + wantKey: "key", + wantVal: "value", + wantErr: false, + }, + { + name: "key=value with equals in value", + value: "key=value=with=equals", + wantKey: "key", + wantVal: "value=with=equals", + wantErr: false, + }, + { + name: "empty value", + value: "key=", + wantKey: "key", + wantVal: "", + wantErr: false, + }, + { + name: "invalid format - no equals", + value: "keyvalue", + wantErr: true, + }, + { + name: "invalid format - only key", + value: "key", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := paramMap{} + err := p.Set(tt.value) + + if (err != nil) != tt.wantErr { + t.Errorf("paramMap.Set() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + if p[tt.wantKey] != tt.wantVal { + t.Errorf("paramMap[%q] = %q, want %q", tt.wantKey, p[tt.wantKey], tt.wantVal) + } + } + }) + } +} + +func TestParamMap_String(t *testing.T) { + p := paramMap{ + "key1": "value1", + "key2": "value2", + } + s := p.String() + if s == "" { + t.Error("paramMap.String() returned empty string") + } +} + +func TestParamMap_SetMultiple(t *testing.T) { + p := paramMap{} + + if err := p.Set("key1=value1"); err != nil { + t.Fatalf("paramMap.Set() failed: %v", err) + } + if err := p.Set("key2=value2"); err != nil { + t.Fatalf("paramMap.Set() failed: %v", err) + } + + if len(p) != 2 { + t.Errorf("paramMap length = %d, want 2", len(p)) + } + if p["key1"] != "value1" { + t.Errorf("paramMap[key1] = %q, want %q", p["key1"], "value1") + } + if p["key2"] != "value2" { + t.Errorf("paramMap[key2] = %q, want %q", p["key2"], "value2") + } +} diff --git a/string_slice.go b/string_slice.go new file mode 100644 index 0000000..edc9db3 --- /dev/null +++ b/string_slice.go @@ -0,0 +1,14 @@ +package main + +import "fmt" + +type stringSlice []string + +func (s *stringSlice) String() string { + return fmt.Sprint(*s) +} + +func (s *stringSlice) Set(value string) error { + *s = append(*s, value) + return nil +} diff --git a/string_slice_test.go b/string_slice_test.go new file mode 100644 index 0000000..704de9b --- /dev/null +++ b/string_slice_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "testing" +) + +func TestStringSlice_Set(t *testing.T) { + s := stringSlice{} + + values := []string{"first", "second", "third"} + for _, v := range values { + if err := s.Set(v); err != nil { + t.Errorf("stringSlice.Set(%q) error = %v", v, err) + } + } + + if len(s) != len(values) { + t.Errorf("stringSlice length = %d, want %d", len(s), len(values)) + } + + for i, want := range values { + if s[i] != want { + t.Errorf("stringSlice[%d] = %q, want %q", i, s[i], want) + } + } +} + +func TestStringSlice_String(t *testing.T) { + s := stringSlice{"value1", "value2", "value3"} + str := s.String() + if str == "" { + t.Error("stringSlice.String() returned empty string") + } +} + +func TestStringSlice_SetEmpty(t *testing.T) { + s := stringSlice{} + + if err := s.Set(""); err != nil { + t.Errorf("stringSlice.Set(\"\") error = %v", err) + } + + if len(s) != 1 { + t.Errorf("stringSlice length = %d, want 1", len(s)) + } + if s[0] != "" { + t.Errorf("stringSlice[0] = %q, want empty string", s[0]) + } +}