diff --git a/cmd/plug/gen.go b/cmd/plug/gen.go index 7b4a351..b9b0fca 100644 --- a/cmd/plug/gen.go +++ b/cmd/plug/gen.go @@ -2,17 +2,22 @@ package main import ( "bytes" + "errors" "fmt" "go/format" "go/types" "io" + "io/fs" + "log" "os" "path/filepath" "reflect" + "runtime" "strings" - "github.com/lufia/plug/plugcore" "golang.org/x/tools/go/ast/astutil" + + "github.com/lufia/plug/plugcore" ) type Stub struct { // Plug? @@ -23,11 +28,25 @@ type Stub struct { // Plug? func Rewrite(stub *Stub) (string, error) { filePath := stub.f.path name := filepath.Base(filePath) - dir := filepath.Join("plug", stub.f.pkg.path) + cacheDir, err := os.UserCacheDir() + if err != nil { + return "", fmt.Errorf("failed to get cachedir: %w", err) + } + dir := filepath.Join(cacheDir, "plug", runtime.Version(), stub.f.pkg.PathVersion()) if err := os.MkdirAll(dir, 0755); err != nil && !os.IsExist(err) { return "", fmt.Errorf("failed to create %s: %w", dir, err) } + if verbose { + log.Printf("cachedir: %s\n", dir) + } file := filepath.Join(dir, name) + _, err = os.Stat(file) + if err == nil { + return file, nil + } + if !errors.Is(err, fs.ErrNotExist) { + return "", fmt.Errorf("failed to stat %s: %w", file, err) + } w, err := os.Create(file) if err != nil { return "", fmt.Errorf("failed to create %s: %w", file, err) @@ -43,13 +62,9 @@ func Rewrite(stub *Stub) (string, error) { return file, nil } -func pkgPath(v any) string { - return reflect.TypeOf(v).PkgPath() -} - func rewriteFile(w io.Writer, stub *Stub) error { fset := stub.f.pkg.c.Fset - path := pkgPath(plugcore.Object{}) + path := reflect.TypeOf(plugcore.Object{}).PkgPath() astutil.AddImport(fset, stub.f.f, path) var buf bytes.Buffer @@ -57,7 +72,7 @@ func rewriteFile(w io.Writer, stub *Stub) error { rewriteFunc(&buf, fn) } if verbose { - fmt.Printf("====\n%s\n====\n", buf.Bytes()) + log.Printf("====\n%s\n====\n", buf.Bytes()) } s, err := format.Source(buf.Bytes()) if err != nil { diff --git a/cmd/plug/main.go b/cmd/plug/main.go index 03c70d0..faa23ce 100644 --- a/cmd/plug/main.go +++ b/cmd/plug/main.go @@ -33,7 +33,7 @@ func main() { flag.BoolVar(&verbose, "v", false, "enable verbose log") flag.Parse() - pkgPath, err := loadPackagePath(".") + pkgPath, modVers, err := loadPackagePath(".") if err != nil { log.Fatal(err) } @@ -41,7 +41,7 @@ func main() { if err != nil { log.Fatal(err) } - stubs := Group(syms) + stubs := Group(syms, modVers) var o Overlay for filePath, stub := range stubs { @@ -56,11 +56,11 @@ func main() { } } -func loadPackagePath(dir string) (string, error) { +func loadPackagePath(dir string) (string, map[string]string, error) { // loader.Import does not handle "." notation that means current package. dir, err := filepath.Abs(dir) if err != nil { - return "", err + return "", nil, err } s := dir file := filepath.Join(s, "go.mod") @@ -70,36 +70,47 @@ func loadPackagePath(dir string) (string, error) { break } if !os.IsNotExist(err) { - return "", err + return "", nil, err } up := filepath.Dir(s) if up == s { - return "", fmt.Errorf("go.mod is not exist") + return "", nil, fmt.Errorf("go.mod is not exist") } s = up file = filepath.Join(s, "go.mod") } data, err := os.ReadFile(file) if err != nil { - return "", err + return "", nil, err } - modPath := modfile.ModulePath(data) + + f, err := modfile.Parse(file, data, nil) + if err != nil { + return "", nil, err + } + modPath := f.Module.Mod.Path if modPath == "" { - return "", fmt.Errorf("%s: invalid go.mod syntax", file) + return "", nil, fmt.Errorf("%s: invalid go.mod syntax", file) } slug, err := filepath.Rel(s, dir) if err != nil { - return "", err + return "", nil, err + } + pkgPath := path.Join(modPath, filepath.ToSlash(slug)) + + modVers := make(map[string]string) + for _, r := range f.Require { + modVers[r.Mod.Path] = r.Mod.Version } - return path.Join(modPath, filepath.ToSlash(slug)), nil + return pkgPath, modVers, nil } // Group returns a map of Stub indexed by filePath. -func Group(syms []*Sym) map[string]*Stub { +func Group(syms []*Sym, modVers map[string]string) map[string]*Stub { stubs := make(map[string]*Stub) for _, sym := range syms { pkgPath := sym.PkgPath() - pkg, err := LoadPackage(pkgPath) + pkg, err := LoadPackage(pkgPath, modVers[pkgPath]) if err != nil { log.Fatalf("failed to load package %s: %v\n", pkgPath, err) } diff --git a/cmd/plug/pkg.go b/cmd/plug/pkg.go index f7f45f2..a171ec7 100644 --- a/cmd/plug/pkg.go +++ b/cmd/plug/pkg.go @@ -11,8 +11,17 @@ import ( type Pkg struct { *loader.PackageInfo - c *loader.Config - path string + c *loader.Config + path string + version string // If it is empty, maybe it is the stdlib +} + +func (pkg *Pkg) PathVersion() string { + s := pkg.path + if v := pkg.version; v != "" { + s += "@" + v + } + return s } type File struct { @@ -30,7 +39,7 @@ type Func struct { var pkgCache = make(map[string]*Pkg) -func LoadPackage(pkgPath string) (*Pkg, error) { +func LoadPackage(pkgPath, modVersion string) (*Pkg, error) { if pkg, ok := pkgCache[pkgPath]; ok { return pkg, nil } @@ -42,7 +51,7 @@ func LoadPackage(pkgPath string) (*Pkg, error) { if err != nil { return nil, err } - pkg := &Pkg{p.Package(pkgPath), &c, pkgPath} + pkg := &Pkg{p.Package(pkgPath), &c, pkgPath, modVersion} pkgCache[pkgPath] = pkg return pkg, nil }