Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1491,3 +1491,144 @@ This is the task prompt for resume mode.
t.Errorf("stderr should NOT contain 'Rules written' message in resume mode")
}
}

// TestLocalDirectoryNotDeleted verifies that local directories passed via -d flag
// are not deleted after the command completes.
func TestLocalDirectoryNotDeleted(t *testing.T) {
// Create a local directory with a rule file and a marker file
localDir := t.TempDir()
rulesDir := filepath.Join(localDir, ".agents", "rules")

if err := os.MkdirAll(rulesDir, 0o755); err != nil {
t.Fatalf("failed to create rules dir: %v", err)
}

// Create a rule file
ruleFile := filepath.Join(rulesDir, "local-rule.md")
ruleContent := `---
language: go
---
# Local Rule

This is a rule from a local directory.
`
if err := os.WriteFile(ruleFile, []byte(ruleContent), 0o644); err != nil {
t.Fatalf("failed to write rule file: %v", err)
}

// Create a marker file to verify the directory is not deleted
markerFile := filepath.Join(localDir, "marker.txt")
if err := os.WriteFile(markerFile, []byte("marker"), 0o644); err != nil {
t.Fatalf("failed to write marker file: %v", err)
}

// Create a temporary directory for the task
tmpDir := t.TempDir()
tasksDir := filepath.Join(tmpDir, ".agents", "tasks")

if err := os.MkdirAll(tasksDir, 0o755); err != nil {
t.Fatalf("failed to create tasks dir: %v", err)
}

createStandardTask(t, tasksDir, "test-task")

// Run the program with local directory using file:// URL
localURL := "file://" + localDir
output := runTool(t, "-C", tmpDir, "-d", localURL, "test-task")

// Check that local rule content is present
if !strings.Contains(output, "# Local Rule") {
t.Errorf("local rule content not found in stdout")
}
if !strings.Contains(output, "This is a rule from a local directory") {
t.Errorf("local rule description not found in stdout")
}

// Verify the marker file still exists (directory was not deleted)
if _, err := os.Stat(markerFile); err != nil {
if os.IsNotExist(err) {
t.Errorf("marker file was deleted, indicating local directory was deleted")
} else {
t.Fatalf("unexpected error checking marker file: %v", err)
}
}

// Verify the rule file still exists
if _, err := os.Stat(ruleFile); err != nil {
if os.IsNotExist(err) {
t.Errorf("rule file was deleted, indicating local directory was deleted")
} else {
t.Fatalf("unexpected error checking rule file: %v", err)
}
}
}

// TestLocalDirectoryWithoutProtocol verifies that local directories passed
// without the file:// protocol are not deleted.
func TestLocalDirectoryWithoutProtocol(t *testing.T) {
// Create a local directory with a rule file and a marker file
localDir := t.TempDir()
rulesDir := filepath.Join(localDir, ".agents", "rules")

if err := os.MkdirAll(rulesDir, 0o755); err != nil {
t.Fatalf("failed to create rules dir: %v", err)
}

// Create a rule file
ruleFile := filepath.Join(rulesDir, "local-rule.md")
ruleContent := `---
language: go
---
# Local Rule

This is a rule from a local directory without protocol.
`
if err := os.WriteFile(ruleFile, []byte(ruleContent), 0o644); err != nil {
t.Fatalf("failed to write rule file: %v", err)
}

// Create a marker file to verify the directory is not deleted
markerFile := filepath.Join(localDir, "marker.txt")
if err := os.WriteFile(markerFile, []byte("marker"), 0o644); err != nil {
t.Fatalf("failed to write marker file: %v", err)
}

// Create a temporary directory for the task
tmpDir := t.TempDir()
tasksDir := filepath.Join(tmpDir, ".agents", "tasks")

if err := os.MkdirAll(tasksDir, 0o755); err != nil {
t.Fatalf("failed to create tasks dir: %v", err)
}

createStandardTask(t, tasksDir, "test-task")

// Run the program with local directory using absolute path (no protocol)
output := runTool(t, "-C", tmpDir, "-d", localDir, "test-task")

// Check that local rule content is present
if !strings.Contains(output, "# Local Rule") {
t.Errorf("local rule content not found in stdout")
}
if !strings.Contains(output, "This is a rule from a local directory without protocol") {
t.Errorf("local rule description not found in stdout")
}

// Verify the marker file still exists (directory was not deleted)
if _, err := os.Stat(markerFile); err != nil {
if os.IsNotExist(err) {
t.Errorf("marker file was deleted, indicating local directory was deleted")
} else {
t.Fatalf("unexpected error checking marker file: %v", err)
}
}

// Verify the rule file still exists
if _, err := os.Stat(ruleFile); err != nil {
if os.IsNotExist(err) {
t.Errorf("rule file was deleted, indicating local directory was deleted")
} else {
t.Fatalf("unexpected error checking rule file: %v", err)
}
}
}
46 changes: 46 additions & 0 deletions pkg/codingcontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,37 @@ func (cc *Context) Run(ctx context.Context, taskName string) (*Result, error) {
return result, nil
}

