diff --git a/integration_test.go b/integration_test.go index 13fe1d0..64c7a45 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1491,3 +1491,144 @@ This is the task prompt for resume mode. t.Errorf("stderr should NOT contain 'Rules written' message in resume mode") } } + +// TestLocalDirectoryNotDeleted verifies that local directories passed via -d flag +// are not deleted after the command completes. +func TestLocalDirectoryNotDeleted(t *testing.T) { + // Create a local directory with a rule file and a marker file + localDir := t.TempDir() + rulesDir := filepath.Join(localDir, ".agents", "rules") + + if err := os.MkdirAll(rulesDir, 0o755); err != nil { + t.Fatalf("failed to create rules dir: %v", err) + } + + // Create a rule file + ruleFile := filepath.Join(rulesDir, "local-rule.md") + ruleContent := `--- +language: go +--- +# Local Rule + +This is a rule from a local directory. +` + if err := os.WriteFile(ruleFile, []byte(ruleContent), 0o644); err != nil { + t.Fatalf("failed to write rule file: %v", err) + } + + // Create a marker file to verify the directory is not deleted + markerFile := filepath.Join(localDir, "marker.txt") + if err := os.WriteFile(markerFile, []byte("marker"), 0o644); err != nil { + t.Fatalf("failed to write marker file: %v", err) + } + + // Create a temporary directory for the task + tmpDir := t.TempDir() + tasksDir := filepath.Join(tmpDir, ".agents", "tasks") + + if err := os.MkdirAll(tasksDir, 0o755); err != nil { + t.Fatalf("failed to create tasks dir: %v", err) + } + + createStandardTask(t, tasksDir, "test-task") + + // Run the program with local directory using file:// URL + localURL := "file://" + localDir + output := runTool(t, "-C", tmpDir, "-d", localURL, "test-task") + + // Check that local rule content is present + if !strings.Contains(output, "# Local Rule") { + t.Errorf("local rule content not found in stdout") + } + if !strings.Contains(output, "This is a rule from a local directory") { + t.Errorf("local rule description not found in stdout") + } + + // Verify the marker file still exists (directory was not deleted) + if _, err := os.Stat(markerFile); err != nil { + if os.IsNotExist(err) { + t.Errorf("marker file was deleted, indicating local directory was deleted") + } else { + t.Fatalf("unexpected error checking marker file: %v", err) + } + } + + // Verify the rule file still exists + if _, err := os.Stat(ruleFile); err != nil { + if os.IsNotExist(err) { + t.Errorf("rule file was deleted, indicating local directory was deleted") + } else { + t.Fatalf("unexpected error checking rule file: %v", err) + } + } +} + +// TestLocalDirectoryWithoutProtocol verifies that local directories passed +// without the file:// protocol are not deleted. +func TestLocalDirectoryWithoutProtocol(t *testing.T) { + // Create a local directory with a rule file and a marker file + localDir := t.TempDir() + rulesDir := filepath.Join(localDir, ".agents", "rules") + + if err := os.MkdirAll(rulesDir, 0o755); err != nil { + t.Fatalf("failed to create rules dir: %v", err) + } + + // Create a rule file + ruleFile := filepath.Join(rulesDir, "local-rule.md") + ruleContent := `--- +language: go +--- +# Local Rule + +This is a rule from a local directory without protocol. +` + if err := os.WriteFile(ruleFile, []byte(ruleContent), 0o644); err != nil { + t.Fatalf("failed to write rule file: %v", err) + } + + // Create a marker file to verify the directory is not deleted + markerFile := filepath.Join(localDir, "marker.txt") + if err := os.WriteFile(markerFile, []byte("marker"), 0o644); err != nil { + t.Fatalf("failed to write marker file: %v", err) + } + + // Create a temporary directory for the task + tmpDir := t.TempDir() + tasksDir := filepath.Join(tmpDir, ".agents", "tasks") + + if err := os.MkdirAll(tasksDir, 0o755); err != nil { + t.Fatalf("failed to create tasks dir: %v", err) + } + + createStandardTask(t, tasksDir, "test-task") + + // Run the program with local directory using absolute path (no protocol) + output := runTool(t, "-C", tmpDir, "-d", localDir, "test-task") + + // Check that local rule content is present + if !strings.Contains(output, "# Local Rule") { + t.Errorf("local rule content not found in stdout") + } + if !strings.Contains(output, "This is a rule from a local directory without protocol") { + t.Errorf("local rule description not found in stdout") + } + + // Verify the marker file still exists (directory was not deleted) + if _, err := os.Stat(markerFile); err != nil { + if os.IsNotExist(err) { + t.Errorf("marker file was deleted, indicating local directory was deleted") + } else { + t.Fatalf("unexpected error checking marker file: %v", err) + } + } + + // Verify the rule file still exists + if _, err := os.Stat(ruleFile); err != nil { + if os.IsNotExist(err) { + t.Errorf("rule file was deleted, indicating local directory was deleted") + } else { + t.Fatalf("unexpected error checking rule file: %v", err) + } + } +} diff --git a/pkg/codingcontext/context.go b/pkg/codingcontext/context.go index dc38531..e83dc4e 100644 --- a/pkg/codingcontext/context.go +++ b/pkg/codingcontext/context.go @@ -348,6 +348,37 @@ func (cc *Context) Run(ctx context.Context, taskName string) (*Result, error) { return result, nil } +// isLocalPath checks if a path is a local file system path. +// Returns true for: +// - file:// URLs (e.g., file:///path/to/dir) +// - Absolute paths (e.g., /path/to/dir) +// - Relative paths (e.g., ./path or ../path) +// Returns false for remote protocols like git::, https://, s3::, etc. +func isLocalPath(path string) bool { + // Check if path starts with file:// protocol + if strings.HasPrefix(path, "file://") { + return true + } + + // Check if it's an absolute or relative local path + // (no protocol prefix like git::, https://, s3::, etc.) + if !strings.Contains(path, "://") && !strings.Contains(path, "::") { + return true + } + + return false +} + +// normalizeLocalPath converts a local path to a usable file system path. +// For file:// URLs, it strips the protocol prefix. +// For other local paths, it returns them as-is. +func normalizeLocalPath(path string) string { + if strings.HasPrefix(path, "file://") { + return strings.TrimPrefix(path, "file://") + } + return path +} + func downloadDir(path string) string { // hash the path and prepend it with a temporary directory hash := sha256.Sum256([]byte(path)) @@ -397,6 +428,15 @@ func (cc *Context) parseManifestFile(ctx context.Context) ([]string, error) { func (cc *Context) downloadRemoteDirectories(ctx context.Context) error { for _, path := range cc.searchPaths { + // If the path is local, use it directly without downloading + if isLocalPath(path) { + localPath := normalizeLocalPath(path) + cc.logger.Info("Using local directory", "path", localPath) + cc.downloadedPaths = append(cc.downloadedPaths, localPath) + continue + } + + // Download remote directories cc.logger.Info("Downloading remote directory", "path", path) dst := downloadDir(path) if _, err := getter.Get(ctx, dst, path); err != nil { @@ -411,6 +451,12 @@ func (cc *Context) downloadRemoteDirectories(ctx context.Context) error { func (cc *Context) cleanupDownloadedDirectories() { for _, path := range cc.searchPaths { + // Skip cleanup for local paths - they should not be deleted + if isLocalPath(path) { + continue + } + + // Only clean up downloaded remote directories dst := downloadDir(path) if err := os.RemoveAll(dst); err != nil { cc.logger.Error("Error cleaning up downloaded directory", "path", dst, "error", err) diff --git a/pkg/codingcontext/context_test.go b/pkg/codingcontext/context_test.go index 8d921f4..4ca3f84 100644 --- a/pkg/codingcontext/context_test.go +++ b/pkg/codingcontext/context_test.go @@ -1695,3 +1695,116 @@ func TestUserPrompt(t *testing.T) { }) } } + +// TestIsLocalPath tests the isLocalPath helper function +func TestIsLocalPath(t *testing.T) { + tests := []struct { + name string + path string + expected bool + }{ + { + name: "file:// protocol", + path: "file:///path/to/local", + expected: true, + }, + { + name: "absolute path", + path: "/path/to/local", + expected: true, + }, + { + name: "relative path - ./", + path: "./relative/path", + expected: true, + }, + { + name: "relative path - ../", + path: "../relative/path", + expected: true, + }, + { + name: "relative path - no prefix", + path: "relative/path", + expected: true, + }, + { + name: "git protocol", + path: "git::https://github.com/user/repo.git", + expected: false, + }, + { + name: "https protocol", + path: "https://example.com/file.tar.gz", + expected: false, + }, + { + name: "http protocol", + path: "http://example.com/file.tar.gz", + expected: false, + }, + { + name: "s3 protocol", + path: "s3::https://s3.amazonaws.com/bucket/key", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isLocalPath(tt.path) + if result != tt.expected { + t.Errorf("isLocalPath(%q) = %v, expected %v", tt.path, result, tt.expected) + } + }) + } +} + +// TestNormalizeLocalPath tests the normalizeLocalPath helper function +func TestNormalizeLocalPath(t *testing.T) { + tests := []struct { + name string + path string + expected string + }{ + { + name: "file:// protocol - absolute path", + path: "file:///path/to/local", + expected: "/path/to/local", + }, + { + name: "file:// protocol - relative path", + path: "file://./relative/path", + expected: "./relative/path", + }, + { + name: "absolute path without protocol", + path: "/path/to/local", + expected: "/path/to/local", + }, + { + name: "relative path - ./", + path: "./relative/path", + expected: "./relative/path", + }, + { + name: "relative path - ../", + path: "../relative/path", + expected: "../relative/path", + }, + { + name: "relative path - no prefix", + path: "relative/path", + expected: "relative/path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeLocalPath(tt.path) + if result != tt.expected { + t.Errorf("normalizeLocalPath(%q) = %q, expected %q", tt.path, result, tt.expected) + } + }) + } +}