diff --git a/internal/api/api.go b/internal/api/api.go index 2c4fa079ac..41a478074d 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -32,13 +32,14 @@ type API struct { logger logging.Logger session *project.Session - projects map[Handle[project.Project]]tspath.Path - filesMu sync.Mutex - files handleMap[ast.SourceFile] - symbolsMu sync.Mutex - symbols handleMap[ast.Symbol] - typesMu sync.Mutex - types handleMap[checker.Type] + projectsMu sync.RWMutex + projects map[Handle[project.Project]]tspath.Path + filesMu sync.Mutex + files handleMap[ast.SourceFile] + symbolsMu sync.Mutex + symbols handleMap[ast.Symbol] + typesMu sync.Mutex + types handleMap[checker.Type] } func NewAPI(init *APIInit) *API { @@ -145,12 +146,18 @@ func (api *API) LoadProject(ctx context.Context, configFileName string) (*Projec return nil, err } data := NewProjectResponse(project) + // Acquire write lock to safely add project to the map + api.projectsMu.Lock() api.projects[data.Id] = project.ConfigFilePath() + api.projectsMu.Unlock() return data, nil } func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[project.Project], fileName string, position int) (*SymbolResponse, error) { + // Acquire read lock to safely access projects map + api.projectsMu.RLock() projectPath, ok := api.projects[projectId] + api.projectsMu.RUnlock() if !ok { return nil, errors.New("project ID not found") } @@ -174,7 +181,10 @@ func (api *API) GetSymbolAtPosition(ctx context.Context, projectId Handle[projec } func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[project.Project], location Handle[ast.Node]) (*SymbolResponse, error) { + // Acquire read lock to safely access projects map + api.projectsMu.RLock() projectPath, ok := api.projects[projectId] + api.projectsMu.RUnlock() if !ok { return nil, errors.New("project ID not found") } @@ -216,7 +226,10 @@ func (api *API) GetSymbolAtLocation(ctx context.Context, projectId Handle[projec } func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Project], symbolHandle Handle[ast.Symbol]) (*TypeResponse, error) { + // Acquire read lock to safely access projects map + api.projectsMu.RLock() projectPath, ok := api.projects[projectId] + api.projectsMu.RUnlock() if !ok { return nil, errors.New("project ID not found") } @@ -242,7 +255,10 @@ func (api *API) GetTypeOfSymbol(ctx context.Context, projectId Handle[project.Pr } func (api *API) GetSourceFile(projectId Handle[project.Project], fileName string) (*ast.SourceFile, error) { + // Acquire read lock to safely access projects map + api.projectsMu.RLock() projectPath, ok := api.projects[projectId] + api.projectsMu.RUnlock() if !ok { return nil, errors.New("project ID not found") } @@ -267,11 +283,15 @@ func (api *API) releaseHandle(handle string) error { switch handle[0] { case handlePrefixProject: projectId := Handle[project.Project](handle) + // Acquire write lock to safely delete project from the map + api.projectsMu.Lock() _, ok := api.projects[projectId] if !ok { + api.projectsMu.Unlock() return fmt.Errorf("project %q not found", handle) } delete(api.projects, projectId) + api.projectsMu.Unlock() case handlePrefixFile: fileId := Handle[ast.SourceFile](handle) api.filesMu.Lock()