// isLocalPath checks if a path is a local file system path.
// Returns true for:
// - file:// URLs (e.g., file:///path/to/dir)
// - Absolute paths (e.g., /path/to/dir)
// - Relative paths (e.g., ./path or ../path)
// Returns false for remote protocols like git::, https://, s3::, etc.
func isLocalPath(path string) bool {
// Check if path starts with file:// protocol
if strings.HasPrefix(path, "file://") {
return true
}

// Check if it's an absolute or relative local path
// (no protocol prefix like git::, https://, s3::, etc.)
if !strings.Contains(path, "://") && !strings.Contains(path, "::") {
return true
}

return false
}

// normalizeLocalPath converts a local path to a usable file system path.
// For file:// URLs, it strips the protocol prefix.
// For other local paths, it returns them as-is.
func normalizeLocalPath(path string) string {
if strings.HasPrefix(path, "file://") {
return strings.TrimPrefix(path, "file://")
}
return path
}

func downloadDir(path string) string {
// hash the path and prepend it with a temporary directory
hash := sha256.Sum256([]byte(path))
Expand Down Expand Up @@ -397,6 +428,15 @@ func (cc *Context) parseManifestFile(ctx context.Context) ([]string, error) {

func (cc *Context) downloadRemoteDirectories(ctx context.Context) error {
for _, path := range cc.searchPaths {
// If the path is local, use it directly without downloading
if isLocalPath(path) {
localPath := normalizeLocalPath(path)
cc.logger.Info("Using local directory", "path", localPath)
cc.downloadedPaths = append(cc.downloadedPaths, localPath)
continue
}

// Download remote directories
cc.logger.Info("Downloading remote directory", "path", path)
dst := downloadDir(path)
if _, err := getter.Get(ctx, dst, path); err != nil {
Expand All @@ -411,6 +451,12 @@ func (cc *Context) downloadRemoteDirectories(ctx context.Context) error {

func (cc *Context) cleanupDownloadedDirectories() {
for _, path := range cc.searchPaths {
// Skip cleanup for local paths - they should not be deleted
if isLocalPath(path) {
continue
}

// Only clean up downloaded remote directories
dst := downloadDir(path)
if err := os.RemoveAll(dst); err != nil {
cc.logger.Error("Error cleaning up downloaded directory", "path", dst, "error", err)
Expand Down
113 changes: 113 additions & 0 deletions pkg/codingcontext/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1695,3 +1695,116 @@ func TestUserPrompt(t *testing.T) {
})
}
}

// TestIsLocalPath tests the isLocalPath helper function
func TestIsLocalPath(t *testing.T) {
tests := []struct {
name string
path string
expected bool
}{
{
name: "file:// protocol",
path: "file:///path/to/local",
expected: true,
},
{
name: "absolute path",
path: "/path/to/local",
expected: true,
},
{
name: "relative path - ./",
path: "./relative/path",
expected: true,
},
{
name: "relative path - ../",
path: "../relative/path",
expected: true,
},
{
name: "relative path - no prefix",
path: "relative/path",
expected: true,
},
{
name: "git protocol",
path: "git::https://github.com/user/repo.git",
expected: false,
},
{
name: "https protocol",
path: "https://example.com/file.tar.gz",
expected: false,
},
{
name: "http protocol",
path: "http://example.com/file.tar.gz",
expected: false,
},
{
name: "s3 protocol",
path: "s3::https://s3.amazonaws.com/bucket/key",
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isLocalPath(tt.path)
if result != tt.expected {
t.Errorf("isLocalPath(%q) = %v, expected %v", tt.path, result, tt.expected)
}
})
}
}

// TestNormalizeLocalPath tests the normalizeLocalPath helper function
func TestNormalizeLocalPath(t *testing.T) {
tests := []struct {
name string
path string
expected string
}{
{
name: "file:// protocol - absolute path",
path: "file:///path/to/local",
expected: "/path/to/local",
},
{
name: "file:// protocol - relative path",
path: "file://./relative/path",
expected: "./relative/path",
},
{
name: "absolute path without protocol",
path: "/path/to/local",
expected: "/path/to/local",
},
{
name: "relative path - ./",
path: "./relative/path",
expected: "./relative/path",
},
{
name: "relative path - ../",
path: "../relative/path",
expected: "../relative/path",
},
{
name: "relative path - no prefix",
path: "relative/path",
expected: "relative/path",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeLocalPath(tt.path)
if result != tt.expected {
t.Errorf("normalizeLocalPath(%q) = %q, expected %q", tt.path, result, tt.expected)
}
})
}
}