diff --git a/_examples/status/job_status_enum.go b/_examples/status/job_status_enum.go index 482ed8b..3bb872d 100644 --- a/_examples/status/job_status_enum.go +++ b/_examples/status/job_status_enum.go @@ -10,7 +10,7 @@ import ( // JobStatus is the exported type for the enum type JobStatus struct { name string - value int + value uint8 } func (e JobStatus) String() string { return e.name } @@ -85,7 +85,7 @@ func MustJobStatus(v string) JobStatus { } // GetJobStatusByID gets the correspondent jobStatus enum value by its ID (raw integer value) -func GetJobStatusByID(v int) (JobStatus, error) { +func GetJobStatusByID(v uint8) (JobStatus, error) { switch v { case 1: return JobStatusActive, nil diff --git a/_examples/status/status_enum.go b/_examples/status/status_enum.go index 12b053e..60ec1e3 100644 --- a/_examples/status/status_enum.go +++ b/_examples/status/status_enum.go @@ -10,7 +10,7 @@ import ( // Status is the exported type for the enum type Status struct { name string - value int + value uint8 } func (e Status) String() string { return e.name } diff --git a/internal/generator/enum.go.tmpl b/internal/generator/enum.go.tmpl index b342859..befece0 100644 --- a/internal/generator/enum.go.tmpl +++ b/internal/generator/enum.go.tmpl @@ -13,7 +13,7 @@ import ( // {{.Type | title}} is the exported type for the enum type {{.Type | title}} struct { name string - value int + value {{.OriginalType}} } func (e {{.Type | title}}) String() string { return e.name } @@ -91,7 +91,7 @@ func Must{{.Type | title}}(v string) {{.Type | title}} { {{if .GenerateGetter -}} // Get{{.Type | title}}ByID gets the correspondent {{.Type}} enum value by its ID (raw integer value) -func Get{{.Type | title}}ByID(v int) ({{.Type | title}}, error) { +func Get{{.Type | title}}ByID(v {{.OriginalType}}) ({{.Type | title}}, error) { switch v { {{range .Values -}} case {{.Index}}: diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 1e65752..c0923a0 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -33,6 +33,7 @@ type Generator struct { pkgName string // package name from source file lowerCase bool // use lower case for marshal/unmarshal generateGetter bool // generate getter methods for enum values + originalType string // original type name (e.g., "uint8") } // Value represents a single enum value @@ -88,6 +89,9 @@ func (g *Generator) Parse(dir string) error { } } + if g.originalType == "" { + return fmt.Errorf("type %s not found", g.Type) + } if len(g.values) == 0 { return fmt.Errorf("no const values found for type %s", g.Type) } @@ -97,7 +101,22 @@ func (g *Generator) Parse(dir string) error { // parseFile processes a single file for enum declarations func (g *Generator) parseFile(file *ast.File) { - + parseTypeBlock := func(decl *ast.GenDecl) { + // extracts the type name from a const block + for _, spec := range decl.Specs { + vspec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + if vspec.Name.Name == g.Type { + tspec, ok := vspec.Type.(*ast.Ident) + if !ok { + continue + } + g.originalType = tspec.Name + } + } + } parseConstBlock := func(decl *ast.GenDecl) { // extracts enum values from a const block var iotaVal int @@ -155,8 +174,13 @@ func (g *Generator) parseFile(file *ast.File) { } ast.Inspect(file, func(n ast.Node) bool { - if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.CONST { - parseConstBlock(decl) + if decl, ok := n.(*ast.GenDecl); ok { + switch decl.Tok { + case token.CONST: + parseConstBlock(decl) + case token.TYPE: + parseTypeBlock(decl) + } } return true }) @@ -250,12 +274,14 @@ func (g *Generator) Generate() error { Package string LowerCase bool GenerateGetter bool + OriginalType string }{ Type: g.Type, Values: values, Package: pkgName, LowerCase: g.lowerCase, GenerateGetter: g.generateGetter, + OriginalType: g.originalType, } // execute template diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go index 2518f35..0d4b179 100644 --- a/internal/generator/generator_test.go +++ b/internal/generator/generator_test.go @@ -238,7 +238,7 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) // check content - assert.Contains(t, string(content), "func GetJobStatusByID(v int) (JobStatus, error)") + assert.Contains(t, string(content), "func GetJobStatusByID(v uint8) (JobStatus, error)") assert.Contains(t, string(content), "case 0:\n\t\treturn JobStatusUnknown, nil") assert.Contains(t, string(content), "case 1:\n\t\treturn JobStatusActive, nil") assert.Contains(t, string(content), "case 2:\n\t\treturn JobStatusInactive, nil") @@ -266,7 +266,7 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) // check content - assert.Contains(t, string(content), "func GetExplicitValuesByID(v int) (ExplicitValues, error)") + assert.Contains(t, string(content), "func GetExplicitValuesByID(v uint8) (ExplicitValues, error)") assert.Contains(t, string(content), "case 10:\n\t\treturn ExplicitValuesFirst, nil") assert.Contains(t, string(content), "case 20:\n\t\treturn ExplicitValuesSecond, nil") assert.Contains(t, string(content), "case 30:\n\t\treturn ExplicitValuesThird, nil") @@ -442,6 +442,7 @@ func TestPermissions(t *testing.T) { // create a sample status file sampleFile := `package source +type status uint8 const ( statusUnknown = iota statusActive @@ -540,6 +541,7 @@ func TestParseSpecialCases(t *testing.T) { tmpDir := t.TempDir() err := os.WriteFile(filepath.Join(tmpDir, "empty.go"), []byte(` package test +type status uint8 const ( ) `), 0o644) @@ -557,6 +559,7 @@ const ( tmpDir := t.TempDir() err := os.WriteFile(filepath.Join(tmpDir, "no_values.go"), []byte(` package test +type status uint8 const name string `), 0o644) require.NoError(t, err) @@ -568,6 +571,26 @@ const name string require.Error(t, err) assert.Contains(t, err.Error(), "no const values found for type status") }) + + t.Run("no status type", func(t *testing.T) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, "no_type.go"), []byte(` +package test +const ( + statusUnknown = iota + statusActive + statusInactive +) +`), 0o644) + require.NoError(t, err) + + gen, err := New("status", "") + require.NoError(t, err) + + err = gen.Parse(tmpDir) + require.Error(t, err) + assert.Contains(t, err.Error(), "type status not found") + }) } func TestSplitCamelCase(t *testing.T) {