From 6830af31fa8ea99bd28e1c61df2c1ed5deeae4c6 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Tue, 25 Oct 2022 22:45:09 +0300 Subject: [PATCH 01/16] Complete CustomType and TypeParameter struct implementations --- argument.go | 2 +- cmd/marker/processor.go | 5 +- collector.go | 22 +- definition.go | 2 +- marker.go | 42 ++- marker_test.go | 26 +- markers.go | 42 +-- markers_test.go | 10 +- registry.go | 12 +- scanner.go | 12 +- target.go | 2 +- target_test.go | 16 +- test/menu/dessert.go | 2 +- util.go | 12 +- util_test.go | 12 +- visitor/collector.go | 19 +- visitor/custom_type.go | 162 +++++++++-- visitor/file.go | 19 +- visitor/file_test.go | 5 + visitor/function.go | 540 +++++++++++++++++++++++++------------ visitor/function_test.go | 47 ++-- visitor/generic_type.go | 46 ++++ visitor/interface.go | 368 ++++++++++++++++--------- visitor/interface_test.go | 10 +- visitor/pointer.go | 25 ++ visitor/position.go | 19 ++ visitor/struct.go | 423 ++++++++++++++++++----------- visitor/struct_test.go | 6 +- visitor/type.go | 273 +++++++++---------- visitor/type_constraint.go | 48 ++++ visitor/type_parameter.go | 64 +++++ visitor/variadic.go | 21 ++ visitor/visitor.go | 10 +- visitor/visitor_test.go | 12 +- 34 files changed, 1534 insertions(+), 802 deletions(-) create mode 100644 visitor/generic_type.go create mode 100644 visitor/pointer.go create mode 100644 visitor/position.go create mode 100644 visitor/type_constraint.go create mode 100644 visitor/type_parameter.go create mode 100644 visitor/variadic.go diff --git a/argument.go b/argument.go index c51f8c1..906fe9d 100644 --- a/argument.go +++ b/argument.go @@ -15,7 +15,7 @@ type Argument struct { } func ExtractArgument(structField reflect.StructField) (Argument, error) { - parameterName := UpperCamelCase(structField.Name) + parameterName := upperCamelCase(structField.Name) parameterTag, parameterTagExists := structField.Tag.Lookup("parameter") if parameterTagExists && parameterTag != "" { diff --git a/cmd/marker/processor.go b/cmd/marker/processor.go index 5d9f7b0..c183d24 100644 --- a/cmd/marker/processor.go +++ b/cmd/marker/processor.go @@ -2,7 +2,6 @@ package main import ( "errors" - "github.com/procyon-projects/marker" "github.com/procyon-projects/marker/packages" "github.com/procyon-projects/marker/processor" "github.com/spf13/cobra" @@ -183,9 +182,7 @@ func addProcessorCommand(processorName string) error { return err } - if !markers.IsLower(processorName) { - return errors.New("processor name must only contain lower case letters") - } + processorName = strings.ToLower(processorName) yamlPath := filepath.FromSlash(path.Join(wd, "marker.processors.yaml")) _, err = os.Stat(yamlPath) diff --git a/collector.go b/collector.go index 4b3a660..b518f2f 100644 --- a/collector.go +++ b/collector.go @@ -19,7 +19,7 @@ func NewCollector(registry *Registry) *Collector { } } -func (collector *Collector) Collect(pkg *packages.Package) (map[ast.Node]MarkerValues, error) { +func (collector *Collector) Collect(pkg *packages.Package) (map[ast.Node]Values, error) { if pkg == nil { return nil, errors.New("pkg(package) cannot be nil") @@ -57,14 +57,14 @@ func (collector *Collector) collectFileMarkerComments(file *ast.File) map[ast.No return visitor.nodeMarkers } -func (collector *Collector) parseMarkerComments(pkg *packages.Package, nodeMarkerComments map[ast.Node][]markerComment) (map[ast.Node]MarkerValues, error) { +func (collector *Collector) parseMarkerComments(pkg *packages.Package, nodeMarkerComments map[ast.Node][]markerComment) (map[ast.Node]Values, error) { importNodeMarkers, err := collector.parseImportMarkerComments(pkg, nodeMarkerComments) if err != nil { return nil, err } - nodeMarkerValues := make(map[ast.Node]MarkerValues) + nodeMarkerValues := make(map[ast.Node]Values) if importNodeMarkers != nil { for importNode, importMarker := range importNodeMarkers { @@ -82,14 +82,14 @@ func (collector *Collector) parseMarkerComments(pkg *packages.Package, nodeMarke var errs []error for node, markerComments := range nodeMarkerComments { - markerValues := make(MarkerValues) + markerValues := make(Values) file := pkg.Fset.File(node.Pos()) importAliases := fileImportAliases[file] for _, markerComment := range markerComments { markerText := markerComment.Text() markerName, _, _ := splitMarker(markerText) - targetLevel := FindTargetLevelFromNode(node) + targetLevel := FindTargetLevel(node) alias := strings.SplitN(markerName, ":", 2)[0] var definition *Definition @@ -138,13 +138,13 @@ func (collector *Collector) parseMarkerComments(pkg *packages.Package, nodeMarke return nodeMarkerValues, NewErrorList(errs) } -func (collector *Collector) parseImportMarkerComments(pkg *packages.Package, nodeMarkerComments map[ast.Node][]markerComment) (map[ast.Node]MarkerValues, error) { +func (collector *Collector) parseImportMarkerComments(pkg *packages.Package, nodeMarkerComments map[ast.Node][]markerComment) (map[ast.Node]Values, error) { var errs []error - importNodeMarkers := make(map[ast.Node]MarkerValues) + importNodeMarkers := make(map[ast.Node]Values) for node, markerComments := range nodeMarkerComments { - markerValues := make(MarkerValues) + markerValues := make(Values) for _, markerComment := range markerComments { markerText := markerComment.Text() @@ -194,9 +194,9 @@ func (collector *Collector) parseImportMarkerComments(pkg *packages.Package, nod return importNodeMarkers, NewErrorList(errs) } -type AliasMap map[string]ImportMarker +type AliasMap map[string]Import -func (collector *Collector) extractFileImportAliases(pkg *packages.Package, importNodeMarkers map[ast.Node]MarkerValues) (map[*token.File]AliasMap, error) { +func (collector *Collector) extractFileImportAliases(pkg *packages.Package, importNodeMarkers map[ast.Node]Values) (map[*token.File]AliasMap, error) { var errs []error var fileImportAliases = make(map[*token.File]AliasMap, 0) @@ -221,7 +221,7 @@ func (collector *Collector) extractFileImportAliases(pkg *packages.Package, impo pkgIdMap := make(map[string]bool, 0) for _, marker := range markers { - importMarker := marker.(ImportMarker) + importMarker := marker.(Import) if _, ok := pkgIdMap[importMarker.Pkg]; ok { position := pkg.Fset.Position(node.Pos()) diff --git a/definition.go b/definition.go index 9a47165..7b05f89 100644 --- a/definition.go +++ b/definition.go @@ -73,7 +73,7 @@ func (definition *Definition) validate() error { return fmt.Errorf("specify target levels for the definition: %v", definition.Name) } - if !IsLower(definition.Name) { + if !isLower(definition.Name) { return fmt.Errorf("marker '%s' should only contain lower case characters", definition.Name) } diff --git a/marker.go b/marker.go index fa2bdd8..a1a9a06 100644 --- a/marker.go +++ b/marker.go @@ -10,40 +10,56 @@ type Validate interface { Validate() error } -type MarkerValues map[string][]any +type Values map[string][]any + +func (v Values) Count() int { + if len(v) == 0 { + return 0 + } -func (markerValues MarkerValues) Count() int { count := 0 - for _, markers := range markerValues { + for _, markers := range v { count = count + len(markers) } return count } -func (markerValues MarkerValues) AllMarkers(name string) []any { - result := markerValues[name] +func (v Values) FindByName(name string) ([]any, bool) { + if len(v) == 0 { + return nil, false + } + + result := v[name] if len(result) == 0 { - return nil + return nil, false } - return result + return result, true } -func (markerValues MarkerValues) First(name string) any { - result := markerValues[name] +func (v Values) First(name string) (any, bool) { + if len(v) == 0 { + return nil, false + } + + result := v[name] if len(result) == 0 { - return nil + return nil, false } - return result[0] + return result[0], true } -func (markerValues MarkerValues) CountByName(name string) int { - result := markerValues[name] +func (v Values) CountByName(name string) int { + if len(v) == 0 { + return 0 + } + + result := v[name] return len(result) } diff --git a/marker_test.go b/marker_test.go index 714eb22..bc75c45 100644 --- a/marker_test.go +++ b/marker_test.go @@ -6,17 +6,22 @@ import ( ) func TestMarkerValues_AllMarkers(t *testing.T) { - markerValues := make(MarkerValues) + markerValues := make(Values) markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest1") markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest2") markerValues["anyMarker2"] = append(markerValues["anyMarker2"], "anyTest3") - assert.Equal(t, []interface{}{"anyTest1", "anyTest2"}, markerValues.AllMarkers("anyMarker1")) - assert.Equal(t, []interface{}{"anyTest3"}, markerValues.AllMarkers("anyMarker2")) + markers, exists := markerValues.FindByName("anyMarker1") + assert.True(t, exists) + assert.Equal(t, []interface{}{"anyTest1", "anyTest2"}, markers) + + markers, exists = markerValues.FindByName("anyMarker2") + assert.True(t, exists) + assert.Equal(t, []interface{}{"anyTest3"}, markers) } func TestMarkerValues_Count(t *testing.T) { - markerValues := make(MarkerValues) + markerValues := make(Values) markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest1") markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest2") markerValues["anyMarker2"] = append(markerValues["anyMarker2"], "anyTest3") @@ -25,7 +30,7 @@ func TestMarkerValues_Count(t *testing.T) { } func TestMarkerValues_CountByName(t *testing.T) { - markerValues := make(MarkerValues) + markerValues := make(Values) markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest1") markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest2") markerValues["anyMarker2"] = append(markerValues["anyMarker2"], "anyTest3") @@ -35,11 +40,16 @@ func TestMarkerValues_CountByName(t *testing.T) { } func TestMarkerValues_First(t *testing.T) { - markerValues := make(MarkerValues) + markerValues := make(Values) markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest1") markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest2") markerValues["anyMarker2"] = append(markerValues["anyMarker2"], "anyTest3") - assert.Equal(t, "anyTest1", markerValues.First("anyMarker1")) - assert.Equal(t, "anyTest3", markerValues.First("anyMarker2")) + marker, exists := markerValues.First("anyMarker1") + assert.True(t, exists) + assert.Equal(t, "anyTest1", marker) + + marker, exists = markerValues.First("anyMarker2") + assert.True(t, exists) + assert.Equal(t, "anyTest3", marker) } diff --git a/markers.go b/markers.go index 5b8f1b2..940cc94 100644 --- a/markers.go +++ b/markers.go @@ -7,21 +7,21 @@ import ( // Reserved markers const ( - ImportMarkerName = "import" - DeprecatedMarkerName = "deprecated" - OverrideMarkerName = "override" - DefinitionMarkerName = "marker" - DefinitionParameterMarkerName = "marker:parameter" - DefinitionEnumMarkerName = "marker:enum" + ImportMarkerName = "import" + DeprecatedMarkerName = "deprecated" + OverrideMarkerName = "override" + MarkerName = "marker" + ParameterMarkerName = "marker:parameter" + EnumMarkerName = "marker:enum" ) var reservedMarkerMap = map[string]struct{}{ - ImportMarkerName: {}, - DeprecatedMarkerName: {}, - OverrideMarkerName: {}, - DefinitionMarkerName: {}, - DefinitionParameterMarkerName: {}, - DefinitionEnumMarkerName: {}, + ImportMarkerName: {}, + DeprecatedMarkerName: {}, + OverrideMarkerName: {}, + MarkerName: {}, + ParameterMarkerName: {}, + EnumMarkerName: {}, } func IsReservedMarker(marker string) bool { @@ -32,13 +32,13 @@ func IsReservedMarker(marker string) bool { return false } -type ImportMarker struct { +type Import struct { Value string `parameter:"Value" required:"true"` Alias string `parameter:"Alias" required:"false"` Pkg string `parameter:"Pkg" required:"true"` } -func (m ImportMarker) Validate() error { +func (m Import) Validate() error { var errs []error if strings.Trim(m.Value, " \t") == "" { @@ -56,12 +56,12 @@ func (m ImportMarker) Validate() error { return nil } -func (m ImportMarker) PkgPath() string { +func (m Import) PkgPath() string { pkgParts := strings.Split(m.Pkg, "@") return pkgParts[0] } -func (m ImportMarker) PkgVersion() string { +func (m Import) PkgVersion() string { pkgParts := strings.Split(m.Pkg, "@") if len(pkgParts) > 1 { @@ -72,15 +72,15 @@ func (m ImportMarker) PkgVersion() string { return "latest" } -type DeprecatedMarker struct { +type Deprecated struct { Value string `parameter:"Value"` } -type OverrideMarker struct { +type Override struct { Value string `parameter:"Value"` } -type DefinitionMarker struct { +type Marker struct { Value string `parameter:"Value" required:"true"` Description string `parameter:"Description" required:"true"` Repeatable bool `parameter:"Repeatable" required:"false"` @@ -88,7 +88,7 @@ type DefinitionMarker struct { Targets []string `parameter:"Targets" required:"true" enum:"PACKAGE_LEVEL,STRUCT_TYPE_LEVEL,INTERFACE_TYPE_LEVEL,FIELD_LEVEL,FUNCTION_LEVEL,STRUCT_METHOD_LEVEL,INTERFACE_METHOD_LEVEL"` } -type DefinitionParameterMarker struct { +type Parameter struct { Value string `parameter:"Value" required:"true"` Description string `parameter:"Description" required:"true"` Required bool `parameter:"Required" required:"false"` @@ -96,7 +96,7 @@ type DefinitionParameterMarker struct { Default any `parameter:"Default" required:"false"` } -type DefinitionEnumMarker struct { +type Enum struct { Value string `parameter:"Value" required:"true"` Name string `parameter:"Name" required:"true"` } diff --git a/markers_test.go b/markers_test.go index 358016f..02cd5ab 100644 --- a/markers_test.go +++ b/markers_test.go @@ -13,14 +13,14 @@ func TestIsReservedMarker(t *testing.T) { } func TestImportMarker_PkgPath(t *testing.T) { - importMarker := &ImportMarker{ + importMarker := &Import{ Pkg: "github.com/procyon-projects/marker@v1.2.3", } assert.Equal(t, "github.com/procyon-projects/marker", importMarker.PkgPath()) } func TestImportMarker_PkgVersion(t *testing.T) { - importMarker := &ImportMarker{ + importMarker := &Import{ Pkg: "github.com/procyon-projects/marker@v1.2.3", } assert.Equal(t, "v1.2.3", importMarker.PkgVersion()) @@ -28,7 +28,7 @@ func TestImportMarker_PkgVersion(t *testing.T) { } func TestImportMarker_PkgVersionLatest(t *testing.T) { - importMarker := &ImportMarker{ + importMarker := &Import{ Pkg: "github.com/procyon-projects/marker", } assert.Equal(t, "latest", importMarker.PkgVersion()) @@ -36,14 +36,14 @@ func TestImportMarker_PkgVersionLatest(t *testing.T) { } func TestImportMarker_Validate_IfValueIsMissing(t *testing.T) { - importMarker := &ImportMarker{ + importMarker := &Import{ Pkg: "github.com/procyon-projects/marker", } assert.Error(t, importMarker.Validate()) } func TestImportMarker_Validate_IfPkgIsMissing(t *testing.T) { - importMarker := &ImportMarker{ + importMarker := &Import{ Value: "anyValue", } assert.Error(t, importMarker.Validate()) diff --git a/registry.go b/registry.go index 536426d..c1535ae 100644 --- a/registry.go +++ b/registry.go @@ -27,19 +27,19 @@ func (registry *Registry) initialize() { registry.packageMap = make(map[string]DefinitionMap) registry.packageMap[""] = make(DefinitionMap) } - registry.packageMap[""][ImportMarkerName], _ = MakeDefinition(ImportMarkerName, "", PackageLevel, &ImportMarker{}) + registry.packageMap[""][ImportMarkerName], _ = MakeDefinition(ImportMarkerName, "", PackageLevel, &Import{}) - overrideMarker, _ := MakeDefinition(OverrideMarkerName, "", StructMethodLevel, &OverrideMarker{}) + overrideMarker, _ := MakeDefinition(OverrideMarkerName, "", StructMethodLevel, &Override{}) overrideMarker.Output.SyntaxFree = true registry.packageMap[""][OverrideMarkerName] = overrideMarker - deprecatedDefinitionMarker, _ := MakeDefinition(DeprecatedMarkerName, "", TypeLevel|MethodLevel|FieldLevel|FunctionLevel, &DeprecatedMarker{}) + deprecatedDefinitionMarker, _ := MakeDefinition(DeprecatedMarkerName, "", TypeLevel|MethodLevel|FieldLevel|FunctionLevel, &Deprecated{}) deprecatedDefinitionMarker.Output.SyntaxFree = true registry.packageMap[""][DeprecatedMarkerName] = deprecatedDefinitionMarker - registry.packageMap[""][DefinitionMarkerName], _ = MakeDefinition(DefinitionMarkerName, "", StructTypeLevel, &DefinitionMarker{}) - registry.packageMap[""][DefinitionParameterMarkerName], _ = MakeDefinition(DefinitionParameterMarkerName, "", FieldLevel, &DefinitionParameterMarker{}) - registry.packageMap[""][DefinitionEnumMarkerName], _ = MakeDefinition(DefinitionEnumMarkerName, "", FieldLevel, &DefinitionEnumMarker{}) + registry.packageMap[""][MarkerName], _ = MakeDefinition(MarkerName, "", StructTypeLevel, &Definition{}) + registry.packageMap[""][ParameterMarkerName], _ = MakeDefinition(ParameterMarkerName, "", FieldLevel, &Parameter{}) + registry.packageMap[""][EnumMarkerName], _ = MakeDefinition(EnumMarkerName, "", FieldLevel, &Enum{}) } // Register registers a new marker with the given name, target level, and output type. diff --git a/scanner.go b/scanner.go index 88e50c5..9fefbff 100644 --- a/scanner.go +++ b/scanner.go @@ -116,10 +116,10 @@ func (scanner *Scanner) Scan() rune { character := scanner.SkipWhitespaces() token := character - if IsIdentifier(character, 0) { + if isIdentifier(character, 0) { token = Identifier character = scanner.ScanIdentifier() - } else if IsDecimal(character) { + } else if isDecimal(character) { token = IntegerValue character = scanner.ScanNumber() } else if character == EOF { @@ -143,13 +143,13 @@ func (scanner *Scanner) Scan() rune { } func (scanner *Scanner) ScanNumber() rune { - if IsDecimal(scanner.SkipWhitespaces()) { + if isDecimal(scanner.SkipWhitespaces()) { scanner.tokenStartPosition = scanner.searchIndex } character := scanner.SkipWhitespaces() - for IsDecimal(character) { + for isDecimal(character) { character = scanner.Next() } @@ -159,13 +159,13 @@ func (scanner *Scanner) ScanNumber() rune { } func (scanner *Scanner) ScanIdentifier() rune { - if IsIdentifier(scanner.SkipWhitespaces(), 1) { + if isIdentifier(scanner.SkipWhitespaces(), 1) { scanner.tokenStartPosition = scanner.searchIndex } character := scanner.SkipWhitespaces() - for index := 1; IsIdentifier(character, index); index++ { + for index := 1; isIdentifier(character, index); index++ { character = scanner.Next() } diff --git a/target.go b/target.go index 17da8a0..13061b3 100644 --- a/target.go +++ b/target.go @@ -32,7 +32,7 @@ const ( AllLevels = PackageLevel | TypeLevel | MethodLevel | FieldLevel | FunctionLevel ) -func FindTargetLevelFromNode(node ast.Node) TargetLevel { +func FindTargetLevel(node ast.Node) TargetLevel { switch typedNode := node.(type) { case *ast.TypeSpec: _, isStructType := typedNode.Type.(*ast.StructType) diff --git a/target_test.go b/target_test.go index 0610145..b448aa8 100644 --- a/target_test.go +++ b/target_test.go @@ -7,20 +7,20 @@ import ( ) func TestFindTargetLevelFromNode(t *testing.T) { - assert.Equal(t, StructTypeLevel, FindTargetLevelFromNode(&ast.TypeSpec{ + assert.Equal(t, StructTypeLevel, FindTargetLevel(&ast.TypeSpec{ Type: &ast.StructType{}, })) - assert.Equal(t, InterfaceTypeLevel, FindTargetLevelFromNode(&ast.TypeSpec{ + assert.Equal(t, InterfaceTypeLevel, FindTargetLevel(&ast.TypeSpec{ Type: &ast.InterfaceType{}, })) - assert.Equal(t, FieldLevel, FindTargetLevelFromNode(&ast.Field{})) - assert.Equal(t, InterfaceMethodLevel, FindTargetLevelFromNode(&ast.Field{ + assert.Equal(t, FieldLevel, FindTargetLevel(&ast.Field{})) + assert.Equal(t, InterfaceMethodLevel, FindTargetLevel(&ast.Field{ Type: &ast.FuncType{}, })) - assert.Equal(t, StructMethodLevel, FindTargetLevelFromNode(&ast.FuncDecl{ + assert.Equal(t, StructMethodLevel, FindTargetLevel(&ast.FuncDecl{ Recv: &ast.FieldList{}, })) - assert.Equal(t, FunctionLevel, FindTargetLevelFromNode(&ast.FuncDecl{})) - assert.Equal(t, PackageLevel, FindTargetLevelFromNode(&ast.Package{})) - assert.Equal(t, InvalidLevel, FindTargetLevelFromNode(nil)) + assert.Equal(t, FunctionLevel, FindTargetLevel(&ast.FuncDecl{})) + assert.Equal(t, PackageLevel, FindTargetLevel(&ast.Package{})) + assert.Equal(t, InvalidLevel, FindTargetLevel(nil)) } diff --git a/test/menu/dessert.go b/test/menu/dessert.go index f2e62b1..134f0bb 100644 --- a/test/menu/dessert.go +++ b/test/menu/dessert.go @@ -105,7 +105,7 @@ type Dessert interface { // muffin is a method // +marker:interface-method-level:Name=muffin - muffin() (string, error) + muffin() (*string, error) } // MakeACake is a function diff --git a/util.go b/util.go index cab7546..749c15c 100644 --- a/util.go +++ b/util.go @@ -5,7 +5,7 @@ import ( "unicode" ) -func IsUpper(s string) bool { +func isUpper(s string) bool { for _, r := range s { if !unicode.IsUpper(r) && unicode.IsLetter(r) { return false @@ -15,7 +15,7 @@ func IsUpper(s string) bool { return true } -func IsLower(s string) bool { +func isLower(s string) bool { for _, r := range s { if !unicode.IsLower(r) && unicode.IsLetter(r) { return false @@ -25,15 +25,15 @@ func IsLower(s string) bool { return true } -func IsDecimal(character rune) bool { +func isDecimal(character rune) bool { return '0' <= character && character <= '9' } -func IsIdentifier(character rune, index int) bool { +func isIdentifier(character rune, index int) bool { return character == '_' || unicode.IsLetter(character) || unicode.IsDigit(character) && index > 0 } -func LowerCamelCase(str string) string { +func lowerCamelCase(str string) string { isFirst := true return strings.Map(func(r rune) rune { @@ -47,7 +47,7 @@ func LowerCamelCase(str string) string { } -func UpperCamelCase(str string) string { +func upperCamelCase(str string) string { isFirst := true return strings.Map(func(r rune) rune { diff --git a/util_test.go b/util_test.go index d68d1c8..2488722 100644 --- a/util_test.go +++ b/util_test.go @@ -6,19 +6,19 @@ import ( ) func TestIsLower(t *testing.T) { - assert.True(t, IsLower("any")) - assert.False(t, IsLower("Any")) + assert.True(t, isLower("any")) + assert.False(t, isLower("Any")) } func TestIsUpper(t *testing.T) { - assert.True(t, IsUpper("ANY")) - assert.False(t, IsUpper("Any")) + assert.True(t, isUpper("ANY")) + assert.False(t, isUpper("Any")) } func TestLowerCamelCase(t *testing.T) { - assert.Equal(t, "testAny", LowerCamelCase("TestAny")) + assert.Equal(t, "testAny", lowerCamelCase("TestAny")) } func TestUpperCamelCase(t *testing.T) { - assert.Equal(t, "TestAny", UpperCamelCase("testAny")) + assert.Equal(t, "TestAny", upperCamelCase("testAny")) } diff --git a/visitor/collector.go b/visitor/collector.go index 518d481..409bb15 100644 --- a/visitor/collector.go +++ b/visitor/collector.go @@ -10,7 +10,7 @@ type packageCollector struct { unprocessedTypes map[string]map[string]Type - importTypes map[string]*ImportedType + importTypes map[string]Type } func newPackageCollector() *packageCollector { @@ -20,7 +20,7 @@ func newPackageCollector() *packageCollector { files: make(map[string]*Files), packages: make(map[string]*packages.Package), unprocessedTypes: make(map[string]map[string]Type), - importTypes: make(map[string]*ImportedType), + importTypes: make(map[string]Type), } } @@ -70,7 +70,7 @@ func (collector *packageCollector) addFile(pkgId string, file *File) { collector.files[pkgId].elements = append(collector.files[pkgId].elements, file) } -func (collector *packageCollector) findTypeByImportAndTypeName(importName, typeName string, file *File) *ImportedType { +func (collector *packageCollector) findTypeByImportAndTypeName(importName, typeName string, file *File) Type { if importedType, ok := collector.importTypes[importName+"#"+typeName]; ok { return importedType } @@ -88,19 +88,10 @@ func (collector *packageCollector) findTypeByImportAndTypeName(importName, typeN typ, exists := collector.findTypeByPkgIdAndName(packageImport.path, typeName) if exists { - importedType := &ImportedType{ - collector.packages[packageImport.path], - typ, - } - collector.importTypes[packageImport.path+"#"+typeName] = importedType + collector.importTypes[packageImport.path+"#"+typeName] = typ } - importedType := &ImportedType{ - pkg: collector.packages[packageImport.path], - typ: typ, - } - collector.importTypes[packageImport.path+"#"+typeName] = importedType - return importedType + return typ } func (collector *packageCollector) findTypeByPkgIdAndName(pkgId, typeName string) (Type, bool) { diff --git a/visitor/custom_type.go b/visitor/custom_type.go index 542ffa5..656f664 100644 --- a/visitor/custom_type.go +++ b/visitor/custom_type.go @@ -5,39 +5,64 @@ import ( "github.com/procyon-projects/marker" "github.com/procyon-projects/marker/packages" "go/ast" + "strings" + "sync" ) type CustomType struct { - name string - aliasType Type - isExported bool - position Position - markers markers.MarkerValues - methods []*Function - file *File + name string + underlyingType Type + isExported bool + position Position + markers markers.Values + typeParams *TypeParameters + methods []*Function + file *File + + pkg *packages.Package + specType *ast.TypeSpec isProcessed bool visitor *packageVisitor + + typeParamsOnce sync.Once } -func newCustomType(specType *ast.TypeSpec, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.MarkerValues) *CustomType { +func newCustomType(specType *ast.TypeSpec, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *CustomType { customType := &CustomType{ - name: specType.Name.Name, - aliasType: getTypeFromExpression(specType.Type, file, visitor), - isExported: ast.IsExported(specType.Name.Name), - position: getPosition(file.Package(), specType.Pos()), - markers: markers, - methods: make([]*Function, 0), - file: file, + markers: markers, + methods: make([]*Function, 0), + typeParams: &TypeParameters{ + []*TypeParameter{}, + }, isProcessed: true, + file: file, visitor: visitor, + pkg: pkg, + specType: specType, } - return customType.initialize() + return customType.initialize(specType, file, pkg) } -func (c *CustomType) initialize() *CustomType { - c.file.customTypes.elements = append(c.file.customTypes.elements, c) +func (c *CustomType) initialize(specType *ast.TypeSpec, file *File, pkg *packages.Package) *CustomType { + c.isProcessed = true + c.specType = specType + c.file = file + c.pkg = pkg + + if specType != nil { + c.name = specType.Name.Name + c.isExported = ast.IsExported(specType.Name.Name) + c.position = getPosition(file.Package(), specType.Pos()) + + c.loadTypeParams() + c.underlyingType = getTypeFromExpression(specType.Type, file, c.visitor, c, c.typeParams) + if _, exists := file.customTypes.FindByName(c.name); !exists { + c.file.customTypes.elements = append(c.file.customTypes.elements, c) + } + } + return c } @@ -49,16 +74,107 @@ func (c *CustomType) IsExported() bool { return c.isExported } -func (c *CustomType) AliasType() Type { - return c.aliasType +func (c *CustomType) Underlying() Type { + return c.underlyingType } -func (c *CustomType) Underlying() Type { - return c +func (c *CustomType) TypeParameters() *TypeParameters { + c.loadTypeParams() + return c.typeParams +} + +func (c *CustomType) NumMethods() int { + return len(c.methods) +} + +func (c *CustomType) Methods() *Functions { + return &Functions{ + elements: c.methods, + } +} + +func (c *CustomType) Markers() markers.Values { + return c.markers +} + +func (c *CustomType) SpecType() *ast.TypeSpec { + return c.specType } func (c *CustomType) String() string { - return fmt.Sprintf("type %s %s", c.name, c.aliasType.Name()) + var builder strings.Builder + + if c.file != nil && c.file.pkg.Name != "builtin" { + builder.WriteString(fmt.Sprintf("%s.%s", c.file.Package().Name, c.name)) + } + + if c.TypeParameters().Len() != 0 { + builder.WriteString("[") + + for index := 0; index < c.TypeParameters().Len(); index++ { + typeParam := c.TypeParameters().At(index) + builder.WriteString(typeParam.String()) + + if index != c.TypeParameters().Len()-1 { + builder.WriteString(",") + } + + builder.WriteString("]") + } + + } + + return builder.String() +} + +func (c *CustomType) loadTypeParams() { + c.typeParamsOnce.Do(func() { + if c.specType == nil || c.specType.TypeParams == nil { + return + } + + for _, field := range c.specType.TypeParams.List { + for _, fieldName := range field.Names { + typeParameter := &TypeParameter{ + name: fieldName.Name, + constraints: &TypeConstraints{ + []*TypeConstraint{}, + }, + } + c.typeParams.elements = append(c.typeParams.elements, typeParameter) + } + } + + for _, field := range c.specType.TypeParams.List { + constraints := make([]*TypeConstraint, 0) + typ := getTypeFromExpression(field.Type, c.file, c.visitor, nil, c.typeParams) + + if typeSets, isTypeSets := typ.(TypeSets); isTypeSets { + for _, item := range typeSets { + if constraint, isConstraint := item.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: item}) + } + } + } else { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + + for _, fieldName := range field.Names { + typeParam, exists := c.typeParams.FindByName(fieldName.Name) + + if exists { + typeParam.constraints.elements = append(typeParam.constraints.elements, constraints...) + } + } + } + + }) } type CustomTypes struct { diff --git a/visitor/file.go b/visitor/file.go index e1ff8f9..8df8fb2 100644 --- a/visitor/file.go +++ b/visitor/file.go @@ -12,11 +12,11 @@ type File struct { path string pkg *packages.Package - allMarkers markers.MarkerValues - fileMarkers markers.MarkerValues + allMarkers markers.Values + fileMarkers markers.Values imports *Imports - importMarkers []markers.ImportMarker + importMarkers []markers.Import functions *Functions structs *Structs @@ -29,7 +29,7 @@ type File struct { visitor *packageVisitor } -func newFile(rawFile *ast.File, pkg *packages.Package, markerValues markers.MarkerValues, visitor *packageVisitor) *File { +func newFile(rawFile *ast.File, pkg *packages.Package, markerValues markers.Values, visitor *packageVisitor) *File { position := pkg.Fset.Position(rawFile.Pos()) path := position.Filename @@ -38,9 +38,9 @@ func newFile(rawFile *ast.File, pkg *packages.Package, markerValues markers.Mark path: path, pkg: pkg, allMarkers: markerValues, - fileMarkers: make(markers.MarkerValues, 0), + fileMarkers: make(markers.Values, 0), imports: &Imports{}, - importMarkers: make([]markers.ImportMarker, 0), + importMarkers: make([]markers.Import, 0), functions: &Functions{}, structs: &Structs{}, interfaces: &Interfaces{}, @@ -57,7 +57,7 @@ func (f *File) initialize() *File { for markerName, markerValues := range f.allMarkers { if markers.ImportMarkerName == markerName { for _, importMarker := range markerValues { - f.importMarkers = append(f.importMarkers, importMarker.(markers.ImportMarker)) + f.importMarkers = append(f.importMarkers, importMarker.(markers.Import)) } } else { f.fileMarkers[markerName] = append(f.fileMarkers[markerName], markerValues...) @@ -82,6 +82,7 @@ func (f *File) initialize() *File { name: importName, path: importPackage.Path.Value[1 : len(importPackage.Path.Value)-1], sideEffect: sideEffect, + file: f, position: Position{ importPosition.Line, importPosition.Column, @@ -100,7 +101,7 @@ func (f *File) Path() string { return f.path } -func (f *File) Markers() markers.MarkerValues { +func (f *File) Markers() markers.Values { return f.fileMarkers } @@ -112,7 +113,7 @@ func (f *File) Imports() *Imports { return f.imports } -func (f *File) ImportMarkers() []markers.ImportMarker { +func (f *File) ImportMarkers() []markers.Import { return f.importMarkers } diff --git a/visitor/file_test.go b/visitor/file_test.go index 367ee5b..8794610 100644 --- a/visitor/file_test.go +++ b/visitor/file_test.go @@ -18,6 +18,7 @@ type importInfo struct { name string path string sideEffect bool + file string position Position } @@ -68,6 +69,10 @@ func assertImports(t *testing.T, file *File, expectedImports []importInfo) bool t.Errorf("import path in file %s shoud be %s, but got %s", file.name, expectedImport.path, actualImport.Path()) } + if expectedImport.file != actualImport.File().Name() { + t.Errorf("file name for import %s shoud be %s, but got %s", actualImport.Path(), expectedImport.file, actualImport.File().Name()) + } + if actualImport.SideEffect() && !expectedImport.sideEffect { t.Errorf("import with path %s in file %s is not an import side effect, but should be an import side effect", expectedImport.path, file.name) } else if !actualImport.SideEffect() && expectedImport.sideEffect { diff --git a/visitor/function.go b/visitor/function.go index 69fa281..286428f 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -6,53 +6,137 @@ import ( "github.com/procyon-projects/marker/packages" "go/ast" "strings" + "sync" ) -type Variable struct { +type Parameter struct { name string typ Type } -func (v *Variable) Name() string { - return v.name +func (p *Parameter) Name() string { + return p.name } -func (v *Variable) Type() Type { - return v.typ +func (p *Parameter) Type() Type { + return p.typ } -func (v *Variable) String() string { - if v.name == "" { - return v.typ.Name() +func (p *Parameter) String() string { + if p.name == "" { + return p.typ.Name() } - return fmt.Sprintf("%s %s", v.name, v.typ.Name()) + return fmt.Sprintf("%s %s", p.name, p.typ.Name()) } -type Variables []*Variable +type Parameters struct { + elements []*Parameter +} + +func (p *Parameters) Len() int { + return len(p.elements) +} + +func (p *Parameters) At(index int) *Parameter { + if index >= 0 && index < len(p.elements) { + return p.elements[index] + } + + return nil +} + +func (p *Parameters) FindByName(name string) (*Parameter, bool) { + for _, parameter := range p.elements { + if parameter.name == name { + return parameter, true + } + } + + return nil, false +} + +type Result struct { + name string + typ Type +} + +func (r *Result) Name() string { + return r.name +} + +func (r *Result) Type() Type { + return r.typ +} + +func (r *Result) String() string { + if r.name == "" { + return r.typ.Name() + } + + return fmt.Sprintf("%s %s", r.name, r.typ.Name()) +} -func (v Variables) Len() int { - return len(v) +type Results struct { + elements []*Result } -func (v Variables) At(index int) *Variable { - if index >= 0 && index < len(v) { - return v[index] +func (r *Results) Len() int { + return len(r.elements) +} + +func (r *Results) At(index int) *Result { + if index >= 0 && index < len(r.elements) { + return r.elements[index] } return nil } +func (r *Results) FindByName(name string) (*Result, bool) { + for _, result := range r.elements { + if result.name == name { + return result, true + } + } + + return nil, false +} + +type Receiver struct { + name string + typ Type +} + +func (r *Receiver) Name() string { + return r.name +} + +func (r *Receiver) Type() Type { + return r.typ +} + +func (r *Receiver) String() string { + if r.name == "" { + return r.typ.Name() + } + + return fmt.Sprintf("%s %s", r.name, r.typ.Name()) +} + type Function struct { - name string - isExported bool - markers markers.MarkerValues - position Position - receiver *Variable - typeParams *TypeParams - params Variables - results Variables - variadic bool + name string + isExported bool + markers markers.Values + position Position + receiver *Receiver + typeParams *TypeParameters + receiverTypeParams *TypeParameters + typeParamAliases []string + params *Parameters + results *Results + variadic bool + ownerType Type file *File @@ -63,22 +147,25 @@ type Function struct { pkg *packages.Package visitor *packageVisitor - loadedTypeParams bool - loadedParams bool - loadedReturnValues bool + typeParamsOnce sync.Once + paramsOnce sync.Once + resultsOnce sync.Once } -func newFunction(funcDecl *ast.FuncDecl, funcField *ast.Field, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.MarkerValues) *Function { +func newFunction(funcDecl *ast.FuncDecl, funcType *ast.FuncType, funcField *ast.Field, ownerType Type, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *Function { function := &Function{ - file: file, - typeParams: &TypeParams{}, - params: Variables{}, - results: Variables{}, - markers: markers, - funcDecl: funcDecl, - funcField: funcField, - pkg: pkg, - visitor: visitor, + file: file, + typeParams: &TypeParameters{}, + receiverTypeParams: &TypeParameters{}, + params: &Parameters{}, + results: &Results{}, + markers: markers, + funcDecl: funcDecl, + funcField: funcField, + funcType: funcType, + pkg: pkg, + visitor: visitor, + ownerType: ownerType, } if funcDecl != nil { @@ -86,6 +173,8 @@ func newFunction(funcDecl *ast.FuncDecl, funcField *ast.Field, file *File, pkg * function.isExported = ast.IsExported(funcDecl.Name.Name) function.position = getPosition(file.pkg, funcDecl.Pos()) function.funcType = funcDecl.Type + } else if funcType != nil { + function.position = getPosition(file.pkg, function.funcType.Pos()) } else { if funcField.Names != nil { function.name = funcField.Names[0].Name @@ -103,7 +192,7 @@ func (f *Function) initialize() *Function { if f.funcDecl.Recv == nil { f.file.functions.elements = append(f.file.functions.elements, f) } else { - f.receiver = &Variable{} + f.receiver = &Receiver{} if f.funcDecl.Recv.List[0].Names != nil { f.receiver.name = f.funcDecl.Recv.List[0].Names[0].Name @@ -120,7 +209,6 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { var receiverTypeSpec *ast.TypeSpec receiverTypeName := "" - isPointerReceiver := false isStructMethod := false switch typedReceiver := receiverExpr.(type) { @@ -134,17 +222,18 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { receiverTypeName = receiverTypeSpec.Name.Name _, isStructMethod = receiverTypeSpec.Type.(*ast.StructType) } + case *ast.IndexExpr: + f.typeParamAliases = append(f.typeParamAliases, typedReceiver.Index.(*ast.Ident).Name) + return f.receiverType(typedReceiver.X) + case *ast.IndexListExpr: + for _, typeParamAlias := range typedReceiver.Indices { + f.typeParamAliases = append(f.typeParamAliases, typeParamAlias.(*ast.Ident).Name) + } + return f.receiverType(typedReceiver.X) case *ast.StarExpr: - if typedReceiver.X.(*ast.Ident).Obj == nil { - receiverTypeName = typedReceiver.X.(*ast.Ident).Name - unprocessedype := getTypeFromScope(receiverTypeName, f.visitor) - _, isStructMethod = unprocessedype.(*Struct) - } else { - receiverTypeSpec = typedReceiver.X.(*ast.Ident).Obj.Decl.(*ast.TypeSpec) - receiverTypeName = receiverTypeSpec.Name.Name - _, isStructMethod = receiverTypeSpec.Type.(*ast.StructType) + return &Pointer{ + base: f.receiverType(typedReceiver.X), } - isPointerReceiver = true } candidateType, ok := f.visitor.collector.findTypeByPkgIdAndName(f.file.pkg.ID, receiverTypeName) @@ -155,99 +244,90 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { } structType := candidateType.(*Struct) + f.ownerType = structType structType.methods = append(structType.methods, f) + f.file.functions.elements = append(f.file.functions.elements, f) } else { if !ok { candidateType = newCustomType(receiverTypeSpec, f.file, f.pkg, f.visitor, nil) } customType := candidateType.(*CustomType) + f.ownerType = customType customType.methods = append(customType.methods, f) } - if isPointerReceiver { - return &Pointer{ - base: candidateType, - } - } - return candidateType } -func (f *Function) getTypeParams(fieldList []*ast.Field) *TypeParams { - typeParams := &TypeParams{ - params: make([]*TypeParam, 0), - } - - for _, field := range fieldList { - - typ := getTypeFromExpression(field.Type, f.file, f.visitor) +func (f *Function) Name() string { + return f.name +} - if field.Names == nil { - typeParams.params = append(typeParams.params, &TypeParam{ - typ: typ, - }) - } +func (f *Function) File() *File { + return f.file +} - for _, fieldName := range field.Names { - typeParams.params = append(typeParams.params, &TypeParam{ - name: fieldName.Name, - typ: typ, - }) - } +func (f *Function) Position() Position { + return f.position +} - } +func (f *Function) Underlying() Type { + return f +} - return typeParams +func (f *Function) Receiver() *Receiver { + return f.receiver } -func (f *Function) getTypeParameterByName(name string) *TypeParam { +func (f *Function) TypeParameters() *TypeParameters { f.loadTypeParams() - /*for _, typeParam := range f.typeParams.variables { - if typeParam.name == name { - return typeParam + if f.ownerType != nil { + _, isStruct := f.ownerType.(*Struct) + _, isCustomType := f.ownerType.(*CustomType) + if !isStruct && !isCustomType { + return &TypeParameters{} } - }*/ - return nil + return f.receiverTypeParams + } + + return f.typeParams } -func (f *Function) getGenericTypeFromExpression(exp ast.Expr) Type { - var typeParam *TypeParam +func (f *Function) Parameters() *Parameters { + f.loadParams() + return f.params +} - switch t := exp.(type) { - case *ast.Ident: - typeParam = f.getTypeParameterByName(t.Name) - case *ast.SelectorExpr: - } +func (f *Function) Results() *Results { + f.loadResultValues() + return f.results +} - if typeParam == nil { - return nil - } +func (f *Function) IsVariadic() bool { + f.loadParams() + return f.variadic +} - return &Generic{ - typeParam, - } +func (f *Function) Markers() markers.Values { + return f.markers } -func (f *Function) getVariables(fieldList []*ast.Field) Variables { - variables := Variables{} +func (f *Function) getResults(fieldList []*ast.Field) []*Result { + variables := make([]*Result, 0) for _, field := range fieldList { - typ := f.getGenericTypeFromExpression(field.Type) - - if typ == nil { - typ = getTypeFromExpression(field.Type, f.file, f.visitor) - } + typ := getTypeFromExpression(field.Type, f.file, f.visitor, nil, f.typeParams) if field.Names == nil { - variables = append(variables, &Variable{ + variables = append(variables, &Result{ typ: typ, }) } for _, fieldName := range field.Names { - variables = append(variables, &Variable{ + variables = append(variables, &Result{ name: fieldName.Name, typ: typ, }) @@ -258,61 +338,28 @@ func (f *Function) getVariables(fieldList []*ast.Field) Variables { return variables } -func (f *Function) loadTypeParams() { - - if f.loadedTypeParams { - return - } - - if f.funcType.TypeParams != nil { - f.typeParams.params = append(f.typeParams.params, f.getTypeParams(f.funcType.TypeParams.List).params...) - } - - f.loadedTypeParams = true -} - -func (f *Function) loadParams() { - if f.loadedParams { - return - } - - if f.funcType.Params != nil { - f.params = append(f.params, f.getVariables(f.funcType.Params.List)...) - } +func (f *Function) getParameters(fieldList []*ast.Field) []*Parameter { + variables := make([]*Parameter, 0) - if f.params.Len() != 0 { - _, f.variadic = f.params.At(f.params.Len() - 1).Type().(*Variadic) - } + for _, field := range fieldList { + typ := getTypeFromExpression(field.Type, f.file, f.visitor, nil, f.typeParams) - f.loadedParams = true -} + if field.Names == nil { + variables = append(variables, &Parameter{ + typ: typ, + }) + } -func (f *Function) loadResultValues() { - if f.loadedReturnValues { - return - } + for _, fieldName := range field.Names { + variables = append(variables, &Parameter{ + name: fieldName.Name, + typ: typ, + }) + } - if f.funcType.Results != nil { - f.results = append(f.results, f.getVariables(f.funcType.Results.List)...) } - f.loadedReturnValues = true -} - -func (f *Function) Name() string { - return f.name -} - -func (f *Function) File() *File { - return f.file -} - -func (f *Function) Position() Position { - return f.position -} - -func (f *Function) Underlying() Type { - return f + return variables } func (f *Function) String() string { @@ -324,27 +371,67 @@ func (f *Function) String() string { if f.receiver != nil { builder.WriteString("(") - builder.WriteString(f.receiver.Name()) - builder.WriteString(" ") + + if f.receiver.Name() != "" { + builder.WriteString(f.receiver.Name()) + builder.WriteString(" ") + } + builder.WriteString(f.receiver.Type().String()) - builder.WriteString(") ") + + if f.TypeParameters().Len() != 0 { + builder.WriteString("[") + for i := 0; i < f.TypeParameters().Len(); i++ { + typeParam := f.TypeParameters().At(i) + builder.WriteString(typeParam.Name()) + + if i != f.TypeParameters().Len()-1 { + builder.WriteString(",") + } + } + builder.WriteString("]") + } + + if f.name != "" { + builder.WriteString(") ") + } else { + builder.WriteString(")") + } } builder.WriteString(f.name) + + if f.ownerType == nil && f.TypeParameters().Len() != 0 { + builder.WriteString("[") + for i := 0; i < f.TypeParameters().Len(); i++ { + typeParam := f.TypeParameters().At(i) + builder.WriteString(typeParam.Name()) + + if i != f.TypeParameters().Len()-1 { + builder.WriteString(",") + } + } + builder.WriteString("]") + } + builder.WriteString("(") - if f.Params().Len() != 0 { - for i := 0; i < f.Params().Len(); i++ { - param := f.Params().At(i) + if f.Parameters().Len() != 0 { + for i := 0; i < f.Parameters().Len(); i++ { + param := f.Parameters().At(i) builder.WriteString(param.String()) - if i != f.Params().Len()-1 { + if i != f.Parameters().Len()-1 { builder.WriteString(",") } } } - builder.WriteString(") ") + if f.Results().Len() != 0 { + builder.WriteString(")") + } else { + builder.WriteString(") ") + } if f.Results().Len() > 1 { builder.WriteString("(") @@ -368,32 +455,127 @@ func (f *Function) String() string { return builder.String() } -func (f *Function) Receiver() *Variable { - return f.receiver -} +func (f *Function) loadTypeParams() { + f.typeParamsOnce.Do(func() { + if f.ownerType != nil { + + switch typedOwner := f.ownerType.(type) { + case *CustomType: + f.typeParams.elements = append(f.typeParams.elements, typedOwner.TypeParameters().elements...) + + for index, typeParamAlias := range f.typeParamAliases { + if typeParameter, exists := f.typeParams.FindByName(typeParamAlias); exists { + f.receiverTypeParams.elements = append(f.receiverTypeParams.elements, typeParameter) + continue + } + + typeParam := f.typeParams.At(index) + + if typeParam != nil && typeParam.Name() != typeParamAlias { + typeParameter := &TypeParameter{ + typeParamAlias, + typeParam.TypeConstraints(), + } + f.typeParams.elements = append(f.typeParams.elements, typeParameter) + f.receiverTypeParams.elements = append(f.receiverTypeParams.elements, typeParameter) + } + } + case *Interface: + f.typeParams.elements = append(f.typeParams.elements, typedOwner.TypeParameters().elements...) + case *Struct: + f.typeParams.elements = append(f.typeParams.elements, typedOwner.TypeParameters().elements...) + + for index, typeParamAlias := range f.typeParamAliases { + if typeParameter, exists := f.typeParams.FindByName(typeParamAlias); exists { + f.receiverTypeParams.elements = append(f.receiverTypeParams.elements, typeParameter) + continue + } + + typeParam := f.typeParams.At(index) + + if typeParam != nil && typeParam.Name() != typeParamAlias { + typeParameter := &TypeParameter{ + typeParamAlias, + typeParam.TypeConstraints(), + } + f.typeParams.elements = append(f.typeParams.elements, typeParameter) + f.receiverTypeParams.elements = append(f.receiverTypeParams.elements, typeParameter) + } + } + } -func (f *Function) TypeParams() *TypeParams { - f.loadTypeParams() - return f.typeParams -} + } -func (f *Function) Params() Variables { - f.loadParams() - return f.params -} + if f.funcType.TypeParams == nil { + return + } -func (f *Function) Results() Variables { - f.loadResultValues() - return f.results + for _, field := range f.funcType.TypeParams.List { + for _, fieldName := range field.Names { + typeParameter := &TypeParameter{ + name: fieldName.Name, + constraints: &TypeConstraints{ + []*TypeConstraint{}, + }, + } + f.typeParams.elements = append(f.typeParams.elements, typeParameter) + } + } + + for _, field := range f.funcType.TypeParams.List { + constraints := make([]*TypeConstraint, 0) + typ := getTypeFromExpression(field.Type, f.file, f.visitor, nil, f.typeParams) + + if typeSets, isTypeSets := typ.(TypeSets); isTypeSets { + for _, item := range typeSets { + if constraint, isConstraint := item.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: item}) + } + } + } else { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + + for _, fieldName := range field.Names { + typeParam, exists := f.typeParams.FindByName(fieldName.Name) + + if exists { + typeParam.constraints.elements = append(typeParam.constraints.elements, constraints...) + } + } + } + }) } -func (f *Function) IsVariadic() bool { - f.loadParams() - return f.variadic +func (f *Function) loadParams() { + f.paramsOnce.Do(func() { + f.loadTypeParams() + + if f.funcType.Params != nil { + f.params.elements = append(f.params.elements, f.getParameters(f.funcType.Params.List)...) + } + + if f.params.Len() != 0 { + _, f.variadic = f.params.At(f.params.Len() - 1).Type().(*Variadic) + } + + }) } -func (f *Function) Markers() markers.MarkerValues { - return f.markers +func (f *Function) loadResultValues() { + f.resultsOnce.Do(func() { + f.loadTypeParams() + + if f.funcType.Results != nil { + f.results.elements = append(f.results.elements, f.getResults(f.funcType.Results.List)...) + } + }) } type Functions struct { diff --git a/visitor/function_test.go b/visitor/function_test.go index d2087be..e30b1fb 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -15,7 +15,7 @@ type receiverInfo struct { } type functionInfo struct { - markers markers.MarkerValues + markers markers.Values isVariadic bool name string fileName string @@ -50,6 +50,9 @@ func (f functionInfo) String() string { builder.WriteString(param.name + " ") } + if param.isPointer { + builder.WriteString("*") + } builder.WriteString(param.typeName) if i != len(f.params)-1 { @@ -71,6 +74,9 @@ func (f functionInfo) String() string { builder.WriteString(result.name + " ") } + if result.isPointer { + builder.WriteString("*") + } builder.WriteString(result.typeName) if i != len(f.results)-1 { @@ -89,7 +95,7 @@ func (f functionInfo) String() string { // functions var ( breadFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Bread", @@ -122,7 +128,7 @@ var ( } macaronFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Macaron", @@ -155,7 +161,7 @@ var ( } makeACakeFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:function-level": { FunctionLevel{ Name: "MakeACake", @@ -184,7 +190,7 @@ var ( } biscuitCakeFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:function-level": { FunctionLevel{ Name: "BiscuitCake", @@ -225,7 +231,7 @@ var ( } funfettiFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Funfetti", @@ -254,7 +260,7 @@ var ( } iceCreamFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "IceCream", @@ -287,7 +293,7 @@ var ( } cupCakeFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "CupCake", @@ -320,7 +326,7 @@ var ( } tartFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Tart", @@ -344,7 +350,7 @@ var ( } donutFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Donut", @@ -368,7 +374,7 @@ var ( } puddingFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Pudding", @@ -392,7 +398,7 @@ var ( } pieFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "Pie", @@ -416,7 +422,7 @@ var ( } muffinFunction = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-method-level": { InterfaceMethodLevel{ Name: "muffin", @@ -433,8 +439,9 @@ var ( params: []variableInfo{}, results: []variableInfo{ { - name: "", - typeName: "string", + name: "", + typeName: "string", + isPointer: true, }, { name: "", @@ -444,7 +451,7 @@ var ( } eatMethod = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-method-level": { StructMethodLevel{ Name: "Eat", @@ -473,7 +480,7 @@ var ( } buyMethod = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-method-level": { StructMethodLevel{ Name: "Buy", @@ -502,7 +509,7 @@ var ( } fortuneCookieMethod = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-method-level": { StructMethodLevel{ Name: "FortuneCookie", @@ -536,7 +543,7 @@ var ( } oreoMethod = functionInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-method-level": { StructMethodLevel{ Name: "Oreo", @@ -574,7 +581,7 @@ var ( } genericFunction = functionInfo{ - markers: markers.MarkerValues{}, + markers: markers.Values{}, name: "GenericFunction", fileName: "generics.go", position: Position{ diff --git a/visitor/generic_type.go b/visitor/generic_type.go new file mode 100644 index 0000000..23e0be9 --- /dev/null +++ b/visitor/generic_type.go @@ -0,0 +1,46 @@ +package visitor + +import "strings" + +type GenericType struct { + rawType Type + arguments []Type +} + +func (g *GenericType) Name() string { + return g.rawType.Name() +} + +func (g *GenericType) ActualTypeArguments() TypeSets { + return g.arguments +} + +func (g *GenericType) RawType() Type { + return g.rawType +} + +func (g *GenericType) Underlying() Type { + return g +} + +func (g *GenericType) String() string { + var builder strings.Builder + builder.WriteString(g.rawType.Name()) + + if g.ActualTypeArguments().Len() != 0 { + builder.WriteString("[") + + for index := 0; index < g.ActualTypeArguments().Len(); index++ { + typeParam := g.ActualTypeArguments().At(index) + builder.WriteString(typeParam.String()) + + if index != g.ActualTypeArguments().Len()-1 { + builder.WriteString(",") + } + } + + builder.WriteString("]") + } + + return builder.String() +} diff --git a/visitor/interface.go b/visitor/interface.go index ac19a64..da987ff 100644 --- a/visitor/interface.go +++ b/visitor/interface.go @@ -1,40 +1,28 @@ package visitor import ( + "fmt" "github.com/procyon-projects/marker" "github.com/procyon-projects/marker/packages" "go/ast" "go/token" "go/types" "strings" + "sync" ) -type Constraint struct { -} - -func (c *Constraint) Name() string { - return "" -} - -func (c *Constraint) Underlying() Type { - return c -} - -func (c *Constraint) String() string { - return "" -} - type Interface struct { - name string - isExported bool - isAnonymous bool - position Position - markers markers.MarkerValues - embeddeds []Type - constrains []*Constraint - allMethods []*Function - methods []*Function - file *File + name string + isExported bool + isAnonymous bool + position Position + markers markers.Values + embeddedInterfaces []*Interface + embeddedTypes []Type + typeParams *TypeParameters + allMethods []*Function + methods []*Function + file *File isProcessed bool @@ -45,30 +33,37 @@ type Interface struct { pkg *packages.Package visitor *packageVisitor - constraintsLoaded bool - embeddedTypesLoaded bool - methodsLoaded bool - allMethodsLoaded bool + typeParamsOnce sync.Once + embeddedInterfacesOnce sync.Once + embeddedTypesOnce sync.Once + methodsOnce sync.Once + allMethodsOnce sync.Once } -func newInterface(specType *ast.TypeSpec, interfaceType *ast.InterfaceType, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.MarkerValues) *Interface { +func newInterface(specType *ast.TypeSpec, interfaceType *ast.InterfaceType, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *Interface { i := &Interface{ - methods: make([]*Function, 0), - allMethods: make([]*Function, 0), - embeddeds: make([]Type, 0), - constrains: make([]*Constraint, 0), - markers: markers, - file: file, - isProcessed: true, - specType: specType, - pkg: pkg, - visitor: visitor, + methods: make([]*Function, 0), + allMethods: make([]*Function, 0), + embeddedTypes: make([]Type, 0), + embeddedInterfaces: make([]*Interface, 0), + typeParams: &TypeParameters{}, + markers: markers, + file: file, + isProcessed: true, + specType: specType, + pkg: pkg, + visitor: visitor, } - return i.initialize(specType, interfaceType, pkg) + return i.initialize(specType, interfaceType, file, pkg) } -func (i *Interface) initialize(specType *ast.TypeSpec, interfaceType *ast.InterfaceType, pkg *packages.Package) *Interface { +func (i *Interface) initialize(specType *ast.TypeSpec, interfaceType *ast.InterfaceType, file *File, pkg *packages.Package) *Interface { + i.isProcessed = true + i.specType = specType + i.file = file + i.pkg = pkg + if specType != nil { i.name = specType.Name.Name i.isExported = ast.IsExported(specType.Name.Name) @@ -83,7 +78,9 @@ func (i *Interface) initialize(specType *ast.TypeSpec, interfaceType *ast.Interf default: } - i.file.interfaces.elements = append(i.file.interfaces.elements, i) + if _, exists := file.interfaces.FindByName(i.name); !exists { + i.file.interfaces.elements = append(i.file.interfaces.elements, i) + } } else if interfaceType != nil { if interfaceType.Pos() != token.NoPos { //i.position = getPosition(pkg, interfaceType.Pos()) @@ -94,73 +91,12 @@ func (i *Interface) initialize(specType *ast.TypeSpec, interfaceType *ast.Interf return i } -func (i *Interface) getInterfaceMethods() []*Function { - methods := make([]*Function, 0) - - markers := i.visitor.allPackageMarkers[i.pkg.ID] - - for _, rawMethod := range i.fieldList { - _, ok := rawMethod.Type.(*ast.FuncType) - - if ok { - methods = append(methods, newFunction(nil, rawMethod, i.file, i.pkg, i.visitor, markers[rawMethod])) - } - } - - return methods -} - -func (i *Interface) getInterfaceEmbeddedTypes() []Type { - embeddedTypes := make([]Type, 0) - - for _, field := range i.fieldList { - _, ok := field.Type.(*ast.FuncType) - - if !ok { - embeddedTypes = append(embeddedTypes, getTypeFromExpression(field.Type, i.file, i.visitor)) - } - } - - return embeddedTypes -} - -func (i *Interface) loadEmbeddedTypes() { - if i.embeddedTypesLoaded { - return - } - - i.embeddeds = i.getInterfaceEmbeddedTypes() - i.embeddedTypesLoaded = true -} - -func (i *Interface) loadMethods() { - if i.methodsLoaded { - return - } - - i.methods = i.getInterfaceMethods() - i.allMethods = append(i.allMethods, i.methods...) - i.methodsLoaded = true -} - -func (i *Interface) loadAllMethods() { - if i.allMethodsLoaded { - return - } - - i.loadMethods() - i.loadEmbeddedTypes() - - for _, embeddedType := range i.embeddeds { - interfaceType, ok := embeddedType.(*Interface) - - if ok { - interfaceType.loadAllMethods() - i.allMethods = append(i.allMethods, interfaceType.allMethods...) - } +func (i *Interface) Name() string { + if i.name == "" && len(i.fieldList) == 0 { + return "interface{}" } - i.allMethodsLoaded = true + return i.name } func (i *Interface) IsEmpty() bool { @@ -183,32 +119,11 @@ func (i *Interface) Underlying() Type { return i } -func (i *Interface) String() string { - var builder strings.Builder - return builder.String() -} - -func (i *Interface) IsConstraint() bool { - return false -} - -func (i *Interface) Constraints() []*Constraint { - return i.constrains -} - -func (i *Interface) Name() string { - if i.name == "" && len(i.fieldList) == 0 { - return "interface{}" - } - - return i.name -} - func (i *Interface) IsExported() bool { return i.isExported } -func (i *Interface) Markers() markers.MarkerValues { +func (i *Interface) Markers() markers.Values { return i.markers } @@ -224,15 +139,27 @@ func (i *Interface) ExplicitMethods() *Functions { } } +func (i *Interface) NumEmbeddedInterfaces() int { + i.loadEmbeddedInterfaces() + return len(i.methods) +} + +func (i *Interface) EmbeddedInterfaces() *Interfaces { + i.loadEmbeddedInterfaces() + return &Interfaces{ + elements: i.embeddedInterfaces, + } +} + func (i *Interface) NumEmbeddedTypes() int { i.loadEmbeddedTypes() - return len(i.embeddeds) + return len(i.embeddedTypes) } func (i *Interface) EmbeddedTypes() *Types { i.loadEmbeddedTypes() return &Types{ - i.embeddeds, + i.embeddedTypes, } } @@ -248,10 +175,187 @@ func (i *Interface) Methods() *Functions { } } +func (i *Interface) IsConstraint() bool { + i.loadEmbeddedInterfaces() + + // diff is greater than 1 means that there are many non-interface types defined in interface + // and this constraint can never be satisfied + if len(i.embeddedTypes)-len(i.embeddedInterfaces) > 1 { + return false + } + + hasConstraintTypes := len(i.embeddedTypes)-len(i.embeddedInterfaces) == 1 + + for _, embeddedInterface := range i.embeddedInterfaces { + if embeddedInterface.IsConstraint() { + if hasConstraintTypes { + return false + } else { + hasConstraintTypes = true + } + } + } + + return hasConstraintTypes +} + +func (i *Interface) TypeParameters() *TypeParameters { + i.loadTypeParams() + return i.typeParams +} + +func (i *Interface) String() string { + if i.name == "" && len(i.fieldList) == 0 { + return "interface{}" + } + + var builder strings.Builder + if i.file != nil && i.file.pkg.Name != "builtin" { + builder.WriteString(fmt.Sprintf("%s.%s", i.file.Package().Name, i.name)) + } + + for index := 0; index < i.TypeParameters().Len(); index++ { + typeParam := i.TypeParameters().At(index) + builder.WriteString(typeParam.String()) + if index != i.TypeParameters().Len()-1 { + builder.WriteString(",") + } + builder.WriteString("]") + } + + return builder.String() +} + func (i *Interface) InterfaceType() *types.Interface { return i.interfaceType } +func (i *Interface) getInterfaceMethods() []*Function { + methods := make([]*Function, 0) + + markers := i.visitor.allPackageMarkers[i.pkg.ID] + + for _, rawMethod := range i.fieldList { + _, ok := rawMethod.Type.(*ast.FuncType) + + if ok { + methods = append(methods, newFunction(nil, nil, rawMethod, i, i.file, i.pkg, i.visitor, markers[rawMethod])) + } + } + + return methods +} + +func (i *Interface) getEmbeddedTypes() []Type { + embeddedTypes := make([]Type, 0) + + for _, field := range i.fieldList { + _, ok := field.Type.(*ast.FuncType) + + if !ok { + embeddedTypes = append(embeddedTypes, getTypeFromExpression(field.Type, i.file, i.visitor, nil, i.typeParams)) + } + } + + return embeddedTypes +} + +func (i *Interface) getEmbeddedInterfaces() []*Interface { + embeddedInterfaces := make([]*Interface, 0) + + for _, embeddedType := range i.embeddedTypes { + if iface, isInterface := embeddedType.(*Interface); isInterface { + embeddedInterfaces = append(embeddedInterfaces, iface) + } + } + + return embeddedInterfaces +} + +func (i *Interface) loadEmbeddedTypes() { + i.embeddedTypesOnce.Do(func() { + i.loadTypeParams() + i.embeddedTypes = i.getEmbeddedTypes() + }) +} + +func (i *Interface) loadEmbeddedInterfaces() { + i.embeddedInterfacesOnce.Do(func() { + i.loadEmbeddedTypes() + i.embeddedInterfaces = i.getEmbeddedInterfaces() + }) +} + +func (i *Interface) loadMethods() { + i.methodsOnce.Do(func() { + i.loadTypeParams() + i.methods = i.getInterfaceMethods() + i.allMethods = append(i.allMethods, i.methods...) + }) +} + +func (i *Interface) loadTypeParams() { + i.typeParamsOnce.Do(func() { + if i.specType == nil || i.specType.TypeParams == nil { + return + } + + for _, field := range i.specType.TypeParams.List { + for _, fieldName := range field.Names { + typeParameter := &TypeParameter{ + name: fieldName.Name, + constraints: &TypeConstraints{ + []*TypeConstraint{}, + }, + } + i.typeParams.elements = append(i.typeParams.elements, typeParameter) + } + } + + for _, field := range i.specType.TypeParams.List { + constraints := make([]*TypeConstraint, 0) + typ := getTypeFromExpression(field.Type, i.file, i.visitor, nil, i.typeParams) + + if typeSets, isTypeSets := typ.(TypeSets); isTypeSets { + for _, item := range typeSets { + if constraint, isConstraint := item.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: item}) + } + } + } else { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + + for _, fieldName := range field.Names { + typeParam, exists := i.typeParams.FindByName(fieldName.Name) + + if exists { + typeParam.constraints.elements = append(typeParam.constraints.elements, constraints...) + } + } + } + + }) +} + +func (i *Interface) loadAllMethods() { + i.allMethodsOnce.Do(func() { + i.loadMethods() + i.loadEmbeddedTypes() + + for _, embeddedInterface := range i.embeddedInterfaces { + embeddedInterface.loadAllMethods() + i.allMethods = append(i.allMethods, embeddedInterface.allMethods...) + } + }) +} + type Interfaces struct { elements []*Interface } diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 52f8e21..f771dc3 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -8,7 +8,7 @@ import ( ) type interfaceInfo struct { - markers markers.MarkerValues + markers markers.Values name string fileName string position Position @@ -21,7 +21,7 @@ type interfaceInfo struct { // interfaces var ( bakeryShopInterface = interfaceInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-type-level": { InterfaceTypeLevel{ Name: "BakeryShop", @@ -52,7 +52,7 @@ var ( } dessertInterface = interfaceInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-type-level": { InterfaceTypeLevel{ Name: "Dessert", @@ -87,7 +87,7 @@ var ( } newYearsEveCookieInterface = interfaceInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-type-level": { InterfaceTypeLevel{ Name: "newYearsEveCookie", @@ -110,7 +110,7 @@ var ( } sweetShopInterface = interfaceInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:interface-type-level": { InterfaceTypeLevel{ Name: "SweetShop", diff --git a/visitor/pointer.go b/visitor/pointer.go new file mode 100644 index 0000000..7b1f6d2 --- /dev/null +++ b/visitor/pointer.go @@ -0,0 +1,25 @@ +package visitor + +import ( + "fmt" +) + +type Pointer struct { + base Type +} + +func (p *Pointer) Name() string { + return fmt.Sprintf("*%s", p.base.Name()) +} + +func (p *Pointer) Elem() Type { + return p.base +} + +func (p *Pointer) Underlying() Type { + return p +} + +func (p *Pointer) String() string { + return fmt.Sprintf("*%s", p.base.Name()) +} diff --git a/visitor/position.go b/visitor/position.go new file mode 100644 index 0000000..ddde43c --- /dev/null +++ b/visitor/position.go @@ -0,0 +1,19 @@ +package visitor + +import ( + "github.com/procyon-projects/marker/packages" + "go/token" +) + +type Position struct { + Line int + Column int +} + +func getPosition(pkg *packages.Package, tokenPosition token.Pos) Position { + position := pkg.Fset.Position(tokenPosition) + return Position{ + Line: position.Line, + Column: position.Column, + } +} diff --git a/visitor/struct.go b/visitor/struct.go index 4ec61e1..b9b137b 100644 --- a/visitor/struct.go +++ b/visitor/struct.go @@ -1,11 +1,14 @@ package visitor import ( + "fmt" "github.com/procyon-projects/marker" "github.com/procyon-projects/marker/packages" "go/ast" "go/token" "go/types" + "strings" + "sync" ) type Field struct { @@ -14,7 +17,7 @@ type Field struct { tags string typ Type position Position - markers markers.MarkerValues + markers markers.Values file *File isEmbedded bool } @@ -35,6 +38,10 @@ func (f *Field) IsEmbedded() bool { return f.isEmbedded } +func (f *Field) Markers() markers.Values { + return f.markers +} + func (f *Field) Tags() string { return f.tags } @@ -74,11 +81,12 @@ type Struct struct { isExported bool isAnonymous bool position Position - markers markers.MarkerValues + markers markers.Values fields []*Field allFields []*Field methods []*Function allMethods []*Function + typeParams *TypeParameters file *File isProcessed bool @@ -90,20 +98,21 @@ type Struct struct { pkg *packages.Package visitor *packageVisitor - methodsLoaded bool - allMethodsLoaded bool - - fieldsLoaded bool - allFieldsLoaded bool + typeParamsOnce sync.Once + methodsOnce sync.Once + allMethodsOnce sync.Once + fieldsOnce sync.Once + allFieldsOnce sync.Once } -func newStruct(specType *ast.TypeSpec, structType *ast.StructType, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.MarkerValues) *Struct { +func newStruct(specType *ast.TypeSpec, structType *ast.StructType, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *Struct { s := &Struct{ markers: markers, file: file, fields: make([]*Field, 0), allFields: make([]*Field, 0), methods: make([]*Function, 0), + typeParams: &TypeParameters{}, isProcessed: true, specType: specType, pkg: pkg, @@ -114,13 +123,20 @@ func newStruct(specType *ast.TypeSpec, structType *ast.StructType, file *File, p } func (s *Struct) initialize(specType *ast.TypeSpec, structType *ast.StructType, file *File, pkg *packages.Package) *Struct { + s.isProcessed = true + s.specType = specType + s.file = file + s.pkg = pkg + if specType != nil { s.name = specType.Name.Name s.isExported = ast.IsExported(specType.Name.Name) s.position = getPosition(pkg, specType.Pos()) s.namedType = file.pkg.Types.Scope().Lookup(specType.Name.Name).Type().(*types.Named) s.fieldList = s.specType.Type.(*ast.StructType).Fields.List - s.file.structs.elements = append(s.file.structs.elements, s) + if _, exists := file.structs.FindByName(s.name); !exists { + s.file.structs.elements = append(s.file.structs.elements, s) + } } else if structType != nil { if structType.Pos() != token.NoPos { //i.position = getPosition(pkg, interfaceType.Pos()) @@ -132,157 +148,6 @@ func (s *Struct) initialize(specType *ast.TypeSpec, structType *ast.StructType, return s } -func (s *Struct) getFieldsFromFieldList() []*Field { - fields := make([]*Field, 0) - - markers := s.visitor.allPackageMarkers[s.pkg.ID] - - for _, rawField := range s.fieldList { - tags := "" - - if rawField.Tag != nil { - tags = rawField.Tag.Value - } - - if rawField.Names == nil { - embeddedType := getTypeFromExpression(rawField.Type, s.file, s.visitor) - - field := &Field{ - name: embeddedType.Name(), - isExported: ast.IsExported(embeddedType.Name()), - position: Position{}, - markers: markers[rawField], - file: s.file, - tags: tags, - typ: embeddedType, - isEmbedded: true, - } - - fields = append(fields, field) - continue - } - - for _, fieldName := range rawField.Names { - typ := getTypeFromExpression(rawField.Type, s.file, s.visitor) - - field := &Field{ - name: fieldName.Name, - isExported: ast.IsExported(fieldName.Name), - position: getPosition(s.file.pkg, fieldName.Pos()), - markers: markers[rawField], - file: s.file, - tags: tags, - typ: typ, - isEmbedded: false, - } - - fields = append(fields, field) - } - - } - - return fields -} - -func (s *Struct) loadFields() { - if s.fieldsLoaded { - return - } - - s.fields = append(s.fields, s.getFieldsFromFieldList()...) - s.fieldsLoaded = true -} - -func (s *Struct) loadAllFields() { - if s.allFieldsLoaded { - return - } - - s.loadFields() - - for _, field := range s.fields { - - if !field.IsEmbedded() { - s.allFields = append(s.allFields, field) - continue - } - - var baseType = field.Type() - pointerType, ok := field.Type().(*Pointer) - - if ok { - baseType = pointerType.Elem() - } - - importedType, ok := baseType.(*ImportedType) - - if ok { - baseType = importedType.Underlying() - } - - structType, ok := baseType.(*Struct) - - if ok { - s.allFields = append(s.allFields, structType.FieldsInHierarchy().ToSlice()...) - } - - } - - s.allFieldsLoaded = true -} - -func (s *Struct) loadMethods() { - if s.methodsLoaded { - return - } - - s.allMethods = append(s.allMethods, s.methods...) - s.methodsLoaded = true -} - -func (s *Struct) loadAllMethods() { - if s.allMethodsLoaded { - return - } - - s.loadMethods() - s.loadFields() - - for _, field := range s.fields { - - if !field.IsEmbedded() { - continue - } - - var baseType = field.Type() - pointerType, ok := field.Type().(*Pointer) - - if ok { - baseType = pointerType.Elem() - } - - importedType, ok := baseType.(*ImportedType) - - if ok { - baseType = importedType.Underlying() - } - - structType, ok := baseType.(*Struct) - - if ok { - s.allMethods = append(s.allMethods, structType.MethodsInHierarchy().ToSlice()...) - } - - interfaceType, ok := baseType.(*Interface) - - if ok { - s.allMethods = append(s.allMethods, interfaceType.Methods().ToSlice()...) - } - } - - s.allMethodsLoaded = true -} - func (s *Struct) File() *File { return s.file } @@ -296,11 +161,29 @@ func (s *Struct) Underlying() Type { } func (s *Struct) String() string { - return "" + if s.name == "" && len(s.fieldList) == 0 { + return "struct{}" + } + + var builder strings.Builder + if s.file != nil && s.file.pkg.Name != "builtin" { + builder.WriteString(fmt.Sprintf("%s.%s", s.file.Package().Name, s.name)) + } + + for index := 0; index < s.TypeParameters().Len(); index++ { + typeParam := s.TypeParameters().At(index) + builder.WriteString(typeParam.String()) + if index != s.TypeParameters().Len()-1 { + builder.WriteString(",") + } + builder.WriteString("]") + } + + return builder.String() } func (s *Struct) Name() string { - if len(s.fieldList) == 0 { + if s.name == "" && len(s.fieldList) == 0 { return "struct{}" } @@ -319,7 +202,7 @@ func (s *Struct) IsAnonymous() bool { return s.isAnonymous } -func (s *Struct) Markers() markers.MarkerValues { +func (s *Struct) Markers() markers.Values { return s.markers } @@ -405,6 +288,11 @@ func (s *Struct) MethodsInHierarchy() *Functions { } } +func (s *Struct) TypeParameters() *TypeParameters { + s.loadTypeParams() + return s.typeParams +} + func (s *Struct) Implements(i *Interface) bool { if i == nil || i.interfaceType == nil || s.namedType == nil { return false @@ -423,6 +311,215 @@ func (s *Struct) Implements(i *Interface) bool { return false } +func (s *Struct) getFieldsFromFieldList() []*Field { + fields := make([]*Field, 0) + + markers := s.visitor.allPackageMarkers[s.pkg.ID] + + for _, rawField := range s.fieldList { + tags := "" + + if rawField.Tag != nil { + tags = rawField.Tag.Value + } + + if rawField.Names == nil { + embeddedType := getTypeFromExpression(rawField.Type, s.file, s.visitor, nil, nil) + typ := embeddedType + + pointerType, isPointerType := typ.(*Pointer) + if isPointerType { + typ = pointerType.Elem() + } + + genericType, isGenericType := typ.(*GenericType) + + if isGenericType { + typ = genericType.RawType() + } + + nameParts := strings.SplitN(typ.Name(), ".", 2) + name := nameParts[0] + + if len(nameParts) == 2 { + name = nameParts[1] + } + + field := &Field{ + name: name, + isExported: ast.IsExported(name), + // TODO set position + position: Position{}, + markers: markers[rawField], + file: s.file, + tags: tags, + typ: embeddedType, + isEmbedded: true, + } + + fields = append(fields, field) + continue + } + + for _, fieldName := range rawField.Names { + typ := getTypeFromExpression(rawField.Type, s.file, s.visitor, nil, nil) + + field := &Field{ + name: fieldName.Name, + isExported: ast.IsExported(fieldName.Name), + position: getPosition(s.file.pkg, fieldName.Pos()), + markers: markers[rawField], + file: s.file, + tags: tags, + typ: typ, + isEmbedded: false, + } + + fields = append(fields, field) + } + + } + + return fields +} + +func (s *Struct) loadFields() { + s.fieldsOnce.Do(func() { + s.loadTypeParams() + s.fields = append(s.fields, s.getFieldsFromFieldList()...) + }) +} + +func (s *Struct) loadAllFields() { + s.allFieldsOnce.Do(func() { + s.loadFields() + + for _, field := range s.fields { + + if !field.IsEmbedded() { + s.allFields = append(s.allFields, field) + continue + } + + var baseType = field.Type() + pointerType, isPointer := field.Type().(*Pointer) + + if isPointer { + baseType = pointerType.Elem() + } + + genericType, isGenericType := baseType.(*GenericType) + + if isGenericType { + baseType = genericType.RawType() + } + + structType, isStruct := baseType.(*Struct) + + if isStruct { + s.allFields = append(s.allFields, structType.FieldsInHierarchy().ToSlice()...) + } + } + }) +} + +func (s *Struct) loadMethods() { + s.methodsOnce.Do(func() { + s.loadTypeParams() + s.allMethods = append(s.allMethods, s.methods...) + }) +} + +func (s *Struct) loadAllMethods() { + s.allMethodsOnce.Do(func() { + s.loadMethods() + s.loadFields() + + for _, field := range s.fields { + + if !field.IsEmbedded() { + continue + } + + var baseType = field.Type() + pointerType, isPointer := field.Type().(*Pointer) + + if isPointer { + baseType = pointerType.Elem() + } + + genericType, isGenericType := baseType.(*GenericType) + + if isGenericType { + baseType = genericType.RawType() + } + + structType, isStructType := baseType.(*Struct) + + if isStructType { + s.allMethods = append(s.allMethods, structType.MethodsInHierarchy().ToSlice()...) + } + + interfaceType, isInterfaceType := baseType.(*Interface) + + if isInterfaceType { + s.allMethods = append(s.allMethods, interfaceType.Methods().ToSlice()...) + } + } + }) + +} + +func (s *Struct) loadTypeParams() { + s.typeParamsOnce.Do(func() { + if s.specType == nil || s.specType.TypeParams == nil { + return + } + + for _, field := range s.specType.TypeParams.List { + for _, fieldName := range field.Names { + typeParameter := &TypeParameter{ + name: fieldName.Name, + constraints: &TypeConstraints{ + []*TypeConstraint{}, + }, + } + s.typeParams.elements = append(s.typeParams.elements, typeParameter) + } + } + + for _, field := range s.specType.TypeParams.List { + constraints := make([]*TypeConstraint, 0) + typ := getTypeFromExpression(field.Type, s.file, s.visitor, nil, s.typeParams) + + if typeSets, isTypeSets := typ.(TypeSets); isTypeSets { + for _, item := range typeSets { + if constraint, isConstraint := item.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: item}) + } + } + } else { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + + for _, fieldName := range field.Names { + typeParam, exists := s.typeParams.FindByName(fieldName.Name) + + if exists { + typeParam.constraints.elements = append(typeParam.constraints.elements, constraints...) + } + } + } + + }) +} + type Structs struct { elements []*Struct } diff --git a/visitor/struct_test.go b/visitor/struct_test.go index cbf0bb4..9edf880 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -18,7 +18,7 @@ type structInfo struct { fileName string isExported bool position Position - markers markers.MarkerValues + markers markers.Values methods map[string]functionInfo allMethods map[string]functionInfo fields map[string]fieldInfo @@ -32,7 +32,7 @@ type structInfo struct { // structs var ( friedCookieStruct = structInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-type-level": { StructTypeLevel{ Name: "FriedCookie", @@ -80,7 +80,7 @@ var ( } cookieStruct = structInfo{ - markers: markers.MarkerValues{ + markers: markers.Values{ "marker:struct-type-level": { StructTypeLevel{ Name: "cookie", diff --git a/visitor/type.go b/visitor/type.go index 28aba71..76d49a9 100644 --- a/visitor/type.go +++ b/visitor/type.go @@ -1,13 +1,10 @@ package visitor import ( - "fmt" - "github.com/procyon-projects/marker/packages" "go/ast" "go/token" "go/types" "strconv" - "strings" ) type Type interface { @@ -42,133 +39,29 @@ func (t *Types) FindByName(name string) (Type, bool) { return nil, false } -type Position struct { - Line int - Column int -} - -func getPosition(pkg *packages.Package, tokenPosition token.Pos) Position { - position := pkg.Fset.Position(tokenPosition) - return Position{ - Line: position.Line, - Column: position.Column, - } -} - -type ImportedType struct { - pkg *packages.Package - typ Type -} - -func (i *ImportedType) Package() *packages.Package { - return i.pkg -} - -func (i *ImportedType) Underlying() Type { - return i.typ -} - -func (i *ImportedType) String() string { - return "" -} - -func (i *ImportedType) Name() string { - return fmt.Sprintf("%s.%s", i.pkg.Name, i.typ.Name()) -} - -type Variadic struct { - elem Type -} - -func (v *Variadic) Name() string { - return v.elem.Name() -} - -func (v *Variadic) Elem() Type { - return v.elem -} - -func (v *Variadic) Underlying() Type { - return v -} +type TypeSets []Type -func (v *Variadic) String() string { +func (t TypeSets) Name() string { return "" } -type Pointer struct { - base Type +func (t TypeSets) Len() int { + return len(t) } -func (p *Pointer) Name() string { - return "" -} - -func (p *Pointer) Elem() Type { - return p.base -} - -func (p *Pointer) Underlying() Type { - return p -} - -func (p *Pointer) String() string { - var builder strings.Builder - builder.WriteString("*") - builder.WriteString(p.base.Name()) - return builder.String() -} - -type TypeParam struct { - name string - typ Type -} - -func (t *TypeParam) Name() string { - return t.name -} - -func (t *TypeParam) Type() Type { - return t.typ -} - -type TypeParams struct { - params []*TypeParam -} - -func (t *TypeParams) Len() int { - return len(t.params) -} - -func (t *TypeParams) At(index int) *TypeParam { - if index >= 0 && index < len(t.params) { - return t.params[index] +func (t TypeSets) At(index int) Type { + if index >= 0 && index < len(t) { + return t[index] } return nil } -type Generic struct { - typeParam *TypeParam -} - -func (g *Generic) Name() string { - return g.typeParam.name -} - -func (g *Generic) ParamName() string { - return g.typeParam.name -} - -func (g *Generic) TypeParam() *TypeParam { - return g.typeParam +func (t TypeSets) Underlying() Type { + return t } -func (g *Generic) Underlying() Type { - return g.typeParam.typ -} - -func (g *Generic) String() string { +func (t TypeSets) String() string { return "" } @@ -185,24 +78,18 @@ func getTypeFromScope(name string, visitor *packageVisitor) Type { if ok { switch typedName.Underlying().(type) { case *types.Struct: - structType := &Struct{ - name: name, - isProcessed: false, - } + structType := newStruct(nil, nil, nil, pkg, visitor, nil) + structType.isProcessed = false visitor.collector.unprocessedTypes[pkg.ID][name] = structType return structType case *types.Interface: - interfaceType := &Interface{ - name: name, - isProcessed: false, - } + interfaceType := newInterface(nil, nil, nil, pkg, visitor, nil) + interfaceType.isProcessed = false visitor.collector.unprocessedTypes[pkg.ID][name] = interfaceType return interfaceType default: - customType := &CustomType{ - name: name, - isProcessed: false, - } + customType := newCustomType(nil, nil, pkg, visitor, nil) + customType.isProcessed = false visitor.collector.unprocessedTypes[pkg.ID][name] = customType return customType } @@ -222,19 +109,28 @@ func collectTypeFromTypeSpec(typeSpec *ast.TypeSpec, visitor *packageVisitor) Ty switch t := typ.(type) { case *Interface: if !t.isProcessed { - file.interfaces.elements = append(file.interfaces.elements, t) + if _, exists := file.interfaces.FindByName(t.name); !exists { + file.interfaces.elements = append(file.interfaces.elements, t) + } + t.initialize(typeSpec, nil, file, pkg) } t.markers = visitor.packageMarkers[typeSpec] return t case *Struct: if !t.isProcessed { - file.structs.elements = append(file.structs.elements, t) + if _, exists := file.structs.FindByName(t.name); !exists { + file.structs.elements = append(file.structs.elements, t) + } + t.initialize(typeSpec, nil, file, pkg) } t.markers = visitor.packageMarkers[typeSpec] return t case *CustomType: if !t.isProcessed { - file.customTypes.elements = append(file.customTypes.elements, t) + if _, exists := file.customTypes.FindByName(t.name); !exists { + file.customTypes.elements = append(file.customTypes.elements, t) + } + t.initialize(typeSpec, file, pkg) } t.markers = visitor.packageMarkers[typeSpec] return t @@ -251,7 +147,7 @@ func collectTypeFromTypeSpec(typeSpec *ast.TypeSpec, visitor *packageVisitor) Ty } } -func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) Type { +func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor, ownerType Type, typeParameters *TypeParameters) Type { pkg := visitor.pkg collector := visitor.collector @@ -265,12 +161,10 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) T return typ } - if typed.Name == "error" { - errorType, _ := collector.findTypeByPkgIdAndName("builtin", "error") - return errorType - } else if typed.Name == "any" { - anyType, _ := collector.findTypeByPkgIdAndName("builtin", "any") - return anyType + typ, ok = collector.findTypeByPkgIdAndName("builtin", typed.Name) + + if ok { + return typ } typ, ok = collector.findTypeByPkgIdAndName(pkg.ID, typed.Name) @@ -280,9 +174,28 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) T } if typed.Obj == nil { + if typeParameters != nil { + if typeParameter, exists := typeParameters.FindByName(typed.Name); exists { + return typeParameter + } + } return getTypeFromScope(typed.Name, visitor) } + if field, isField := typed.Obj.Decl.(*ast.Field); isField { + if typeParameters == nil { + // TODO return invalid type + return nil + } + + if typeParameter, exists := typeParameters.FindByName(field.Names[0].Name); exists { + return typeParameter + } + + // TODO return invalid type + return nil + } + return collectTypeFromTypeSpec(typed.Obj.Decl.(*ast.TypeSpec), visitor) case *ast.SelectorExpr: importName := typed.X.(*ast.Ident).Name @@ -290,13 +203,13 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) T return collector.findTypeByImportAndTypeName(importName, typeName, file) case *ast.StarExpr: return &Pointer{ - base: getTypeFromExpression(typed.X, file, visitor), + base: getTypeFromExpression(typed.X, file, visitor, ownerType, typeParameters), } case *ast.ArrayType: if typed.Len == nil { return &Slice{ - elem: getTypeFromExpression(typed.Elt, file, visitor), + elem: getTypeFromExpression(typed.Elt, file, visitor, ownerType, typeParameters), } } else { basicLit, isBasicLit := typed.Len.(*ast.BasicLit) @@ -304,19 +217,19 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) T if isBasicLit { length, _ := strconv.ParseInt(basicLit.Value, 10, 64) return &Array{ - elem: getTypeFromExpression(typed.Elt, file, visitor), + elem: getTypeFromExpression(typed.Elt, file, visitor, ownerType, typeParameters), len: length, } } return &Array{ - elem: getTypeFromExpression(typed.Elt, file, visitor), + elem: getTypeFromExpression(typed.Elt, file, visitor, ownerType, typeParameters), len: -1, } } case *ast.ChanType: chanType := &Chan{ - elem: getTypeFromExpression(typed.Value, file, visitor), + elem: getTypeFromExpression(typed.Value, file, visitor, ownerType, typeParameters), } if typed.Dir&ast.SEND == ast.SEND { @@ -330,19 +243,85 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor) T return chanType case *ast.Ellipsis: return &Variadic{ - elem: getTypeFromExpression(typed.Elt, file, visitor), + elem: getTypeFromExpression(typed.Elt, file, visitor, ownerType, typeParameters), } case *ast.FuncType: return &Function{} case *ast.MapType: return &Map{ - key: getTypeFromExpression(typed.Key, file, visitor), - elem: getTypeFromExpression(typed.Value, file, visitor), + key: getTypeFromExpression(typed.Key, file, visitor, ownerType, typeParameters), + elem: getTypeFromExpression(typed.Value, file, visitor, ownerType, typeParameters), } case *ast.InterfaceType: return newInterface(nil, typed, nil, nil, visitor, nil) case *ast.StructType: return newStruct(nil, typed, nil, nil, visitor, nil) + case *ast.IndexExpr: + genericType := &GenericType{ + rawType: getTypeFromExpression(typed.X, file, visitor, ownerType, typeParameters), + arguments: make([]Type, 0), + } + + genericType.arguments = append(genericType.arguments, getTypeFromExpression(typed.Index, file, visitor, ownerType, typeParameters)) + return genericType + case *ast.IndexListExpr: + genericType := &GenericType{ + rawType: getTypeFromExpression(typed.X, file, visitor, ownerType, typeParameters), + arguments: make([]Type, 0), + } + + for _, argument := range typed.Indices { + genericType.arguments = append(genericType.arguments, getTypeFromExpression(argument, file, visitor, ownerType, typeParameters)) + } + + return genericType + case *ast.BinaryExpr: + constraints := make(TypeSets, 0) + + firstType := getTypeFromExpression(typed.X, file, visitor, ownerType, typeParameters) + + if typeSets, isTypeSet := firstType.(TypeSets); isTypeSet { + for _, typ := range typeSets { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + } else { + if constraint, isConstraint := firstType.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: firstType}) + } + } + + secondType := getTypeFromExpression(typed.Y, file, visitor, ownerType, typeParameters) + + if typeSets, isTypeSet := secondType.(TypeSets); isTypeSet { + for _, typ := range typeSets { + if constraint, isConstraint := typ.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: typ}) + } + } + } else { + if constraint, isConstraint := secondType.(*TypeConstraint); isConstraint { + constraints = append(constraints, constraint) + } else { + constraints = append(constraints, &TypeConstraint{typ: secondType}) + } + } + + return constraints + case *ast.UnaryExpr: + if typed.Op == token.TILDE { + return &TypeConstraint{ + tildeOperator: true, + typ: getTypeFromExpression(typed.X, nil, visitor, ownerType, typeParameters), + } + } } return nil diff --git a/visitor/type_constraint.go b/visitor/type_constraint.go new file mode 100644 index 0000000..35610cb --- /dev/null +++ b/visitor/type_constraint.go @@ -0,0 +1,48 @@ +package visitor + +import "fmt" + +type TypeConstraint struct { + typ Type + tildeOperator bool +} + +func (c *TypeConstraint) Name() string { + return c.typ.String() +} + +func (c *TypeConstraint) Type() Type { + return c.typ +} + +func (c *TypeConstraint) Underlying() Type { + return c +} + +func (c *TypeConstraint) Satisfy(t Type) bool { + // TODO implement this method + return false +} + +func (c *TypeConstraint) String() string { + if c.tildeOperator { + return fmt.Sprintf("~%s", c.typ.String()) + } + return c.typ.String() +} + +type TypeConstraints struct { + elements []*TypeConstraint +} + +func (c *TypeConstraints) Len() int { + return len(c.elements) +} + +func (c *TypeConstraints) At(index int) *TypeConstraint { + if index >= 0 && index < len(c.elements) { + return c.elements[index] + } + + return nil +} diff --git a/visitor/type_parameter.go b/visitor/type_parameter.go new file mode 100644 index 0000000..012deb7 --- /dev/null +++ b/visitor/type_parameter.go @@ -0,0 +1,64 @@ +package visitor + +import "strings" + +type TypeParameter struct { + name string + constraints *TypeConstraints +} + +func (t *TypeParameter) Name() string { + return t.name +} + +func (t *TypeParameter) TypeConstraints() *TypeConstraints { + return t.constraints +} + +func (t *TypeParameter) Underlying() Type { + return t +} + +func (t *TypeParameter) String() string { + var builder strings.Builder + builder.WriteString(t.name + " ") + + if t.TypeConstraints().Len() != 0 { + for i := 0; i < t.TypeConstraints().Len(); i++ { + constraint := t.TypeConstraints().At(i) + builder.WriteString(constraint.String()) + + if i != t.TypeConstraints().Len()-1 { + builder.WriteString("|") + } + } + } + + return builder.String() +} + +type TypeParameters struct { + elements []*TypeParameter +} + +func (t *TypeParameters) Len() int { + return len(t.elements) +} + +func (t *TypeParameters) At(index int) *TypeParameter { + if index >= 0 && index < len(t.elements) { + return t.elements[index] + } + + return nil +} + +func (t *TypeParameters) FindByName(name string) (*TypeParameter, bool) { + for _, typeParameter := range t.elements { + if typeParameter.name == name { + return typeParameter, true + } + } + + return nil, false +} diff --git a/visitor/variadic.go b/visitor/variadic.go new file mode 100644 index 0000000..9aee7ef --- /dev/null +++ b/visitor/variadic.go @@ -0,0 +1,21 @@ +package visitor + +type Variadic struct { + elem Type +} + +func (v *Variadic) Name() string { + return v.elem.Name() +} + +func (v *Variadic) Elem() Type { + return v.elem +} + +func (v *Variadic) Underlying() Type { + return v +} + +func (v *Variadic) String() string { + return "" +} diff --git a/visitor/visitor.go b/visitor/visitor.go index 0b8ae0b..5883a1c 100644 --- a/visitor/visitor.go +++ b/visitor/visitor.go @@ -14,8 +14,8 @@ type packageVisitor struct { collector *packageCollector pkg *packages.Package - packageMarkers map[ast.Node]markers.MarkerValues - allPackageMarkers map[string]map[ast.Node]markers.MarkerValues + packageMarkers map[ast.Node]markers.Values + allPackageMarkers map[string]map[ast.Node]markers.Values file *File @@ -54,7 +54,7 @@ func (visitor *packageVisitor) Visit(node ast.Node) ast.Visitor { return visitor case *ast.FuncDecl: visitor.funcDecl = typedNode - newFunction(typedNode, nil, visitor.file, visitor.pkg, visitor, visitor.packageMarkers[typedNode]) + newFunction(typedNode, nil, nil, nil, visitor.file, visitor.pkg, visitor, visitor.packageMarkers[typedNode]) return nil case *ast.TypeSpec: collectTypeFromTypeSpec(typedNode, visitor) @@ -64,7 +64,7 @@ func (visitor *packageVisitor) Visit(node ast.Node) ast.Visitor { } } -func visitPackage(pkg *packages.Package, collector *packageCollector, allPackageMarkers map[string]map[ast.Node]markers.MarkerValues) { +func visitPackage(pkg *packages.Package, collector *packageCollector, allPackageMarkers map[string]map[ast.Node]markers.Values) { pkgVisitor := &packageVisitor{ collector: collector, pkg: pkg, @@ -88,7 +88,7 @@ func EachFile(collector *markers.Collector, pkgs []*packages.Package, callback F } var errs []error - packageMarkers := make(map[string]map[ast.Node]markers.MarkerValues) + packageMarkers := make(map[string]map[ast.Node]markers.Values) for _, pkg := range pkgs { markerValues, err := collector.Collect(pkg) diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index a15cd86..a86e288 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -38,8 +38,9 @@ type FunctionLevel struct { } type variableInfo struct { - name string - typeName string + name string + typeName string + isPointer bool } func TestVisitor_VisitPackage(t *testing.T) { @@ -73,12 +74,14 @@ func TestVisitor_VisitPackage(t *testing.T) { name: "", path: "fmt", sideEffect: false, + file: "dessert.go", position: Position{Line: 7, Column: 2}, }, { name: "_", path: "strings", sideEffect: true, + file: "dessert.go", position: Position{Line: 8, Column: 2}, }, }, @@ -122,6 +125,7 @@ func TestVisitor_VisitPackage(t *testing.T) { name: "", path: "net/http", sideEffect: false, + file: "string.go", position: Position{Line: 3, Column: 8}, }, }, @@ -187,7 +191,7 @@ func TestVisitor_VisitPackage(t *testing.T) { } } -func assertMarkers(t *testing.T, expectedMarkers markers.MarkerValues, actualMarkers markers.MarkerValues, msg string) { +func assertMarkers(t *testing.T, expectedMarkers markers.Values, actualMarkers markers.Values, msg string) { if actualMarkers.Count() != expectedMarkers.Count() { t.Errorf("the number of the markers of the %s should be %d, but got %d", msg, expectedMarkers.Count(), actualMarkers.Count()) return @@ -199,7 +203,7 @@ func assertMarkers(t *testing.T, expectedMarkers markers.MarkerValues, actualMar continue } - actualMarkerValues := actualMarkers.AllMarkers(markerName) + actualMarkerValues, _ := actualMarkers.FindByName(markerName) for index, expectedMarkerValue := range markerValues { actualMarker := actualMarkerValues[index] From 7479e121c18b12752749c6c981cd4b7bcec2e156 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Wed, 26 Oct 2022 22:03:23 +0300 Subject: [PATCH 02/16] Complete CustomType and TypeParameter struct implementations --- go.mod | 8 +- go.sum | 16 +- packages/package_test.go | 23 +-- processor/common.go | 4 +- processor/generate.go | 2 +- processor/validate.go | 2 +- test/any/generics.go | 32 +++- test/any/method.go | 5 + test/menu/coffee.go | 4 + visitor/custom_type_test.go | 78 +++++----- visitor/file_test.go | 4 - visitor/function.go | 44 +++--- visitor/function_test.go | 286 +++++++++++++++++++++++++++++++++--- visitor/interface.go | 4 +- visitor/interface_test.go | 64 ++++++-- visitor/struct.go | 14 +- visitor/struct_test.go | 65 +++++++- visitor/type.go | 2 +- visitor/visitor_test.go | 44 +++++- 19 files changed, 565 insertions(+), 136 deletions(-) create mode 100644 test/any/method.go diff --git a/go.mod b/go.mod index 3080d48..181d456 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,8 @@ go 1.18 require ( github.com/spf13/cobra v1.2.1 github.com/stretchr/testify v1.8.0 - golang.org/x/tools v0.1.10 + golang.org/x/exp v0.0.0-20221026153819-32f3d567a233 + golang.org/x/tools v0.2.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -15,7 +16,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.4.0 // indirect - golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect - golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + golang.org/x/mod v0.6.0 // indirect + golang.org/x/sys v0.1.0 // indirect ) diff --git a/go.sum b/go.sum index a008177..0965579 100644 --- a/go.sum +++ b/go.sum @@ -220,7 +220,6 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= @@ -261,6 +260,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= +golang.org/x/exp v0.0.0-20221026153819-32f3d567a233 h1:9bNbSKT4RPLEzne0Xh1v3NaNecsa1DKjkOuTbY6V9rI= +golang.org/x/exp v0.0.0-20221026153819-32f3d567a233/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -286,8 +287,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 h1:kQgndtyPBW/JIYERgdxfwMYh3AVStj88WQTlNDi2a+o= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= +golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= +golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -388,8 +389,8 @@ golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -452,12 +453,11 @@ golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= -golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= +golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE= +golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= diff --git a/packages/package_test.go b/packages/package_test.go index a42e4f8..e48581e 100644 --- a/packages/package_test.go +++ b/packages/package_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/mock" "os" "os/exec" + "path/filepath" "strings" "testing" "time" @@ -65,7 +66,7 @@ func TestGetPackageInfo(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, pkg) assert.Equal(t, packageInfo, pkg) - assert.Equal(t, "anyGoPath/pkg/mod/github.com/procyon-projects/marker@v2.0.5", packageInfo.ModulePath()) + assert.Equal(t, filepath.FromSlash("anyGoPath/pkg/mod/github.com/procyon-projects/marker@v2.0.5"), packageInfo.ModulePath()) } func TestGetMarkerPackageShouldReturnErrorIfAnyErrorOccurs(t *testing.T) { @@ -208,7 +209,7 @@ func TestMarkerPackagePath(t *testing.T) { mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) - assert.Equal(t, "anyGoPath/marker/pkg/github.com/procyon-projects/marker/anyVersion", + assert.Equal(t, filepath.FromSlash("anyGoPath/marker/pkg/github.com/procyon-projects/marker/anyVersion"), MarkerPackagePath("github.com/procyon-projects/marker", "anyVersion")) } @@ -224,7 +225,7 @@ func TestMarkerPackagePathFromPackageInfo(t *testing.T) { mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) - assert.Equal(t, "anyGoPath/marker/pkg/github.com/procyon-projects/marker/anyVersion", + assert.Equal(t, filepath.FromSlash("anyGoPath/marker/pkg/github.com/procyon-projects/marker/anyVersion"), MarkerPackagePathFromPackageInfo(&PackageInfo{ Path: "github.com/procyon-projects/marker", Version: "anyVersion", @@ -243,7 +244,7 @@ func TestMarkerProcessorYamlPath(t *testing.T) { mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) - assert.Equal(t, "anyGoPath/pkg/mod/github.com/procyon-projects/marker@anyVersion/marker.processors.yaml", + assert.Equal(t, filepath.FromSlash("anyGoPath/pkg/mod/github.com/procyon-projects/marker@anyVersion/marker.processors.yaml"), MarkerProcessorYamlPath(&PackageInfo{ Path: "github.com/procyon-projects/marker", Version: "anyVersion", @@ -266,7 +267,7 @@ func TestMarkerPackageYamlPath(t *testing.T) { assert.True(t, strings.HasSuffix(MarkerPackageYamlPath(&PackageInfo{ Path: "github.com/procyon-projects/marker", Version: "anyVersion", - }), "marker/pkg/github.com/procyon-projects/marker/anyVersion/marker.procesors.yaml")) + }), filepath.FromSlash("marker/pkg/github.com/procyon-projects/marker/anyVersion/marker.procesors.yaml"))) } func TestGoModDir(t *testing.T) { @@ -285,19 +286,19 @@ func TestInstallPackageShouldInstallPackage(t *testing.T) { } goInstallCmd := &exec.Cmd{ - Path: "/usr/local/go/bin/go", + Path: execLookupPath, Args: []string{"go", "install", "github.com/procyon-projects/chrono/...@latest"}, Env: []string{}, Stdout: os.Stdout, Stderr: os.Stderr, } - environmentVariables := []string{fmt.Sprintf("GOBIN=%s", "anyGoPath/marker/pkg/github.com/procyon-projects/chrono/latest")} + environmentVariables := []string{fmt.Sprintf("GOBIN=%s", filepath.FromSlash("anyGoPath/marker/pkg/github.com/procyon-projects/chrono/latest"))} goInstallCmd.Env = append(goInstallCmd.Env, os.Environ()...) goInstallCmd.Env = append(goInstallCmd.Env, environmentVariables...) - mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) mockExecutor.On("Execute", goInstallCmd).Return(nil, nil) + mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) err := InstallPackage(&PackageInfo{ Path: "github.com/procyon-projects/chrono", @@ -317,19 +318,19 @@ func TestInstallPackageReturnsErrorIfInstallationIsFailed(t *testing.T) { } goInstallCmd := &exec.Cmd{ - Path: "/usr/local/go/bin/go", + Path: execLookupPath, Args: []string{"go", "install", "github.com/procyon-projects/chrono/...@latest"}, Env: []string{}, Stdout: os.Stdout, Stderr: os.Stderr, } - environmentVariables := []string{fmt.Sprintf("GOBIN=%s", "anyGoPath/marker/pkg/github.com/procyon-projects/chrono/latest")} + environmentVariables := []string{fmt.Sprintf("GOBIN=%s", filepath.FromSlash("anyGoPath/marker/pkg/github.com/procyon-projects/chrono/latest"))} goInstallCmd.Env = append(goInstallCmd.Env, os.Environ()...) goInstallCmd.Env = append(goInstallCmd.Env, environmentVariables...) - mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) mockExecutor.On("Execute", goInstallCmd).Return(nil, errors.New("anyInstallationError")) + mockExecutor.On("Execute", goPathCmd).Return([]byte("anyGoPath\n "), nil) err := InstallPackage(&PackageInfo{ Path: "github.com/procyon-projects/chrono", diff --git a/processor/common.go b/processor/common.go index 2939c46..0984ef9 100644 --- a/processor/common.go +++ b/processor/common.go @@ -42,9 +42,9 @@ func getConfig(configFilePath string) (*Config, error) { return config, nil } -// GetPackageDirectories finds the go module directory and returns +// PackageDirectories finds the go module directory and returns // the package directories. -func GetPackageDirectories() ([]string, error) { +func PackageDirectories() ([]string, error) { var err error var modDir string modDir, err = packages.GoModDir() diff --git a/processor/generate.go b/processor/generate.go index 16158be..f2c183a 100644 --- a/processor/generate.go +++ b/processor/generate.go @@ -50,7 +50,7 @@ var generateCmd = &cobra.Command{ } var dirs []string - dirs, err = GetPackageDirectories() + dirs, err = PackageDirectories() if err != nil { return errors.New("go.module not found") diff --git a/processor/validate.go b/processor/validate.go index ae38535..2788660 100644 --- a/processor/validate.go +++ b/processor/validate.go @@ -49,7 +49,7 @@ var validateCmd = &cobra.Command{ var dirs []string - dirs, err = GetPackageDirectories() + dirs, err = PackageDirectories() if err != nil { return errors.New("go.module not found") diff --git a/test/any/generics.go b/test/any/generics.go index a7dfb9e..7924e7d 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -1,5 +1,35 @@ package any -func GenericFunction[T string]() { +import ( + "context" + "golang.org/x/exp/constraints" +) +func GenericFunction[K []map[T]X, T int | bool, X ~string](x []K) T { + var value T + return value } + +type Repository[T, ID any] interface { + Save(entity T) T +} + +type Controller[C context.Context, T any] struct { + AnyField1 string + AnyField2 int +} + +func (c Controller[K, C]) Index(ctx K, h C) { + +} + +type TestController struct { + Controller[context.Context, int16] +} + +type Number interface { + constraints.Ordered + ToString() +} + +type HttpHandler[C context.Context] func(ctx C) diff --git a/test/any/method.go b/test/any/method.go new file mode 100644 index 0000000..12fade1 --- /dev/null +++ b/test/any/method.go @@ -0,0 +1,5 @@ +package any + +func (HttpHandler[C]) Print(ctx C) { + +} diff --git a/test/menu/coffee.go b/test/menu/coffee.go index cea8fa2..8a3c1dc 100644 --- a/test/menu/coffee.go +++ b/test/menu/coffee.go @@ -11,3 +11,7 @@ const ( Latte TurkishCoffee ) + +func (c *cookie) PrintCookie(v interface{}) []string { + return nil +} diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index f1f3d46..befc960 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -7,48 +7,63 @@ import ( ) type customTypeInfo struct { - name string - aliasTypeName string - isExported bool + name string + underlyingTypeName string + isExported bool + methods map[string]functionInfo } var ( errorCustomTypes = map[string]customTypeInfo{ "errorList": { - name: "errorList", - aliasTypeName: "[]error", - isExported: false, + name: "errorList", + underlyingTypeName: "[]error", + isExported: false, + methods: map[string]functionInfo{ + "Print": printErrorMethod, + "ToErrors": toErrorsMethod, + }, }, } permissionCustomTypes = map[string]customTypeInfo{ "Permission": { - name: "Permission", - aliasTypeName: "int", - isExported: true, + name: "Permission", + underlyingTypeName: "int", + isExported: true, }, "RequestMethod": { - name: "RequestMethod", - aliasTypeName: "string", - isExported: true, + name: "RequestMethod", + underlyingTypeName: "string", + isExported: true, }, "Chan": { - name: "Chan", - aliasTypeName: "int", - isExported: true, + name: "Chan", + underlyingTypeName: "int", + isExported: true, }, } coffeeCustomTypes = map[string]customTypeInfo{ "Coffee": { - name: "Coffee", - aliasTypeName: "int", - isExported: true, + name: "Coffee", + underlyingTypeName: "int", + isExported: true, }, } freshCustomTypes = map[string]customTypeInfo{ "Lemonade": { - name: "Lemonade", - aliasTypeName: "uint", - isExported: true, + name: "Lemonade", + underlyingTypeName: "uint", + isExported: true, + }, + } + genericsCustomTypes = map[string]customTypeInfo{ + "HttpHandler": { + name: "HttpHandler", + underlyingTypeName: "func (ctx C)", + isExported: true, + methods: map[string]functionInfo{ + "Print": printHttpHandlerMethod, + }, }, } ) @@ -79,20 +94,14 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy continue } - assert.Equal(t, actualCustomType, actualCustomType.Underlying()) assert.Equal(t, fileCustomType, actualCustomType, "CustomTypes.At should return %w, but got %w", fileCustomType, actualCustomType) if expectedCustomType.name != actualCustomType.Name() { t.Errorf("custom type name in file %s shoud be %s, but got %s", file.name, expectedCustomTypeName, actualCustomType.Name()) } - if expectedCustomType.aliasTypeName != actualCustomType.AliasType().Name() { - t.Errorf("alias type of custom type %s in file %s shoud be %s, but got %s", file.name, expectedCustomType.name, expectedCustomType.aliasTypeName, actualCustomType.AliasType().Name()) - } - - customTypeStrValue := fmt.Sprintf("type %s %s", expectedCustomType.name, expectedCustomType.aliasTypeName) - if customTypeStrValue != actualCustomType.String() { - t.Errorf("String() method of custom type %s shoud return %s, but got %s", expectedCustomTypeName, customTypeStrValue, actualCustomType.String()) + if expectedCustomType.underlyingTypeName != actualCustomType.Underlying().String() { + t.Errorf("underlying type of custom type %s in file %s shoud be %s, but got %s", file.name, expectedCustomType.name, expectedCustomType.underlyingTypeName, actualCustomType.Underlying().String()) } if actualCustomType.IsExported() && !expectedCustomType.isExported { @@ -101,16 +110,7 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy t.Errorf("custom type with name %s is not exported, but should be exported field", expectedCustomTypeName) } - /*if expectedConstant.value != actualConstant.Value() { - t.Errorf("value of constant %s in file %s shoud be %s, but got %s", actualConstant.Name(), file.name, expectedConstant.value, actualConstant.Value()) - } - - if expectedConstant.typeName != actualConstant.Type().Name() { - t.Errorf("type name of constant %s in file %s shoud be %s, but got %s", actualConstant.Name(), file.name, expectedConstant.typeName, actualConstant.Type().Name()) - } - - assert.Equal(t, expectedConstant.position, actualConstant.Position(), "the position of constant %s in file %s should be %w, but got %w", expectedConstant.name, actualConstant.File().Name(), expectedConstant.position, actualConstant.Position()) - */ + assertFunctions(t, fmt.Sprintf("custom type %s", actualCustomType.Name()), actualCustomType.Methods(), expectedCustomType.methods) index++ } diff --git a/visitor/file_test.go b/visitor/file_test.go index 8794610..28a1651 100644 --- a/visitor/file_test.go +++ b/visitor/file_test.go @@ -69,10 +69,6 @@ func assertImports(t *testing.T, file *File, expectedImports []importInfo) bool t.Errorf("import path in file %s shoud be %s, but got %s", file.name, expectedImport.path, actualImport.Path()) } - if expectedImport.file != actualImport.File().Name() { - t.Errorf("file name for import %s shoud be %s, but got %s", actualImport.Path(), expectedImport.file, actualImport.File().Name()) - } - if actualImport.SideEffect() && !expectedImport.sideEffect { t.Errorf("import with path %s in file %s is not an import side effect, but should be an import side effect", expectedImport.path, file.name) } else if !actualImport.SideEffect() && expectedImport.sideEffect { diff --git a/visitor/function.go b/visitor/function.go index 286428f..7482231 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -154,18 +154,26 @@ type Function struct { func newFunction(funcDecl *ast.FuncDecl, funcType *ast.FuncType, funcField *ast.Field, ownerType Type, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *Function { function := &Function{ - file: file, - typeParams: &TypeParameters{}, - receiverTypeParams: &TypeParameters{}, - params: &Parameters{}, - results: &Results{}, - markers: markers, - funcDecl: funcDecl, - funcField: funcField, - funcType: funcType, - pkg: pkg, - visitor: visitor, - ownerType: ownerType, + file: file, + typeParams: &TypeParameters{ + []*TypeParameter{}, + }, + receiverTypeParams: &TypeParameters{ + []*TypeParameter{}, + }, + params: &Parameters{ + []*Parameter{}, + }, + results: &Results{ + []*Result{}, + }, + markers: markers, + funcDecl: funcDecl, + funcField: funcField, + funcType: funcType, + pkg: pkg, + visitor: visitor, + ownerType: ownerType, } if funcDecl != nil { @@ -255,6 +263,7 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { customType := candidateType.(*CustomType) f.ownerType = customType customType.methods = append(customType.methods, f) + f.file.functions.elements = append(f.file.functions.elements, f) } return candidateType @@ -371,13 +380,12 @@ func (f *Function) String() string { if f.receiver != nil { builder.WriteString("(") - if f.receiver.Name() != "" { builder.WriteString(f.receiver.Name()) builder.WriteString(" ") } - builder.WriteString(f.receiver.Type().String()) + builder.WriteString(f.receiver.Type().Name()) if f.TypeParameters().Len() != 0 { builder.WriteString("[") @@ -405,7 +413,7 @@ func (f *Function) String() string { builder.WriteString("[") for i := 0; i < f.TypeParameters().Len(); i++ { typeParam := f.TypeParameters().At(i) - builder.WriteString(typeParam.Name()) + builder.WriteString(typeParam.String()) if i != f.TypeParameters().Len()-1 { builder.WriteString(",") @@ -427,7 +435,7 @@ func (f *Function) String() string { } } - if f.Results().Len() != 0 { + if f.Results().Len() == 0 { builder.WriteString(")") } else { builder.WriteString(") ") @@ -506,7 +514,7 @@ func (f *Function) loadTypeParams() { } - if f.funcType.TypeParams == nil { + if f.funcType == nil || f.funcType.TypeParams == nil { return } @@ -557,7 +565,7 @@ func (f *Function) loadParams() { f.paramsOnce.Do(func() { f.loadTypeParams() - if f.funcType.Params != nil { + if f.funcType != nil && f.funcType.Params != nil { f.params.elements = append(f.params.elements, f.getParameters(f.funcType.Params.List)...) } diff --git a/visitor/function_test.go b/visitor/function_test.go index e30b1fb..7c2b523 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -23,6 +23,7 @@ type functionInfo struct { receiver *receiverInfo params []variableInfo results []variableInfo + typeParams []variableInfo } func (f functionInfo) String() string { @@ -31,16 +32,59 @@ func (f functionInfo) String() string { if f.receiver != nil { builder.WriteString("(") - builder.WriteString(f.receiver.name) - builder.WriteString(" ") + if f.receiver.name != "" { + builder.WriteString(f.receiver.name) + builder.WriteString(" ") + } if f.receiver.isPointer { builder.WriteString("*") } builder.WriteString(f.receiver.typeName) + if len(f.typeParams) != 0 { + builder.WriteString("[") + for i := 0; i < len(f.typeParams); i++ { + param := f.typeParams[i] + if param.name != "" { + builder.WriteString(param.name) + } + + if param.typeName != "" { + builder.WriteString(" " + param.typeName) + } + + if i != len(f.typeParams)-1 { + builder.WriteString(",") + } + } + builder.WriteString("]") + } + builder.WriteString(") ") } - builder.WriteString(f.name) + if f.name != "" { + builder.WriteString(f.name) + } else { + builder.WriteString(" ") + } + + if f.receiver == nil && len(f.typeParams) != 0 { + builder.WriteString("[") + for i := 0; i < len(f.typeParams); i++ { + param := f.typeParams[i] + if param.name != "" { + builder.WriteString(param.name + " ") + } + + builder.WriteString(param.typeName) + + if i != len(f.typeParams)-1 { + builder.WriteString(",") + } + } + builder.WriteString("]") + } + builder.WriteString("(") if len(f.params) != 0 { @@ -61,7 +105,11 @@ func (f functionInfo) String() string { } } - builder.WriteString(") ") + if len(f.results) == 0 { + builder.WriteString(")") + } else { + builder.WriteString(") ") + } if len(f.results) > 1 { builder.WriteString("(") @@ -73,7 +121,6 @@ func (f functionInfo) String() string { if result.name != "" { builder.WriteString(result.name + " ") } - if result.isPointer { builder.WriteString("*") } @@ -94,6 +141,211 @@ func (f functionInfo) String() string { // functions var ( + saveFunction = functionInfo{ + markers: markers.Values{}, + name: "Save", + fileName: "generics.go", + position: Position{ + Line: 14, + Column: 6, + }, + isVariadic: false, + params: []variableInfo{ + { + name: "entity", + typeName: "T", + }, + }, + results: []variableInfo{ + { + typeName: "T", + }, + }, + } + toStringFunction = functionInfo{ + markers: markers.Values{}, + name: "ToString", + fileName: "generics.go", + position: Position{ + Line: 32, + Column: 10, + }, + isVariadic: false, + params: []variableInfo{}, + results: []variableInfo{}, + } + indexMethod = functionInfo{ + markers: markers.Values{}, + name: "Index", + fileName: "generics.go", + position: Position{ + Line: 22, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + name: "c", + isPointer: false, + typeName: "Controller", + }, + params: []variableInfo{ + { + name: "ctx", + typeName: "K", + }, + { + name: "h", + typeName: "C", + }, + }, + typeParams: []variableInfo{ + { + name: "K", + typeName: "", + }, + { + name: "C", + typeName: "", + }, + }, + } + printCookieMethod = functionInfo{ + markers: markers.Values{}, + name: "PrintCookie", + fileName: "coffee.go", + position: Position{ + Line: 15, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + name: "c", + isPointer: true, + typeName: "cookie", + }, + params: []variableInfo{ + { + name: "v", + typeName: "interface{}", + }, + }, + results: []variableInfo{ + { + name: "", + typeName: "[]string", + }, + }, + } + printHttpHandlerMethod = functionInfo{ + markers: markers.Values{}, + name: "Print", + fileName: "method.go", + position: Position{ + Line: 3, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + isPointer: false, + typeName: "HttpHandler", + }, + params: []variableInfo{ + { + name: "ctx", + typeName: "C", + }, + }, + results: []variableInfo{}, + typeParams: []variableInfo{ + { + name: "C", + typeName: "", + }, + }, + } + printErrorMethod = functionInfo{ + markers: markers.Values{}, + name: "Print", + fileName: "error.go", + position: Position{ + Line: 5, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + isPointer: false, + name: "e", + typeName: "errorList", + }, + params: []variableInfo{}, + results: []variableInfo{}, + typeParams: []variableInfo{}, + } + toErrorsMethod = functionInfo{ + markers: markers.Values{ + "deprecated": { + markers.Deprecated{ + Value: "any deprecation message", + }, + }, + }, + name: "ToErrors", + fileName: "error.go", + position: Position{ + Line: 12, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + isPointer: false, + name: "e", + typeName: "errorList", + }, + params: []variableInfo{}, + results: []variableInfo{ + { + name: "", + typeName: "[]error", + }, + }, + typeParams: []variableInfo{}, + } + genericFunction = functionInfo{ + markers: markers.Values{}, + name: "GenericFunction", + fileName: "generics.go", + position: Position{ + Line: 8, + Column: 1, + }, + isVariadic: false, + params: []variableInfo{ + { + name: "x", + typeName: "[]K", + }, + }, + results: []variableInfo{ + { + name: "", + typeName: "T", + }, + }, + typeParams: []variableInfo{ + { + name: "K", + typeName: "[]map[T]X", + }, + { + name: "T", + typeName: "int|bool", + }, + { + name: "X", + typeName: "~string", + }, + }, + } breadFunction = functionInfo{ markers: markers.Values{ "marker:interface-method-level": { @@ -155,7 +407,7 @@ var ( }, { name: "", - typeName: "fmt.Stringer", + typeName: "Stringer", }, }, } @@ -579,19 +831,6 @@ var ( }, }, } - - genericFunction = functionInfo{ - markers: markers.Values{}, - name: "GenericFunction", - fileName: "generics.go", - position: Position{ - Line: 3, - Column: 1, - }, - isVariadic: false, - params: []variableInfo{}, - results: []variableInfo{}, - } ) func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, expectedMethods map[string]functionInfo) bool { @@ -627,7 +866,8 @@ func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, t.Errorf("the function %s should not be a variadic function for %s", expectedMethodName, descriptor) } - typeParam := actualMethod.TypeParams() + // TODO Type Params + typeParam := actualMethod.TypeParameters() if typeParam != nil { typeParam.Len() } @@ -637,7 +877,7 @@ func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, assert.Equal(t, expectedMethod.position, actualMethod.Position(), "the position of the function %s for %s should be %w, but got %w", expectedMethodName, descriptor, expectedMethod.position, actualMethod.Position()) - assertFunctionParameters(t, expectedMethod.params, actualMethod.Params(), fmt.Sprintf("function %s (%s)", expectedMethodName, descriptor)) + assertFunctionParameters(t, expectedMethod.params, actualMethod.Parameters(), fmt.Sprintf("function %s (%s)", expectedMethodName, descriptor)) assertFunctionResult(t, expectedMethod.results, actualMethod.Results(), fmt.Sprintf("function %s (%s)", expectedMethodName, descriptor)) @@ -647,7 +887,7 @@ func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, return true } -func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actualParams Variables, msg string) { +func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actualParams *Parameters, msg string) { if actualParams.Len() != len(expectedParams) { t.Errorf("the number of the %s parameters should be %d, but got %d", msg, len(expectedParams), actualParams.Len()) return @@ -667,7 +907,7 @@ func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actua } } -func assertFunctionResult(t *testing.T, expectedResults []variableInfo, actualResults Variables, msg string) { +func assertFunctionResult(t *testing.T, expectedResults []variableInfo, actualResults *Results, msg string) { if actualResults.Len() != len(expectedResults) { t.Errorf("the number of the %s results should be %d, but got %d", msg, len(expectedResults), actualResults.Len()) return diff --git a/visitor/interface.go b/visitor/interface.go index da987ff..eea6828 100644 --- a/visitor/interface.go +++ b/visitor/interface.go @@ -141,7 +141,7 @@ func (i *Interface) ExplicitMethods() *Functions { func (i *Interface) NumEmbeddedInterfaces() int { i.loadEmbeddedInterfaces() - return len(i.methods) + return len(i.embeddedInterfaces) } func (i *Interface) EmbeddedInterfaces() *Interfaces { @@ -347,7 +347,7 @@ func (i *Interface) loadTypeParams() { func (i *Interface) loadAllMethods() { i.allMethodsOnce.Do(func() { i.loadMethods() - i.loadEmbeddedTypes() + i.loadEmbeddedInterfaces() for _, embeddedInterface := range i.embeddedInterfaces { embeddedInterface.loadAllMethods() diff --git a/visitor/interface_test.go b/visitor/interface_test.go index f771dc3..68d3e32 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -8,18 +8,51 @@ import ( ) type interfaceInfo struct { - markers markers.Values - name string - fileName string - position Position - explicitMethods map[string]functionInfo - methods map[string]functionInfo - embeddedTypes []string - isExported bool + markers markers.Values + name string + fileName string + position Position + explicitMethods map[string]functionInfo + methods map[string]functionInfo + embeddedTypes []string + embeddedInterfaces []string + isExported bool } // interfaces var ( + repositoryInterface = interfaceInfo{ + name: "Repository", + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 13, + Column: 6, + }, + explicitMethods: map[string]functionInfo{ + "Save": saveFunction, + }, + methods: map[string]functionInfo{ + "Save": saveFunction, + }, + } + numberInterface = interfaceInfo{ + name: "Number", + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 30, + Column: 6, + }, + explicitMethods: map[string]functionInfo{ + "ToString": toStringFunction, + }, + methods: map[string]functionInfo{ + "ToString": toStringFunction, + }, + embeddedTypes: []string{"Ordered"}, + embeddedInterfaces: []string{"Ordered"}, + } bakeryShopInterface = interfaceInfo{ markers: markers.Values{ "marker:interface-type-level": { @@ -48,7 +81,8 @@ var ( "muffin": muffinFunction, "Bread": breadFunction, }, - embeddedTypes: []string{"Dessert"}, + embeddedTypes: []string{"Dessert"}, + embeddedInterfaces: []string{"Dessert"}, } dessertInterface = interfaceInfo{ @@ -138,7 +172,8 @@ var ( "Pie": pieFunction, "muffin": muffinFunction, }, - embeddedTypes: []string{"newYearsEveCookie", "Dessert"}, + embeddedTypes: []string{"newYearsEveCookie", "Dessert"}, + embeddedInterfaces: []string{"newYearsEveCookie", "Dessert"}, } ) @@ -191,6 +226,10 @@ func assertInterfaces(t *testing.T, file *File, interfaces map[string]interfaceI t.Errorf("the number of the explicit methods of the interface %s should be %d, but got %d", expectedInterfaceName, len(expectedInterface.explicitMethods), actualInterface.NumExplicitMethods()) } + if actualInterface.NumEmbeddedInterfaces() != len(expectedInterface.embeddedInterfaces) { + t.Errorf("the number of the embedded interfaces of the interface %s should be %d, but got %d", expectedInterfaceName, len(expectedInterface.embeddedInterfaces), actualInterface.NumEmbeddedInterfaces()) + } + if actualInterface.NumEmbeddedTypes() != len(expectedInterface.embeddedTypes) { t.Errorf("the number of the embedded types of the interface %s should be %d, but got %d", expectedInterfaceName, len(expectedInterface.embeddedTypes), actualInterface.NumEmbeddedTypes()) } @@ -200,6 +239,11 @@ func assertInterfaces(t *testing.T, file *File, interfaces map[string]interfaceI assert.Equal(t, expectedInterface.position, actualInterface.Position(), "the position of the interface %s should be %w, but got %w", expectedInterfaceName, expectedInterface.position, actualInterface.Position()) + // TODO fix + actualInterface.IsConstraint() + actualInterface.EmbeddedInterfaces() + actualInterface.EmbeddedTypes() + assertInterfaceEmbeddedTypes(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.EmbeddedTypes(), expectedInterface.embeddedTypes) assertFunctions(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.Methods(), expectedInterface.methods) assertFunctions(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.ExplicitMethods(), expectedInterface.explicitMethods) diff --git a/visitor/struct.go b/visitor/struct.go index b9b137b..c41c804 100644 --- a/visitor/struct.go +++ b/visitor/struct.go @@ -107,12 +107,14 @@ type Struct struct { func newStruct(specType *ast.TypeSpec, structType *ast.StructType, file *File, pkg *packages.Package, visitor *packageVisitor, markers markers.Values) *Struct { s := &Struct{ - markers: markers, - file: file, - fields: make([]*Field, 0), - allFields: make([]*Field, 0), - methods: make([]*Function, 0), - typeParams: &TypeParameters{}, + markers: markers, + file: file, + fields: make([]*Field, 0), + allFields: make([]*Field, 0), + methods: make([]*Function, 0), + typeParams: &TypeParameters{ + []*TypeParameter{}, + }, isProcessed: true, specType: specType, pkg: pkg, diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 9edf880..a0bf334 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -31,6 +31,62 @@ type structInfo struct { // structs var ( + controllerStruct = structInfo{ + markers: markers.Values{}, + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 17, + Column: 6, + }, + methods: map[string]functionInfo{ + "Index": indexMethod, + }, + allMethods: map[string]functionInfo{ + "Index": indexMethod, + }, + fields: map[string]fieldInfo{ + "AnyField1": { + isExported: true, + isEmbeddedField: false, + typeName: "string", + }, + "AnyField2": { + isExported: true, + isEmbeddedField: false, + typeName: "int", + }, + }, + embeddedFields: map[string]fieldInfo{}, + numFields: 2, + totalFields: 2, + numEmbeddedFields: 0, + } + testControllerStruct = structInfo{ + markers: markers.Values{}, + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 26, + Column: 6, + }, + methods: map[string]functionInfo{}, + allMethods: map[string]functionInfo{ + "Index": indexMethod, + }, + fields: map[string]fieldInfo{}, + embeddedFields: map[string]fieldInfo{ + "Controller": { + isExported: true, + isEmbeddedField: true, + typeName: "Controller", + }, + }, + numFields: 1, + totalFields: 2, + numEmbeddedFields: 1, + } + friedCookieStruct = structInfo{ markers: markers.Values{ "marker:struct-type-level": { @@ -54,6 +110,7 @@ var ( "Buy": buyMethod, "Oreo": oreoMethod, "FortuneCookie": fortuneCookieMethod, + "PrintCookie": printCookieMethod, }, fields: map[string]fieldInfo{ "cookie": { @@ -99,10 +156,12 @@ var ( methods: map[string]functionInfo{ "FortuneCookie": fortuneCookieMethod, "Oreo": oreoMethod, + "PrintCookie": printCookieMethod, }, allMethods: map[string]functionInfo{ "FortuneCookie": fortuneCookieMethod, "Oreo": oreoMethod, + "PrintCookie": printCookieMethod, }, fields: map[string]fieldInfo{ "ChocolateChip": { @@ -126,7 +185,7 @@ var ( func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool { if len(structs) != file.Structs().Len() { - t.Errorf("the number of the functions should be %d, but got %d", len(structs), file.Structs().Len()) + t.Errorf("the number of the structs should be %d, but got %d", len(structs), file.Structs().Len()) return false } @@ -157,9 +216,9 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool t.Errorf("struct with name %s is not exported, but should be exported", actualStruct.Name()) } - if actualStruct.NumMethods() == 0 && !actualStruct.IsEmpty() { + if actualStruct.NumFields() == 0 && !actualStruct.IsEmpty() { t.Errorf("the struct %s should be empty", actualStruct.Name()) - } else if actualStruct.NumMethods() != 0 && actualStruct.IsEmpty() { + } else if actualStruct.NumFields() != 0 && actualStruct.IsEmpty() { t.Errorf("the struct %s should not be empty", actualStruct.Name()) } diff --git a/visitor/type.go b/visitor/type.go index 76d49a9..3c39cbb 100644 --- a/visitor/type.go +++ b/visitor/type.go @@ -246,7 +246,7 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor, o elem: getTypeFromExpression(typed.Elt, file, visitor, ownerType, typeParameters), } case *ast.FuncType: - return &Function{} + return newFunction(nil, typed, nil, ownerType, file, nil, visitor, nil) case *ast.MapType: return &Map{ key: getTypeFromExpression(typed.Key, file, visitor, ownerType, typeParameters), diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index a86e288..5bcd2c0 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -63,6 +63,9 @@ func TestVisitor_VisitPackage(t *testing.T) { "coffee.go": { constants: coffeeConstants, customTypes: coffeeCustomTypes, + functions: map[string]functionInfo{ + "PrintCookie": printCookieMethod, + }, }, "fresh.go": { constants: freshConstants, @@ -86,8 +89,12 @@ func TestVisitor_VisitPackage(t *testing.T) { }, }, functions: map[string]functionInfo{ - "MakeACake": makeACakeFunction, - "BiscuitCake": biscuitCakeFunction, + "MakeACake": makeACakeFunction, + "BiscuitCake": biscuitCakeFunction, + "Eat": eatMethod, + "Buy": buyMethod, + "FortuneCookie": fortuneCookieMethod, + "Oreo": oreoMethod, }, interfaces: map[string]interfaceInfo{ "BakeryShop": bakeryShopInterface, @@ -105,6 +112,10 @@ func TestVisitor_VisitPackage(t *testing.T) { "error.go": { constants: []constantInfo{}, customTypes: errorCustomTypes, + functions: map[string]functionInfo{ + "Print": printErrorMethod, + "ToErrors": toErrorsMethod, + }, }, "permission.go": { constants: permissionConstants, @@ -117,6 +128,35 @@ func TestVisitor_VisitPackage(t *testing.T) { constants: []constantInfo{}, functions: map[string]functionInfo{ "GenericFunction": genericFunction, + "Index": indexMethod, + }, + interfaces: map[string]interfaceInfo{ + "Repository": repositoryInterface, + "Number": numberInterface, + }, + structs: map[string]structInfo{ + "Controller": controllerStruct, + "TestController": testControllerStruct, + }, + imports: []importInfo{ + { + name: "", + path: "context", + sideEffect: false, + position: Position{Line: 4, Column: 2}, + }, + { + name: "", + path: "golang.org/x/exp/constraints", + sideEffect: false, + position: Position{Line: 5, Column: 2}, + }, + }, + customTypes: genericsCustomTypes, + }, + "method.go": { + functions: map[string]functionInfo{ + "Print": printHttpHandlerMethod, }, }, "string.go": { From 567e55442963b1827a8923023583186bf7d06c04 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Wed, 26 Oct 2022 23:54:19 +0300 Subject: [PATCH 03/16] Complete CustomType and TypeParameter struct implementations --- visitor/function.go | 29 ++++++++++++++----- visitor/function_test.go | 4 +++ visitor/generic_type_test.go | 55 ++++++++++++++++++++++++++++++++++++ visitor/pointer_test.go | 41 +++++++++++++++++++++++++++ visitor/variadic.go | 4 ++- visitor/variadic_test.go | 42 +++++++++++++++++++++++++++ 6 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 visitor/generic_type_test.go create mode 100644 visitor/pointer_test.go create mode 100644 visitor/variadic_test.go diff --git a/visitor/function.go b/visitor/function.go index 7482231..82942f1 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -350,7 +350,7 @@ func (f *Function) getResults(fieldList []*ast.Field) []*Result { func (f *Function) getParameters(fieldList []*ast.Field) []*Parameter { variables := make([]*Parameter, 0) - for _, field := range fieldList { + for index, field := range fieldList { typ := getTypeFromExpression(field.Type, f.file, f.visitor, nil, f.typeParams) if field.Names == nil { @@ -366,6 +366,13 @@ func (f *Function) getParameters(fieldList []*ast.Field) []*Parameter { }) } + if index == len(fieldList)-1 { + f.variadic = true + } + } + + if len(variables) != 0 { + _, f.variadic = variables[len(variables)-1].Type().(*Variadic) } return variables @@ -427,7 +434,20 @@ func (f *Function) String() string { if f.Parameters().Len() != 0 { for i := 0; i < f.Parameters().Len(); i++ { param := f.Parameters().At(i) - builder.WriteString(param.String()) + builder.WriteString(param.Name()) + + if i == f.params.Len()-1 && f.IsVariadic() { + if param.name != "" { + builder.WriteString(" ") + } + builder.WriteString("...") + builder.WriteString(param.Type().Name()) + } else { + if param.name != "" { + builder.WriteString(" ") + } + builder.WriteString(param.Type().Name()) + } if i != f.Parameters().Len()-1 { builder.WriteString(",") @@ -568,11 +588,6 @@ func (f *Function) loadParams() { if f.funcType != nil && f.funcType.Params != nil { f.params.elements = append(f.params.elements, f.getParameters(f.funcType.Params.List)...) } - - if f.params.Len() != 0 { - _, f.variadic = f.params.At(f.params.Len() - 1).Type().(*Variadic) - } - }) } diff --git a/visitor/function_test.go b/visitor/function_test.go index 7c2b523..3736ece 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -94,6 +94,10 @@ func (f functionInfo) String() string { builder.WriteString(param.name + " ") } + if i == len(f.params)-1 && f.isVariadic { + builder.WriteString("...") + } + if param.isPointer { builder.WriteString("*") } diff --git a/visitor/generic_type_test.go b/visitor/generic_type_test.go new file mode 100644 index 0000000..fd68cb3 --- /dev/null +++ b/visitor/generic_type_test.go @@ -0,0 +1,55 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGenericType_Name(t *testing.T) { + g := &GenericType{ + rawType: &Struct{ + name: "Test", + }, + arguments: TypeSets{basicTypesMap["int"], basicTypesMap["string"]}, + } + + assert.Equal(t, "Test", g.Name()) +} + +func TestGenericType_String(t *testing.T) { + g := &GenericType{ + rawType: &Struct{ + name: "Test", + }, + arguments: TypeSets{basicTypesMap["int"], basicTypesMap["string"]}, + } + + assert.Equal(t, "Test[int,string]", g.String()) +} + +func TestGenericType_RawType(t *testing.T) { + rawType := &Struct{ + name: "", + typeParams: &TypeParameters{ + []*TypeParameter{}, + }, + } + g := &GenericType{ + rawType: rawType, + } + + assert.Equal(t, rawType, g.RawType()) + assert.Equal(t, "struct{}", g.RawType().Name()) + assert.Equal(t, "struct{}", g.RawType().String()) +} + +func TestGenericType_Underlying(t *testing.T) { + g := &GenericType{ + rawType: &Struct{ + name: "Test", + }, + arguments: TypeSets{basicTypesMap["int"], basicTypesMap["string"]}, + } + + assert.Equal(t, g, g.Underlying()) +} diff --git a/visitor/pointer_test.go b/visitor/pointer_test.go new file mode 100644 index 0000000..bd845ac --- /dev/null +++ b/visitor/pointer_test.go @@ -0,0 +1,41 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestPointer_Name(t *testing.T) { + p := &Pointer{ + base: basicTypesMap["bool"], + } + + assert.Equal(t, "*bool", p.Name()) +} + +func TestPointer_String(t *testing.T) { + p := &Pointer{ + base: basicTypesMap["bool"], + } + + assert.Equal(t, "*bool", p.String()) +} + +func TestPointer_Elem(t *testing.T) { + elem := basicTypesMap["byte"] + p := &Pointer{ + base: elem, + } + + assert.Equal(t, elem, p.Elem()) + assert.Equal(t, "byte", p.Elem().Name()) + assert.Equal(t, "byte", p.Elem().String()) +} + +func TestPointer_Underlying(t *testing.T) { + p := &Pointer{ + base: basicTypesMap["bool"], + } + + assert.Equal(t, p, p.Underlying()) +} diff --git a/visitor/variadic.go b/visitor/variadic.go index 9aee7ef..f26c511 100644 --- a/visitor/variadic.go +++ b/visitor/variadic.go @@ -1,5 +1,7 @@ package visitor +import "fmt" + type Variadic struct { elem Type } @@ -17,5 +19,5 @@ func (v *Variadic) Underlying() Type { } func (v *Variadic) String() string { - return "" + return fmt.Sprintf("...%s", v.elem.Name()) } diff --git a/visitor/variadic_test.go b/visitor/variadic_test.go new file mode 100644 index 0000000..f8a4dcd --- /dev/null +++ b/visitor/variadic_test.go @@ -0,0 +1,42 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestVariadic_Name(t *testing.T) { + v := &Variadic{ + elem: basicTypesMap["bool"], + } + + assert.Equal(t, "bool", v.Name()) +} + +func TestVariadic_String(t *testing.T) { + v := &Variadic{ + elem: basicTypesMap["bool"], + } + + assert.Equal(t, "...bool", v.String()) +} + +func TestVariadic_Elem(t *testing.T) { + elem := basicTypesMap["byte"] + a := &Variadic{ + elem: elem, + } + + assert.Equal(t, elem, a.Elem()) + assert.Equal(t, "byte", a.Elem().Name()) + assert.Equal(t, "byte", a.Elem().String()) +} + +func TestVariadic_Underlying(t *testing.T) { + elem := basicTypesMap["byte"] + v := &Variadic{ + elem: elem, + } + + assert.Equal(t, v, v.Underlying()) +} From b0e4f363f9b0cb20a46ae275d804ae9365925d63 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Sun, 30 Oct 2022 01:58:03 +0300 Subject: [PATCH 04/16] Complete CustomType and TypeParameter struct implementations --- test/any/generics.go | 6 +- test/any/math.go | 6 +- test/any/method.go | 2 +- visitor/constant_test.go | 193 +++++++++++++++++++------------- visitor/custom_type.go | 4 +- visitor/custom_type_test.go | 14 ++- visitor/function_test.go | 20 ++++ visitor/interface.go | 18 ++- visitor/interface_test.go | 27 +++++ visitor/struct.go | 18 ++- visitor/struct_test.go | 29 +++-- visitor/type_constraint_test.go | 70 ++++++++++++ visitor/type_parameter_test.go | 108 ++++++++++++++++++ visitor/visitor_test.go | 5 +- 14 files changed, 415 insertions(+), 105 deletions(-) create mode 100644 visitor/type_constraint_test.go create mode 100644 visitor/type_parameter_test.go diff --git a/test/any/generics.go b/test/any/generics.go index 7924e7d..b0c08f4 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -32,4 +32,8 @@ type Number interface { ToString() } -type HttpHandler[C context.Context] func(ctx C) +type HttpHandler[C context.Context, K string | int] func(ctx C) K + +type EventPublisher[E any] interface { + Publish(e E) +} diff --git a/test/any/math.go b/test/any/math.go index 7dea8be..3481087 100644 --- a/test/any/math.go +++ b/test/any/math.go @@ -1,9 +1,9 @@ package any const IntegerMathOperation = -(3*4 + 2 - 6) / 2 -const FloatMathOperation = (3*4.0 + 5 - 6.2) / 2 +const floatMathOperation = (3*4.0 + 5 - 6.2) / 2 const ModOperation = 5 % 2 -const EqualOperation = 2 == 2 +const equalOperation = 2 == 2 const NotEqualOperation = 2 != 2 const GreaterThan = 3 > 2 const GreaterThanOrEqual = 3 >= 2 @@ -12,4 +12,4 @@ const LessThanOrEqual = 3 <= 5 const XorOperation = 4 ^ 2 const AndNotOperation = 4 &^ 2 const AndOperation = 4 & 2 -const OrOperation = 4 | 2 +const orOperation = 4 | 2 diff --git a/test/any/method.go b/test/any/method.go index 12fade1..4037682 100644 --- a/test/any/method.go +++ b/test/any/method.go @@ -1,5 +1,5 @@ package any -func (HttpHandler[C]) Print(ctx C) { +func (HttpHandler[C, K]) Print(ctx C) { } diff --git a/visitor/constant_test.go b/visitor/constant_test.go index 3790577..1ccc404 100644 --- a/visitor/constant_test.go +++ b/visitor/constant_test.go @@ -6,10 +6,11 @@ import ( ) type constantInfo struct { - name string - position Position - value any - typeName string + name string + position Position + value any + typeName string + isExported bool } var ( @@ -20,8 +21,9 @@ var ( Line: 9, Column: 2, }, - value: -1, - typeName: "Coffee", + value: -1, + typeName: "Coffee", + isExported: true, }, { name: "Americano", @@ -29,8 +31,9 @@ var ( Line: 10, Column: 2, }, - value: -2, - typeName: "Coffee", + value: -2, + typeName: "Coffee", + isExported: true, }, { name: "Latte", @@ -38,8 +41,9 @@ var ( Line: 11, Column: 2, }, - value: -3, - typeName: "Coffee", + value: -3, + typeName: "Coffee", + isExported: true, }, { name: "TurkishCoffee", @@ -47,8 +51,9 @@ var ( Line: 12, Column: 2, }, - value: -4, - typeName: "Coffee", + value: -4, + typeName: "Coffee", + isExported: true, }, } freshConstants = []constantInfo{ @@ -58,8 +63,9 @@ var ( Line: 9, Column: 2, }, - value: 0, - typeName: "Lemonade", + value: 0, + typeName: "Lemonade", + isExported: true, }, { name: "BlueberryLemonade", @@ -67,8 +73,9 @@ var ( Line: 10, Column: 2, }, - value: 1, - typeName: "Lemonade", + value: 1, + typeName: "Lemonade", + isExported: true, }, { name: "WatermelonLemonade", @@ -76,8 +83,9 @@ var ( Line: 11, Column: 2, }, - value: 2, - typeName: "Lemonade", + value: 2, + typeName: "Lemonade", + isExported: true, }, { name: "MangoLemonade", @@ -85,8 +93,9 @@ var ( Line: 12, Column: 2, }, - value: 3, - typeName: "Lemonade", + value: 3, + typeName: "Lemonade", + isExported: true, }, { name: "StrawberryLemonade", @@ -94,8 +103,9 @@ var ( Line: 13, Column: 2, }, - value: 4, - typeName: "Lemonade", + value: 4, + typeName: "Lemonade", + isExported: true, }, } stringConstants = []constantInfo{ @@ -105,8 +115,9 @@ var ( Line: 5, Column: 7, }, - value: "AnyString", - typeName: "string", + value: "AnyString", + typeName: "string", + isExported: true, }, { name: "methods", @@ -114,8 +125,9 @@ var ( Line: 6, Column: 7, }, - value: "GETPUT", - typeName: "string", + value: "GETPUT", + typeName: "string", + isExported: false, }, } permissionConstants = []constantInfo{ @@ -125,8 +137,9 @@ var ( Line: 9, Column: 2, }, - value: 1, - typeName: "Permission", + value: 1, + typeName: "Permission", + isExported: true, }, { name: "Write", @@ -134,8 +147,9 @@ var ( Line: 10, Column: 2, }, - value: 2, - typeName: "Permission", + value: 2, + typeName: "Permission", + isExported: true, }, { name: "ReadWrite", @@ -143,8 +157,9 @@ var ( Line: 11, Column: 2, }, - value: 3, - typeName: "Permission", + value: 3, + typeName: "Permission", + isExported: true, }, { name: "RequestGet", @@ -152,8 +167,9 @@ var ( Line: 17, Column: 2, }, - value: "GET", - typeName: "RequestMethod", + value: "GET", + typeName: "RequestMethod", + isExported: true, }, { name: "RequestPost", @@ -161,8 +177,9 @@ var ( Line: 18, Column: 2, }, - value: "POST", - typeName: "RequestMethod", + value: "POST", + typeName: "RequestMethod", + isExported: true, }, { name: "RequestPatch", @@ -170,8 +187,9 @@ var ( Line: 19, Column: 2, }, - value: "PATCH", - typeName: "RequestMethod", + value: "PATCH", + typeName: "RequestMethod", + isExported: true, }, { name: "RequestDelete", @@ -179,8 +197,9 @@ var ( Line: 20, Column: 2, }, - value: "DELETE", - typeName: "RequestMethod", + value: "DELETE", + typeName: "RequestMethod", + isExported: true, }, { name: "SendDir", @@ -188,8 +207,9 @@ var ( Line: 26, Column: 2, }, - value: 2, - typeName: "Chan", + value: 2, + typeName: "Chan", + isExported: true, }, { name: "ReceiveDir", @@ -197,8 +217,9 @@ var ( Line: 27, Column: 2, }, - value: 1, - typeName: "Chan", + value: 1, + typeName: "Chan", + isExported: true, }, { name: "BothDir", @@ -206,8 +227,9 @@ var ( Line: 28, Column: 2, }, - value: 3, - typeName: "Chan", + value: 3, + typeName: "Chan", + isExported: true, }, } @@ -218,17 +240,19 @@ var ( Line: 3, Column: 7, }, - value: -4, - typeName: "untyped int", + value: -4, + typeName: "untyped int", + isExported: true, }, { - name: "FloatMathOperation", + name: "floatMathOperation", position: Position{ Line: 4, Column: 7, }, - value: 5.4, - typeName: "untyped int", + value: 5.4, + typeName: "untyped int", + isExported: false, }, { name: "ModOperation", @@ -236,17 +260,19 @@ var ( Line: 5, Column: 7, }, - value: 1, - typeName: "untyped int", + value: 1, + typeName: "untyped int", + isExported: true, }, { - name: "EqualOperation", + name: "equalOperation", position: Position{ Line: 6, Column: 7, }, - value: true, - typeName: "bool", + value: true, + typeName: "bool", + isExported: false, }, { name: "NotEqualOperation", @@ -254,8 +280,9 @@ var ( Line: 7, Column: 7, }, - value: false, - typeName: "bool", + value: false, + typeName: "bool", + isExported: true, }, { name: "GreaterThan", @@ -263,8 +290,9 @@ var ( Line: 8, Column: 7, }, - value: true, - typeName: "bool", + value: true, + typeName: "bool", + isExported: true, }, { name: "GreaterThanOrEqual", @@ -272,8 +300,9 @@ var ( Line: 9, Column: 7, }, - value: true, - typeName: "bool", + value: true, + typeName: "bool", + isExported: true, }, { name: "LessThan", @@ -281,8 +310,9 @@ var ( Line: 10, Column: 7, }, - value: true, - typeName: "bool", + value: true, + typeName: "bool", + isExported: true, }, { name: "LessThanOrEqual", @@ -290,8 +320,9 @@ var ( Line: 11, Column: 7, }, - value: true, - typeName: "bool", + value: true, + typeName: "bool", + isExported: true, }, { name: "XorOperation", @@ -299,8 +330,9 @@ var ( Line: 12, Column: 7, }, - value: 6, - typeName: "untyped int", + value: 6, + typeName: "untyped int", + isExported: true, }, { name: "AndNotOperation", @@ -308,8 +340,9 @@ var ( Line: 13, Column: 7, }, - value: 4, - typeName: "untyped int", + value: 4, + typeName: "untyped int", + isExported: true, }, { name: "AndOperation", @@ -317,17 +350,19 @@ var ( Line: 14, Column: 7, }, - value: 0, - typeName: "untyped int", + value: 0, + typeName: "untyped int", + isExported: true, }, { - name: "OrOperation", + name: "orOperation", position: Position{ Line: 15, Column: 7, }, - value: 6, - typeName: "untyped int", + value: 6, + typeName: "untyped int", + isExported: false, }, } ) @@ -344,7 +379,7 @@ func assertConstants(t *testing.T, file *File, constants []constantInfo) bool { actualConstant, exists := file.Constants().FindByName(expectedConstant.name) if !exists || actualConstant == nil { - t.Errorf("constant with name %s in file %s is not found", file.name, expectedConstant.name) + t.Errorf("constant with name %s in file %s is not found", expectedConstant.name, file.name) continue } @@ -362,6 +397,12 @@ func assertConstants(t *testing.T, file *File, constants []constantInfo) bool { t.Errorf("type name of constant %s in file %s shoud be %s, but got %s", actualConstant.Name(), file.name, expectedConstant.typeName, actualConstant.Type().Name()) } + if actualConstant.IsExported() && !expectedConstant.isExported { + t.Errorf("constant with name %s in file %s is exported, but should be unexported field", expectedConstant.name, file.name) + } else if !actualConstant.IsExported() && expectedConstant.isExported { + t.Errorf("constant with name %s in file %s is not exported, but should be exported field", expectedConstant.name, file.name) + } + assert.Equal(t, expectedConstant.position, actualConstant.Position(), "the position of constant %s in file %s should be %w, but got %w", expectedConstant.name, actualConstant.File().Name(), expectedConstant.position, actualConstant.Position()) } diff --git a/visitor/custom_type.go b/visitor/custom_type.go index 656f664..23658de 100644 --- a/visitor/custom_type.go +++ b/visitor/custom_type.go @@ -106,6 +106,8 @@ func (c *CustomType) String() string { if c.file != nil && c.file.pkg.Name != "builtin" { builder.WriteString(fmt.Sprintf("%s.%s", c.file.Package().Name, c.name)) + } else if c.name != "" { + builder.WriteString(c.name) } if c.TypeParameters().Len() != 0 { @@ -119,9 +121,9 @@ func (c *CustomType) String() string { builder.WriteString(",") } - builder.WriteString("]") } + builder.WriteString("]") } return builder.String() diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index befc960..532dc8a 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -11,6 +11,7 @@ type customTypeInfo struct { underlyingTypeName string isExported bool methods map[string]functionInfo + stringValue string } var ( @@ -23,6 +24,7 @@ var ( "Print": printErrorMethod, "ToErrors": toErrorsMethod, }, + stringValue: "any.errorList", }, } permissionCustomTypes = map[string]customTypeInfo{ @@ -30,16 +32,19 @@ var ( name: "Permission", underlyingTypeName: "int", isExported: true, + stringValue: "any.Permission", }, "RequestMethod": { name: "RequestMethod", underlyingTypeName: "string", isExported: true, + stringValue: "any.RequestMethod", }, "Chan": { name: "Chan", underlyingTypeName: "int", isExported: true, + stringValue: "any.Chan", }, } coffeeCustomTypes = map[string]customTypeInfo{ @@ -47,6 +52,7 @@ var ( name: "Coffee", underlyingTypeName: "int", isExported: true, + stringValue: "menu.Coffee", }, } freshCustomTypes = map[string]customTypeInfo{ @@ -54,16 +60,18 @@ var ( name: "Lemonade", underlyingTypeName: "uint", isExported: true, + stringValue: "menu.Lemonade", }, } genericsCustomTypes = map[string]customTypeInfo{ "HttpHandler": { name: "HttpHandler", - underlyingTypeName: "func (ctx C)", + underlyingTypeName: "func (ctx C) K", isExported: true, methods: map[string]functionInfo{ "Print": printHttpHandlerMethod, }, + stringValue: "any.HttpHandler[C context.Context,K string|int]", }, } ) @@ -110,6 +118,10 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy t.Errorf("custom type with name %s is not exported, but should be exported field", expectedCustomTypeName) } + if expectedCustomType.stringValue != actualCustomType.String() { + t.Errorf("Output returning from String() method for custom type with name %s does not equal to %s, but got %s", expectedCustomTypeName, expectedCustomType.stringValue, actualCustomType.String()) + } + assertFunctions(t, fmt.Sprintf("custom type %s", actualCustomType.Name()), actualCustomType.Methods(), expectedCustomType.methods) index++ } diff --git a/visitor/function_test.go b/visitor/function_test.go index 3736ece..bb802e3 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -213,6 +213,22 @@ var ( }, }, } + publishMethod = functionInfo{ + markers: markers.Values{}, + name: "Publish", + fileName: "generics.go", + position: Position{ + Line: 38, + Column: 9, + }, + isVariadic: false, + params: []variableInfo{ + { + name: "e", + typeName: "E", + }, + }, + } printCookieMethod = functionInfo{ markers: markers.Values{}, name: "PrintCookie", @@ -265,6 +281,10 @@ var ( name: "C", typeName: "", }, + { + name: "K", + typeName: "", + }, }, } printErrorMethod = functionInfo{ diff --git a/visitor/interface.go b/visitor/interface.go index eea6828..4ebffd2 100644 --- a/visitor/interface.go +++ b/visitor/interface.go @@ -212,14 +212,22 @@ func (i *Interface) String() string { var builder strings.Builder if i.file != nil && i.file.pkg.Name != "builtin" { builder.WriteString(fmt.Sprintf("%s.%s", i.file.Package().Name, i.name)) + } else if i.name != "" { + builder.WriteString(i.name) } - for index := 0; index < i.TypeParameters().Len(); index++ { - typeParam := i.TypeParameters().At(index) - builder.WriteString(typeParam.String()) - if index != i.TypeParameters().Len()-1 { - builder.WriteString(",") + if i.TypeParameters().Len() != 0 { + builder.WriteString("[") + + for index := 0; index < i.TypeParameters().Len(); index++ { + typeParam := i.TypeParameters().At(index) + builder.WriteString(typeParam.String()) + + if index != i.TypeParameters().Len()-1 { + builder.WriteString(",") + } } + builder.WriteString("]") } diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 68d3e32..b2a4498 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -17,6 +17,7 @@ type interfaceInfo struct { embeddedTypes []string embeddedInterfaces []string isExported bool + stringValue string } // interfaces @@ -35,6 +36,7 @@ var ( methods: map[string]functionInfo{ "Save": saveFunction, }, + stringValue: "any.Repository[T any,ID any]", } numberInterface = interfaceInfo{ name: "Number", @@ -52,6 +54,23 @@ var ( }, embeddedTypes: []string{"Ordered"}, embeddedInterfaces: []string{"Ordered"}, + stringValue: "any.Number", + } + eventPublisherInterface = interfaceInfo{ + name: "EventPublisher", + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 37, + Column: 6, + }, + explicitMethods: map[string]functionInfo{ + "Publish": publishMethod, + }, + methods: map[string]functionInfo{ + "Publish": publishMethod, + }, + stringValue: "any.EventPublisher[E any]", } bakeryShopInterface = interfaceInfo{ markers: markers.Values{ @@ -83,6 +102,7 @@ var ( }, embeddedTypes: []string{"Dessert"}, embeddedInterfaces: []string{"Dessert"}, + stringValue: "menu.BakeryShop", } dessertInterface = interfaceInfo{ @@ -118,6 +138,7 @@ var ( "Pie": pieFunction, "muffin": muffinFunction, }, + stringValue: "menu.Dessert", } newYearsEveCookieInterface = interfaceInfo{ @@ -141,6 +162,7 @@ var ( explicitMethods: map[string]functionInfo{ "Funfetti": funfettiFunction, }, + stringValue: "menu.newYearsEveCookie", } sweetShopInterface = interfaceInfo{ @@ -174,6 +196,7 @@ var ( }, embeddedTypes: []string{"newYearsEveCookie", "Dessert"}, embeddedInterfaces: []string{"newYearsEveCookie", "Dessert"}, + stringValue: "menu.SweetShop", } ) @@ -244,6 +267,10 @@ func assertInterfaces(t *testing.T, file *File, interfaces map[string]interfaceI actualInterface.EmbeddedInterfaces() actualInterface.EmbeddedTypes() + if expectedInterface.stringValue != actualInterface.String() { + t.Errorf("Output returning from String() method for interface type with name %s does not equal to %s, but got %s", expectedInterfaceName, expectedInterface.stringValue, actualInterface.String()) + } + assertInterfaceEmbeddedTypes(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.EmbeddedTypes(), expectedInterface.embeddedTypes) assertFunctions(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.Methods(), expectedInterface.methods) assertFunctions(t, fmt.Sprintf("interface %s", actualInterface.Name()), actualInterface.ExplicitMethods(), expectedInterface.explicitMethods) diff --git a/visitor/struct.go b/visitor/struct.go index c41c804..bb7f299 100644 --- a/visitor/struct.go +++ b/visitor/struct.go @@ -170,14 +170,22 @@ func (s *Struct) String() string { var builder strings.Builder if s.file != nil && s.file.pkg.Name != "builtin" { builder.WriteString(fmt.Sprintf("%s.%s", s.file.Package().Name, s.name)) + } else if s.name != "" { + builder.WriteString(s.name) } - for index := 0; index < s.TypeParameters().Len(); index++ { - typeParam := s.TypeParameters().At(index) - builder.WriteString(typeParam.String()) - if index != s.TypeParameters().Len()-1 { - builder.WriteString(",") + if s.TypeParameters().Len() != 0 { + builder.WriteString("[") + + for index := 0; index < s.TypeParameters().Len(); index++ { + typeParam := s.TypeParameters().At(index) + builder.WriteString(typeParam.String()) + + if index != s.TypeParameters().Len()-1 { + builder.WriteString(",") + } } + builder.WriteString("]") } diff --git a/visitor/struct_test.go b/visitor/struct_test.go index a0bf334..b531bc3 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -27,14 +27,16 @@ type structInfo struct { totalFields int numEmbeddedFields int implements map[string]struct{} + stringValue string } // structs var ( controllerStruct = structInfo{ - markers: markers.Values{}, - fileName: "generics.go", - isExported: true, + markers: markers.Values{}, + stringValue: "any.Controller[C context.Context,T any]", + fileName: "generics.go", + isExported: true, position: Position{ Line: 17, Column: 6, @@ -63,9 +65,10 @@ var ( numEmbeddedFields: 0, } testControllerStruct = structInfo{ - markers: markers.Values{}, - fileName: "generics.go", - isExported: true, + markers: markers.Values{}, + stringValue: "any.TestController", + fileName: "generics.go", + isExported: true, position: Position{ Line: 26, Column: 6, @@ -95,8 +98,9 @@ var ( }, }, }, - fileName: "dessert.go", - isExported: true, + stringValue: "menu.FriedCookie", + fileName: "dessert.go", + isExported: true, position: Position{ Line: 30, Column: 6, @@ -147,8 +151,9 @@ var ( }, }, }, - fileName: "dessert.go", - isExported: false, + stringValue: "menu.cookie", + fileName: "dessert.go", + isExported: false, position: Position{ Line: 56, Column: 6, @@ -242,6 +247,10 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool t.Errorf("the number of the embededed fields of the struct %s should be %d, but got %d", expectedStructName, expectedStruct.numEmbeddedFields, actualStruct.NumFields()) } + if expectedStruct.stringValue != actualStruct.String() { + t.Errorf("Output returning from String() method for struct type with name %s does not equal to %s, but got %s", expectedStructName, expectedStruct.stringValue, actualStruct.String()) + } + assert.Equal(t, actualStruct, actualStruct.Underlying()) assert.Equal(t, expectedStruct.position, actualStruct.Position(), "the position of the struct %s should be %w, but got %w", diff --git a/visitor/type_constraint_test.go b/visitor/type_constraint_test.go new file mode 100644 index 0000000..a32f70d --- /dev/null +++ b/visitor/type_constraint_test.go @@ -0,0 +1,70 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTypeConstraint_Name(t *testing.T) { + c := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + + assert.Equal(t, "bool", c.Name()) +} + +func TestTypeConstraint_String(t *testing.T) { + c := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + + assert.Equal(t, "~bool", c.String()) +} + +func TestTypeConstraint_Type(t *testing.T) { + typ := basicTypesMap["byte"] + c := &TypeConstraint{ + typ: typ, + tildeOperator: true, + } + + assert.Equal(t, typ, c.Type()) +} + +func TestTypeConstraint_Underlying(t *testing.T) { + c := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + + assert.Equal(t, c, c.Underlying()) +} + +func TestTypeConstraints_At(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + + typeConstraints := &TypeConstraints{ + elements: []*TypeConstraint{constraint}, + } + + assert.NotNil(t, typeConstraints.At(0)) + assert.Nil(t, typeConstraints.At(1)) +} + +func TestTypeConstraints_Len(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + + typeConstraints := &TypeConstraints{ + elements: []*TypeConstraint{constraint}, + } + + assert.Equal(t, len(typeConstraints.elements), typeConstraints.Len()) +} diff --git a/visitor/type_parameter_test.go b/visitor/type_parameter_test.go new file mode 100644 index 0000000..67212d3 --- /dev/null +++ b/visitor/type_parameter_test.go @@ -0,0 +1,108 @@ +package visitor + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTypeParameter_Name(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + + assert.Equal(t, typeParameter.name, typeParameter.Name()) +} + +func TestTypeParameter_String(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + + assert.Equal(t, fmt.Sprintf("%s ~%s", typeParameter.name, "bool"), typeParameter.String()) +} + +func TestTypeParameter_TypeConstraints(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + assert.Equal(t, typeConstraints, typeParameter.TypeConstraints()) +} + +func TestTypeParameter_Underlying(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + assert.Equal(t, typeParameter, typeParameter.Underlying()) +} + +func TestTypeParameters_At(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + typeParameters := &TypeParameters{ + elements: []*TypeParameter{typeParameter}, + } + + assert.NotNil(t, typeParameters.At(0)) + assert.Nil(t, typeParameters.At(1)) +} + +func TestTypeParameters_Len(t *testing.T) { + constraint := &TypeConstraint{ + typ: basicTypesMap["bool"], + tildeOperator: true, + } + typeConstraints := &TypeConstraints{ + []*TypeConstraint{constraint}, + } + typeParameter := &TypeParameter{ + name: "anyTypeParameter", + constraints: typeConstraints, + } + typeParameters := &TypeParameters{ + elements: []*TypeParameter{typeParameter}, + } + + assert.Equal(t, len(typeParameters.elements), typeParameters.Len()) +} diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 5bcd2c0..6954b2f 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -131,8 +131,9 @@ func TestVisitor_VisitPackage(t *testing.T) { "Index": indexMethod, }, interfaces: map[string]interfaceInfo{ - "Repository": repositoryInterface, - "Number": numberInterface, + "Repository": repositoryInterface, + "Number": numberInterface, + "EventPublisher": eventPublisherInterface, }, structs: map[string]structInfo{ "Controller": controllerStruct, From 07227571d75976b9902a5f2a9281fdca188f27de Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 7 Apr 2023 20:21:49 +0300 Subject: [PATCH 05/16] Rename package imports --- README.md | 2 +- cmd/marker/common.go | 4 ++-- cmd/marker/constants.go | 2 +- cmd/marker/download.go | 4 ++-- cmd/marker/generate.go | 16 ++++++++++++---- cmd/marker/init.go | 2 +- cmd/marker/list.go | 4 ++-- cmd/marker/main.go | 2 +- cmd/marker/processor.go | 12 ++++++------ cmd/marker/validate.go | 2 +- collector.go | 2 +- collector_test.go | 8 ++++---- go.mod | 3 ++- go.sum | 2 ++ packages/loader.go | 2 +- packages/loader_test.go | 16 ++++++++-------- packages/package.go | 2 +- packages/package_test.go | 2 +- processor/common.go | 2 +- processor/context.go | 4 ++-- processor/generate.go | 4 ++-- processor/validate.go | 4 ++-- visitor/collector.go | 2 +- visitor/constant.go | 2 +- visitor/custom_type.go | 4 ++-- visitor/file.go | 4 ++-- visitor/function.go | 4 ++-- visitor/function_test.go | 2 +- visitor/interface.go | 4 ++-- visitor/interface_test.go | 2 +- visitor/position.go | 2 +- visitor/struct.go | 4 ++-- visitor/struct_test.go | 2 +- visitor/visitor.go | 4 ++-- visitor/visitor_test.go | 4 ++-- 35 files changed, 76 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 542de66..9ac5044 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ To Install Marker CLI quickly, follow the installation instructions. 1. You first need Go installed (version 1.18+ is required), then you can use the below Go command to install Marker CLI. - `$ go get -u github.com/procyon-projects/marker/...` + `$ go get -u github.com/procyon-projects/markers/...` 2. Verify that you've installed Marker CLI by typing the following command. `$ marker version` diff --git a/cmd/marker/common.go b/cmd/marker/common.go index 1fa824f..ae10917 100644 --- a/cmd/marker/common.go +++ b/cmd/marker/common.go @@ -3,8 +3,8 @@ package main import ( "errors" "fmt" - "github.com/procyon-projects/marker/internal/cmd" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/internal/cmd" + "github.com/procyon-projects/markers/packages" "io" "os" "os/exec" diff --git a/cmd/marker/constants.go b/cmd/marker/constants.go index 03dddaa..364d052 100644 --- a/cmd/marker/constants.go +++ b/cmd/marker/constants.go @@ -2,6 +2,6 @@ package main const ( AppName = "marker" - Package = "github.com/procyon-projects/marker" + Package = "github.com/procyon-projects/markers" Version = "v0.2.8-dev" ) diff --git a/cmd/marker/download.go b/cmd/marker/download.go index 593c1eb..412369f 100644 --- a/cmd/marker/download.go +++ b/cmd/marker/download.go @@ -3,8 +3,8 @@ package main import ( "errors" "fmt" - "github.com/procyon-projects/marker/packages" - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/packages" + "github.com/procyon-projects/markers/processor" "github.com/spf13/cobra" "os" ) diff --git a/cmd/marker/generate.go b/cmd/marker/generate.go index 767205f..84ee8e5 100644 --- a/cmd/marker/generate.go +++ b/cmd/marker/generate.go @@ -1,20 +1,28 @@ package main import ( - "github.com/procyon-projects/marker/packages" - "github.com/procyon-projects/marker/processor" - "github.com/procyon-projects/marker/visitor" + "github.com/procyon-projects/markers/packages" + "github.com/procyon-projects/markers/processor" + "github.com/procyon-projects/markers/visitor" "os" "os/exec" ) func Generate(ctx *processor.Context) { - pkg, _ := ctx.LoadResult().Lookup("github.com/procyon-projects/marker/test/package1") + pkg, _ := ctx.LoadResult().Lookup("github.com/procyon-projects/markers/test/package1") err := visitor.EachFile(ctx.Collector(), []*packages.Package{pkg}, func(file *visitor.File, err error) error { if file.NumImportMarkers() == 0 { return nil } + /* + f := file.Functions().At(0) + markers, ok := f.Markers().FindByName("shelf:entity") + + if ok { + + }*/ + return err }) diff --git a/cmd/marker/init.go b/cmd/marker/init.go index cc47e73..4624399 100644 --- a/cmd/marker/init.go +++ b/cmd/marker/init.go @@ -3,7 +3,7 @@ package main import ( "encoding/json" "errors" - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/processor" "github.com/spf13/cobra" "log" "os" diff --git a/cmd/marker/list.go b/cmd/marker/list.go index 0f495af..734cfe7 100644 --- a/cmd/marker/list.go +++ b/cmd/marker/list.go @@ -3,8 +3,8 @@ package main import ( "encoding/json" "errors" - "github.com/procyon-projects/marker/packages" - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/packages" + "github.com/procyon-projects/markers/processor" "github.com/spf13/cobra" "log" ) diff --git a/cmd/marker/main.go b/cmd/marker/main.go index 66d1ed8..feb80fe 100644 --- a/cmd/marker/main.go +++ b/cmd/marker/main.go @@ -1,7 +1,7 @@ package main import ( - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/processor" "log" ) diff --git a/cmd/marker/processor.go b/cmd/marker/processor.go index c183d24..75aab1c 100644 --- a/cmd/marker/processor.go +++ b/cmd/marker/processor.go @@ -2,8 +2,8 @@ package main import ( "errors" - "github.com/procyon-projects/marker/packages" - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/packages" + "github.com/procyon-projects/markers/processor" "github.com/spf13/cobra" "gopkg.in/yaml.v3" "log" @@ -26,8 +26,8 @@ var mainFileContent = `// Code generated by marker; DO NOT EDIT. package main import ( - "github.com/procyon-projects/marker/cmd" - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/cmd" + "github.com/procyon-projects/markers/processor" "log" ) @@ -46,7 +46,7 @@ func main() { var generateFileContent = `package main import ( - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/processor" ) func Generate(ctx *processor.Context) { @@ -57,7 +57,7 @@ func Generate(ctx *processor.Context) { var validateFileContent = `package main import ( - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/processor" ) func Validate(ctx *processor.Context) { diff --git a/cmd/marker/validate.go b/cmd/marker/validate.go index 444db4e..1c905f1 100644 --- a/cmd/marker/validate.go +++ b/cmd/marker/validate.go @@ -1,7 +1,7 @@ package main import ( - "github.com/procyon-projects/marker/processor" + "github.com/procyon-projects/markers/processor" "os" "os/exec" ) diff --git a/collector.go b/collector.go index b518f2f..4222776 100644 --- a/collector.go +++ b/collector.go @@ -3,7 +3,7 @@ package markers import ( "errors" "fmt" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/packages" "go/ast" "go/token" "strings" diff --git a/collector_test.go b/collector_test.go index 7c1677a..193de62 100644 --- a/collector_test.go +++ b/collector_test.go @@ -1,14 +1,14 @@ package markers import ( - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/packages" "github.com/stretchr/testify/assert" "testing" ) func TestCollector_Collect(t *testing.T) { - result, _ := packages.LoadPackages("github.com/procyon-projects/marker/test/...") - pkg, _ := result.Lookup("github.com/procyon-projects/marker/test/menu") + result, _ := packages.LoadPackages("github.com/procyon-projects/markers/test/...") + pkg, _ := result.Lookup("github.com/procyon-projects/markers/test/menu") registry := NewRegistry() collector := NewCollector(registry) @@ -17,7 +17,7 @@ func TestCollector_Collect(t *testing.T) { assert.NotNil(t, nodes) assert.NoError(t, err) - pkg, _ = result.Lookup("github.com/procyon-projects/marker/test/any") + pkg, _ = result.Lookup("github.com/procyon-projects/markers/test/any") nodes, err = collector.Collect(pkg) assert.NotNil(t, nodes) diff --git a/go.mod b/go.mod index 181d456..3ba2773 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/procyon-projects/marker +module github.com/procyon-projects/markers go 1.18 @@ -14,6 +14,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/procyon-projects/marker v0.2.8-dev // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.4.0 // indirect golang.org/x/mod v0.6.0 // indirect diff --git a/go.sum b/go.sum index 0965579..1677183 100644 --- a/go.sum +++ b/go.sum @@ -195,6 +195,8 @@ github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/procyon-projects/marker v0.2.8-dev h1:ofTGupFUY2wUrKpTtM9ANLI1ikwpAyMJ0Ep+dfBLGLg= +github.com/procyon-projects/marker v0.2.8-dev/go.mod h1:pIZpONoM8//4jQ2IFDn8qilsVrb7iVku8s+vV/aW2I0= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= diff --git a/packages/loader.go b/packages/loader.go index 5a77abf..57bf662 100644 --- a/packages/loader.go +++ b/packages/loader.go @@ -3,7 +3,7 @@ package packages import ( "errors" "fmt" - "github.com/procyon-projects/marker/internal/cmd" + "github.com/procyon-projects/markers/internal/cmd" "go/token" "golang.org/x/tools/go/packages" "os/exec" diff --git a/packages/loader_test.go b/packages/loader_test.go index 4b558c8..0ef0c69 100644 --- a/packages/loader_test.go +++ b/packages/loader_test.go @@ -6,7 +6,7 @@ import ( ) func TestLoadResult_Packages(t *testing.T) { - loadResult, err := LoadPackages("github.com/procyon-projects/marker/test/...") + loadResult, err := LoadPackages("github.com/procyon-projects/markers/test/...") assert.Nil(t, err) assert.NotNil(t, loadResult) @@ -30,25 +30,25 @@ func TestLoadResult_StandardPackage(t *testing.T) { } func TestLoadResult_Lookup(t *testing.T) { - loadResult, err := LoadPackages("github.com/procyon-projects/marker/test/...") + loadResult, err := LoadPackages("github.com/procyon-projects/markers/test/...") assert.Nil(t, err) assert.NotNil(t, loadResult) assert.Len(t, loadResult.Packages(), 2) - pkg, err := loadResult.Lookup("github.com/procyon-projects/marker/test/menu") + pkg, err := loadResult.Lookup("github.com/procyon-projects/markers/test/menu") assert.Nil(t, err) assert.NotNil(t, pkg) assert.Equal(t, "menu", pkg.Name) assert.False(t, pkg.IsStandardPackage()) - assert.Equal(t, "github.com/procyon-projects/marker/test/menu", pkg.ID) - assert.Equal(t, "github.com/procyon-projects/marker/test/menu", pkg.PkgPath) + assert.Equal(t, "github.com/procyon-projects/markers/test/menu", pkg.ID) + assert.Equal(t, "github.com/procyon-projects/markers/test/menu", pkg.PkgPath) - pkg, err = loadResult.Lookup("github.com/procyon-projects/marker/test/any") + pkg, err = loadResult.Lookup("github.com/procyon-projects/markers/test/any") assert.Nil(t, err) assert.NotNil(t, pkg) assert.Equal(t, "any", pkg.Name) assert.False(t, pkg.IsStandardPackage()) - assert.Equal(t, "github.com/procyon-projects/marker/test/any", pkg.ID) - assert.Equal(t, "github.com/procyon-projects/marker/test/any", pkg.PkgPath) + assert.Equal(t, "github.com/procyon-projects/markers/test/any", pkg.ID) + assert.Equal(t, "github.com/procyon-projects/markers/test/any", pkg.PkgPath) } diff --git a/packages/package.go b/packages/package.go index f0f1a08..eb5a36c 100644 --- a/packages/package.go +++ b/packages/package.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" "fmt" - "github.com/procyon-projects/marker/internal/cmd" + "github.com/procyon-projects/markers/internal/cmd" "golang.org/x/tools/go/packages" "os" "os/exec" diff --git a/packages/package_test.go b/packages/package_test.go index e48581e..333d1eb 100644 --- a/packages/package_test.go +++ b/packages/package_test.go @@ -4,7 +4,7 @@ import ( "encoding/json" "errors" "fmt" - "github.com/procyon-projects/marker/internal/cmd" + "github.com/procyon-projects/markers/internal/cmd" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "os" diff --git a/processor/common.go b/processor/common.go index 0984ef9..f55236f 100644 --- a/processor/common.go +++ b/processor/common.go @@ -2,7 +2,7 @@ package processor import ( "encoding/json" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/packages" "os" "path" "path/filepath" diff --git a/processor/context.go b/processor/context.go index aa2f626..adcc7df 100644 --- a/processor/context.go +++ b/processor/context.go @@ -2,8 +2,8 @@ package processor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "log" "path/filepath" "regexp" diff --git a/processor/generate.go b/processor/generate.go index f2c183a..3395fe7 100644 --- a/processor/generate.go +++ b/processor/generate.go @@ -3,8 +3,8 @@ package processor import ( "errors" "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "github.com/spf13/cobra" "path" ) diff --git a/processor/validate.go b/processor/validate.go index 2788660..a180662 100644 --- a/processor/validate.go +++ b/processor/validate.go @@ -3,8 +3,8 @@ package processor import ( "errors" "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "github.com/spf13/cobra" "path" ) diff --git a/visitor/collector.go b/visitor/collector.go index 409bb15..5f8f194 100644 --- a/visitor/collector.go +++ b/visitor/collector.go @@ -1,6 +1,6 @@ package visitor -import "github.com/procyon-projects/marker/packages" +import "github.com/procyon-projects/markers/packages" type packageCollector struct { hasSeen map[string]bool diff --git a/visitor/constant.go b/visitor/constant.go index 56df178..729a5c3 100644 --- a/visitor/constant.go +++ b/visitor/constant.go @@ -1,7 +1,7 @@ package visitor import ( - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/packages" "go/ast" "go/token" "reflect" diff --git a/visitor/custom_type.go b/visitor/custom_type.go index 23658de..5e90ec1 100644 --- a/visitor/custom_type.go +++ b/visitor/custom_type.go @@ -2,8 +2,8 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "strings" "sync" diff --git a/visitor/file.go b/visitor/file.go index 8df8fb2..4fa2104 100644 --- a/visitor/file.go +++ b/visitor/file.go @@ -1,8 +1,8 @@ package visitor import ( - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "path/filepath" ) diff --git a/visitor/function.go b/visitor/function.go index 82942f1..b2b6e81 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -2,8 +2,8 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "strings" "sync" diff --git a/visitor/function_test.go b/visitor/function_test.go index bb802e3..b43fae2 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -2,7 +2,7 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" + "github.com/procyon-projects/markers" "github.com/stretchr/testify/assert" "strings" "testing" diff --git a/visitor/interface.go b/visitor/interface.go index 4ebffd2..a0e6867 100644 --- a/visitor/interface.go +++ b/visitor/interface.go @@ -2,8 +2,8 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "go/token" "go/types" diff --git a/visitor/interface_test.go b/visitor/interface_test.go index b2a4498..745026a 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -2,7 +2,7 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" + "github.com/procyon-projects/markers" "github.com/stretchr/testify/assert" "testing" ) diff --git a/visitor/position.go b/visitor/position.go index ddde43c..ce731a8 100644 --- a/visitor/position.go +++ b/visitor/position.go @@ -1,7 +1,7 @@ package visitor import ( - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers/packages" "go/token" ) diff --git a/visitor/struct.go b/visitor/struct.go index bb7f299..c8958ee 100644 --- a/visitor/struct.go +++ b/visitor/struct.go @@ -2,8 +2,8 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "go/token" "go/types" diff --git a/visitor/struct_test.go b/visitor/struct_test.go index b531bc3..2358a35 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -2,7 +2,7 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" + "github.com/procyon-projects/markers" "github.com/stretchr/testify/assert" "testing" ) diff --git a/visitor/visitor.go b/visitor/visitor.go index 5883a1c..c1f8f49 100644 --- a/visitor/visitor.go +++ b/visitor/visitor.go @@ -2,8 +2,8 @@ package visitor import ( "errors" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "go/ast" "go/token" ) diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 6954b2f..0925953 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -2,8 +2,8 @@ package visitor import ( "fmt" - "github.com/procyon-projects/marker" - "github.com/procyon-projects/marker/packages" + "github.com/procyon-projects/markers" + "github.com/procyon-projects/markers/packages" "github.com/stretchr/testify/assert" "testing" ) From e7e0a1e2198328af4cdc2e0724c62000cfb08766 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 7 Apr 2023 20:34:57 +0300 Subject: [PATCH 06/16] Rename package imports --- test/any/permission.go | 2 +- test/menu/coffee.go | 2 +- test/menu/dessert.go | 2 +- test/menu/fresh.go | 2 +- visitor/visitor_test.go | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/any/permission.go b/test/any/permission.go index f3ce469..851fac9 100644 --- a/test/any/permission.go +++ b/test/any/permission.go @@ -1,4 +1,4 @@ -// +import=marker, Pkg=github.com/procyon-projects/marker +// +import=marker, Pkg=github.com/procyon-projects/markers // +marker:package-level:Name=permission.go package any diff --git a/test/menu/coffee.go b/test/menu/coffee.go index 8a3c1dc..25b394e 100644 --- a/test/menu/coffee.go +++ b/test/menu/coffee.go @@ -1,4 +1,4 @@ -// +import=marker, Pkg=github.com/procyon-projects/marker +// +import=marker, Pkg=github.com/procyon-projects/markers // +marker:package-level:Name=coffee.go package menu diff --git a/test/menu/dessert.go b/test/menu/dessert.go index 134f0bb..c5d0ada 100644 --- a/test/menu/dessert.go +++ b/test/menu/dessert.go @@ -1,4 +1,4 @@ -// +import=marker, Pkg=github.com/procyon-projects/marker +// +import=marker, Pkg=github.com/procyon-projects/markers // +marker:package-level:Name=dessert.go package menu diff --git a/test/menu/fresh.go b/test/menu/fresh.go index 04daff2..052bdc9 100644 --- a/test/menu/fresh.go +++ b/test/menu/fresh.go @@ -1,4 +1,4 @@ -// +import=marker, Pkg=github.com/procyon-projects/marker +// +import=marker, Pkg=github.com/procyon-projects/markers // +marker:package-level:Name=fresh.go package menu diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 0925953..bbe1a10 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -59,7 +59,7 @@ func TestVisitor_VisitPackage(t *testing.T) { } testCasePkgs := map[string]map[string]testFile{ - "github.com/procyon-projects/marker/test/menu": { + "github.com/procyon-projects/markers/test/menu": { "coffee.go": { constants: coffeeConstants, customTypes: coffeeCustomTypes, @@ -108,7 +108,7 @@ func TestVisitor_VisitPackage(t *testing.T) { }, }, }, - "github.com/procyon-projects/marker/test/any": { + "github.com/procyon-projects/markers/test/any": { "error.go": { constants: []constantInfo{}, customTypes: errorCustomTypes, @@ -179,7 +179,7 @@ func TestVisitor_VisitPackage(t *testing.T) { registry := markers.NewRegistry() for _, m := range markerList { - err := registry.Register(m.Name, "github.com/procyon-projects/marker", m.Level, m.Output) + err := registry.Register(m.Name, "github.com/procyon-projects/markers", m.Level, m.Output) if err != nil { t.Errorf("marker %s could not be registered", m.Name) return From 5072d13aa3ab02beccc45a3e6a508d219a89716b Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 7 Apr 2023 21:23:38 +0300 Subject: [PATCH 07/16] Add more unit tests --- visitor/file_test.go | 4 ++++ visitor/function_test.go | 6 +++++- visitor/visitor_test.go | 10 ++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/visitor/file_test.go b/visitor/file_test.go index 28a1651..707777d 100644 --- a/visitor/file_test.go +++ b/visitor/file_test.go @@ -69,6 +69,10 @@ func assertImports(t *testing.T, file *File, expectedImports []importInfo) bool t.Errorf("import path in file %s shoud be %s, but got %s", file.name, expectedImport.path, actualImport.Path()) } + if expectedImport.file != actualImport.File().Name() { + t.Errorf("the file name for import '%s' should be %s, but got %s", expectedImport.path, expectedImport.file, actualImport.File().Name()) + } + if actualImport.SideEffect() && !expectedImport.sideEffect { t.Errorf("import with path %s in file %s is not an import side effect, but should be an import side effect", expectedImport.path, file.name) } else if !actualImport.SideEffect() && expectedImport.sideEffect { diff --git a/visitor/function_test.go b/visitor/function_test.go index b43fae2..658e113 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -925,7 +925,7 @@ func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actua t.Errorf("at index %d, the parameter name of the %s should be %s, but got %s", index, msg, expectedFunctionParam.name, actualFunctionParam.name) } - if expectedFunctionParam.typeName != actualFunctionParam.Type().Name() { + if expectedFunctionParam.String() != actualFunctionParam.Type().Name() { t.Errorf("at index %d, the parameter type name of the %s should be %s, but got %s", index, msg, expectedFunctionParam.typeName, actualFunctionParam.Type().Name()) } } @@ -944,6 +944,10 @@ func assertFunctionResult(t *testing.T, expectedResults []variableInfo, actualRe if expectedFunctionParam.name != actualFunctionParam.Name() { t.Errorf("at index %d, the parameter result of the %s should be %s, but got %s", index, msg, expectedFunctionParam.name, actualFunctionParam.name) } + + if expectedFunctionParam.String() != actualFunctionParam.Type().Name() { + t.Errorf("at index %d, the parameter result type of the %s should be %s, but got %s", index, msg, expectedFunctionParam.typeName, actualFunctionParam.Type().Name()) + } } } diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index bbe1a10..2a71d9c 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -43,6 +43,14 @@ type variableInfo struct { isPointer bool } +func (v variableInfo) String() string { + if v.isPointer { + return fmt.Sprintf("*%s", v.typeName) + } + + return v.typeName +} + func TestVisitor_VisitPackage(t *testing.T) { markerList := []struct { Name string @@ -143,12 +151,14 @@ func TestVisitor_VisitPackage(t *testing.T) { { name: "", path: "context", + file: "generics.go", sideEffect: false, position: Position{Line: 4, Column: 2}, }, { name: "", path: "golang.org/x/exp/constraints", + file: "generics.go", sideEffect: false, position: Position{Line: 5, Column: 2}, }, From b4071327cac6e998a15e2e1385201a7a4f17d55a Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 7 Apr 2023 22:30:49 +0300 Subject: [PATCH 08/16] Add more unit tests --- argument.go | 2 +- argument_test.go | 55 +++++++++++++++++++++++++++++++++++++ definition.go | 6 +++- marker_test.go | 6 +++- markers_test.go | 8 ++++++ registry_test.go | 38 +++++++++++++++++++++++++ visitor/custom_type_test.go | 4 +++ 7 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 argument_test.go diff --git a/argument.go b/argument.go index 906fe9d..a5f65ad 100644 --- a/argument.go +++ b/argument.go @@ -14,7 +14,7 @@ type Argument struct { Default any } -func ExtractArgument(structField reflect.StructField) (Argument, error) { +func extractArgument(structField reflect.StructField) (Argument, error) { parameterName := upperCamelCase(structField.Name) parameterTag, parameterTagExists := structField.Tag.Lookup("parameter") diff --git a/argument_test.go b/argument_test.go new file mode 100644 index 0000000..e697cef --- /dev/null +++ b/argument_test.go @@ -0,0 +1,55 @@ +package markers + +import ( + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestExtractArgument(t *testing.T) { + reflMarker := reflect.TypeOf(Marker{}) + + arg, err := extractArgument(reflMarker.Field(0)) + assert.Nil(t, err) + assert.Equal(t, "Value", arg.Name) + assert.Empty(t, arg.Default) + assert.True(t, arg.Required) + assert.False(t, arg.Deprecated) + + assert.Equal(t, StringType, arg.TypeInfo.ActualType) + assert.False(t, arg.TypeInfo.IsPointer) + assert.Nil(t, arg.TypeInfo.ItemType) + assert.Empty(t, arg.TypeInfo.Enum) + + arg, err = extractArgument(reflMarker.Field(2)) + assert.Nil(t, err) + assert.Equal(t, "Repeatable", arg.Name) + assert.Empty(t, arg.Default) + assert.False(t, arg.Required) + assert.False(t, arg.Deprecated) + + assert.Equal(t, BoolType, arg.TypeInfo.ActualType) + assert.False(t, arg.TypeInfo.IsPointer) + assert.Nil(t, arg.TypeInfo.ItemType) + assert.Empty(t, arg.TypeInfo.Enum) + + arg, err = extractArgument(reflMarker.Field(4)) + assert.Nil(t, err) + assert.Equal(t, "Targets", arg.Name) + assert.Empty(t, arg.Default) + assert.True(t, arg.Required) + assert.False(t, arg.Deprecated) + + assert.Equal(t, SliceType, arg.TypeInfo.ActualType) + assert.Equal(t, StringType, arg.TypeInfo.ItemType.ActualType) + assert.False(t, arg.TypeInfo.IsPointer) + assert.Equal(t, map[string]interface{}{ + "FIELD_LEVEL": "FIELD_LEVEL", + "FUNCTION_LEVEL": "FUNCTION_LEVEL", + "INTERFACE_METHOD_LEVEL": "INTERFACE_METHOD_LEVEL", + "INTERFACE_TYPE_LEVEL": "INTERFACE_TYPE_LEVEL", + "PACKAGE_LEVEL": "PACKAGE_LEVEL", + "STRUCT_METHOD_LEVEL": "STRUCT_METHOD_LEVEL", + "STRUCT_TYPE_LEVEL": "STRUCT_TYPE_LEVEL", + }, arg.TypeInfo.Enum) +} diff --git a/definition.go b/definition.go index 7b05f89..4c3f7ed 100644 --- a/definition.go +++ b/definition.go @@ -69,6 +69,10 @@ func MakeDefinition(name, pkg string, level TargetLevel, output any) (*Definitio } func (definition *Definition) validate() error { + if definition.Name == "" { + return fmt.Errorf("marker name cannot be empty") + } + if definition.TargetLevel == 0 { return fmt.Errorf("specify target levels for the definition: %v", definition.Name) } @@ -105,7 +109,7 @@ func (definition *Definition) extract() error { continue } - argumentInfo, err := ExtractArgument(field) + argumentInfo, err := extractArgument(field) if err != nil { return err diff --git a/marker_test.go b/marker_test.go index bc75c45..201ef7f 100644 --- a/marker_test.go +++ b/marker_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestMarkerValues_AllMarkers(t *testing.T) { +func TestMarkerValues_FindByName(t *testing.T) { markerValues := make(Values) markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest1") markerValues["anyMarker1"] = append(markerValues["anyMarker1"], "anyTest2") @@ -18,6 +18,10 @@ func TestMarkerValues_AllMarkers(t *testing.T) { markers, exists = markerValues.FindByName("anyMarker2") assert.True(t, exists) assert.Equal(t, []interface{}{"anyTest3"}, markers) + + markers, exists = markerValues.FindByName("anyMarker3") + assert.False(t, exists) + assert.Nil(t, markers) } func TestMarkerValues_Count(t *testing.T) { diff --git a/markers_test.go b/markers_test.go index 02cd5ab..6132884 100644 --- a/markers_test.go +++ b/markers_test.go @@ -48,3 +48,11 @@ func TestImportMarker_Validate_IfPkgIsMissing(t *testing.T) { } assert.Error(t, importMarker.Validate()) } + +func TestImportMarker_ValidateShouldReturnNilIfValidationIsOkay(t *testing.T) { + importMarker := &Import{ + Value: "anyValue", + Pkg: "anyPkg", + } + assert.Nil(t, importMarker.Validate()) +} diff --git a/registry_test.go b/registry_test.go index fdef16d..0b8d0cf 100644 --- a/registry_test.go +++ b/registry_test.go @@ -91,6 +91,44 @@ func TestRegistry_RegisterWithDefinition(t *testing.T) { assert.Len(t, registry.packageMap["anyPkg"], len(testCases)) } +func TestRegistry_RegisterReturnsErrorIfDefinitionIsNil(t *testing.T) { + registry := NewRegistry() + err := registry.RegisterWithDefinition(nil) + assert.Equal(t, "definition cannot be nil", err.Error()) +} + +func TestRegistry_RegisterReturnsErrorIfDefinitionNameIsEmpty(t *testing.T) { + registry := NewRegistry() + err := registry.RegisterWithDefinition(&Definition{}) + assert.Equal(t, "marker name cannot be empty", err.Error()) +} + +func TestRegistry_RegisterReturnsErrorIfTargetLevelIsNotSpecified(t *testing.T) { + registry := NewRegistry() + err := registry.RegisterWithDefinition(&Definition{ + Name: "any:marker", + }) + assert.Equal(t, "specify target levels for the definition: any:marker", err.Error()) +} + +func TestRegistry_RegisterReturnsErrorIfMarkerNameContainsLowerCaseCharacters(t *testing.T) { + registry := NewRegistry() + err := registry.RegisterWithDefinition(&Definition{ + Name: "anyMarker", + TargetLevel: AllLevels, + }) + assert.Equal(t, "marker 'anyMarker' should only contain lower case characters", err.Error()) +} + +func TestRegistry_RegisterReturnsErrorIfMarkerNameContainsWhitespace(t *testing.T) { + registry := NewRegistry() + err := registry.RegisterWithDefinition(&Definition{ + Name: "any:\tmarker", + TargetLevel: AllLevels, + }) + assert.Equal(t, "marker 'any:\tmarker' cannot contain any whitespace", err.Error()) +} + func TestRegistry_RegisterMarkerAlreadyRegistered(t *testing.T) { registry := NewRegistry() registry.Register("marker:test", "anyPkg", TypeLevel, &testTypeLevelMarker{}) diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index 532dc8a..8dd6f2b 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -122,6 +122,10 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy t.Errorf("Output returning from String() method for custom type with name %s does not equal to %s, but got %s", expectedCustomTypeName, expectedCustomType.stringValue, actualCustomType.String()) } + if actualCustomType.NumMethods() != len(expectedCustomType.methods) { + t.Errorf("the number of the methods of the custom type %s should be %d, but got %d", expectedCustomTypeName, len(expectedCustomType.methods), actualCustomType.NumMethods()) + } + assertFunctions(t, fmt.Sprintf("custom type %s", actualCustomType.Name()), actualCustomType.Methods(), expectedCustomType.methods) index++ } From f72e30a565653ff2749e954c5218e337d5186292 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 7 Apr 2023 23:19:22 +0300 Subject: [PATCH 09/16] Add more unit tests --- test/any/generics.go | 4 ++-- visitor/constant_test.go | 5 +++++ visitor/custom_type_test.go | 5 +++++ visitor/function_test.go | 5 +++++ visitor/import_test.go | 19 +++++++++++++++++++ visitor/interface_test.go | 7 ++++++- visitor/struct_test.go | 2 +- visitor/visitor.go | 7 ++++++- visitor/visitor_test.go | 16 ++++++++++++++++ 9 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 visitor/import_test.go diff --git a/test/any/generics.go b/test/any/generics.go index b0c08f4..da4e388 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -10,11 +10,11 @@ func GenericFunction[K []map[T]X, T int | bool, X ~string](x []K) T { return value } -type Repository[T, ID any] interface { +type Repository[T any, ID any | string] interface { Save(entity T) T } -type Controller[C context.Context, T any] struct { +type Controller[C context.Context, T any | int] struct { AnyField1 string AnyField2 int } diff --git a/visitor/constant_test.go b/visitor/constant_test.go index 1ccc404..9ec11f1 100644 --- a/visitor/constant_test.go +++ b/visitor/constant_test.go @@ -408,3 +408,8 @@ func assertConstants(t *testing.T, file *File, constants []constantInfo) bool { return true } + +func TestConstants_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + constants := &Constants{} + assert.Nil(t, constants.At(0)) +} diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index 8dd6f2b..24965e8 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -132,3 +132,8 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy return true } + +func TestCustomTypes_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + customTypes := &CustomTypes{} + assert.Nil(t, customTypes.At(0)) +} diff --git a/visitor/function_test.go b/visitor/function_test.go index 658e113..fa4fa00 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -951,3 +951,8 @@ func assertFunctionResult(t *testing.T, expectedResults []variableInfo, actualRe } } + +func TestFunctions_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + functions := &Functions{} + assert.Nil(t, functions.At(0)) +} diff --git a/visitor/import_test.go b/visitor/import_test.go new file mode 100644 index 0000000..fa623d9 --- /dev/null +++ b/visitor/import_test.go @@ -0,0 +1,19 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestImports_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + imports := &Imports{} + assert.Nil(t, imports.At(0)) +} + +func TestImports_AtShouldReturnNilIfPathIsNotFound(t *testing.T) { + imports := &Imports{} + imp, ok := imports.FindByPath("anyPath") + + assert.Nil(t, imp) + assert.False(t, ok) +} diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 745026a..7d06d47 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -36,7 +36,7 @@ var ( methods: map[string]functionInfo{ "Save": saveFunction, }, - stringValue: "any.Repository[T any,ID any]", + stringValue: "any.Repository[T any,ID any|string]", } numberInterface = interfaceInfo{ name: "Number", @@ -309,3 +309,8 @@ func assertInterfaceEmbeddedTypes(t *testing.T, interfaceName string, actualEmbe return true } + +func TestInterfaces_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + interfaces := &Interfaces{} + assert.Nil(t, interfaces.At(0)) +} diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 2358a35..068a51a 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -34,7 +34,7 @@ type structInfo struct { var ( controllerStruct = structInfo{ markers: markers.Values{}, - stringValue: "any.Controller[C context.Context,T any]", + stringValue: "any.Controller[C context.Context,T any|int]", fileName: "generics.go", isExported: true, position: Position{ diff --git a/visitor/visitor.go b/visitor/visitor.go index c1f8f49..b32f264 100644 --- a/visitor/visitor.go +++ b/visitor/visitor.go @@ -94,7 +94,12 @@ func EachFile(collector *markers.Collector, pkgs []*packages.Package, callback F markerValues, err := collector.Collect(pkg) if err != nil { - errs = append(errs, err.(markers.ErrorList)...) + switch typedErr := err.(type) { + case markers.ErrorList: + errs = append(errs, err.(markers.ErrorList)...) + default: + errs = append(errs, typedErr) + } continue } diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 2a71d9c..be93a70 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -1,6 +1,7 @@ package visitor import ( + "errors" "fmt" "github.com/procyon-projects/markers" "github.com/procyon-projects/markers/packages" @@ -51,6 +52,21 @@ func (v variableInfo) String() string { return v.typeName } +func TestEachFile_ShouldReturnErrorIfCollectorIsNil(t *testing.T) { + err := EachFile(nil, nil, nil) + assert.Equal(t, "collector cannot be nil", err.Error()) +} + +func TestEachFile_ShouldReturnErrorIfPkgsIsNil(t *testing.T) { + err := EachFile(&markers.Collector{}, nil, nil) + assert.Equal(t, "packages cannot be nil", err.Error()) +} + +func TestEachFile_ShouldReturnErrorIfTraversedPkgIsNil(t *testing.T) { + err := EachFile(&markers.Collector{}, []*packages.Package{nil}, nil) + assert.Equal(t, markers.ErrorList{errors.New("pkg(package) cannot be nil")}, err.(markers.ErrorList)) +} + func TestVisitor_VisitPackage(t *testing.T) { markerList := []struct { Name string From 4e6b16dbb0be9494fbc5106f75f22dad3c2659c6 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Sat, 8 Apr 2023 00:06:09 +0300 Subject: [PATCH 10/16] Add more unit tests --- argument_type.go | 2 +- argument_type_test.go | 74 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/argument_type.go b/argument_type.go index bd0acd5..1f1ac0b 100644 --- a/argument_type.go +++ b/argument_type.go @@ -223,7 +223,7 @@ func (typeInfo ArgumentTypeInfo) parseInteger(scanner *Scanner, out reflect.Valu } if !scanner.Expect(IntegerValue, "Integer") { - return nil + return fmt.Errorf("expected integer, got %q", scanner.Token()) } text := scanner.Token() diff --git a/argument_type_test.go b/argument_type_test.go index 15a2361..b86c81c 100644 --- a/argument_type_test.go +++ b/argument_type_test.go @@ -226,6 +226,33 @@ func TestArgumentTypeInfo_ParseBoolean(t *testing.T) { assert.False(t, boolValue) } +func TestArgumentTypeInfo_ParseBooleanShouldReturnErrorIfScannedTokenIsNotBooleanValue(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(true)) + assert.Nil(t, err) + assert.Equal(t, BoolType, typeInfo.ActualType) + + boolValue := false + + scanner := NewScanner("test") + scanner.Peek() + + err = typeInfo.parseBoolean(scanner, reflect.ValueOf(&boolValue)) + assert.NotNil(t, err) + assert.Equal(t, "expected true or false, got \"test\"", err.Error()) +} + +func TestArgumentTypeInfo_ParseBooleanShouldReturnErrorIfScannerIsNil(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(true)) + assert.Nil(t, err) + assert.Equal(t, BoolType, typeInfo.ActualType) + + boolValue := false + + err = typeInfo.parseBoolean(nil, reflect.ValueOf(&boolValue)) + assert.NotNil(t, err) + assert.Equal(t, "scanner cannot be nil", err.Error()) +} + func TestArgumentTypeInfo_ParseInteger(t *testing.T) { typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(0)) assert.Nil(t, err) @@ -300,6 +327,33 @@ func TestArgumentTypeInfo_ParseInteger(t *testing.T) { assert.Equal(t, uint(70519), unsignedIntegerValue) } +func TestArgumentTypeInfo_parseIntegerShouldReturnErrorIfScannedTokenIsNotIntegerValue(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(2)) + assert.Nil(t, err) + assert.Equal(t, SignedIntegerType, typeInfo.ActualType) + + signedIntegerValue := 0 + + scanner := NewScanner("test") + scanner.Peek() + + err = typeInfo.parseInteger(scanner, reflect.ValueOf(&signedIntegerValue)) + assert.NotNil(t, err) + assert.Equal(t, "expected integer, got \"test\"", err.Error()) +} + +func TestArgumentTypeInfo_parseIntegerShouldReturnErrorIfScannerIsNil(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(5)) + assert.Nil(t, err) + assert.Equal(t, SignedIntegerType, typeInfo.ActualType) + + signedIntegerValue := 0 + + err = typeInfo.parseInteger(nil, reflect.ValueOf(&signedIntegerValue)) + assert.NotNil(t, err) + assert.Equal(t, "scanner cannot be nil", err.Error()) +} + func TestArgumentTypeInfo_ParseMap(t *testing.T) { m := make(map[string]any) typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(&m)) @@ -336,6 +390,16 @@ func TestArgumentTypeInfo_ParseMap(t *testing.T) { assert.Equal(t, "anyValue2", m["anyKey4"]) } +func TestArgumentTypeInfo_parseMapShouldReturnErrorIfScannerIsNil(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(map[string]any{})) + assert.Nil(t, err) + assert.Equal(t, MapType, typeInfo.ActualType) + + err = typeInfo.parseMap(nil, reflect.ValueOf("")) + assert.NotNil(t, err) + assert.Equal(t, "scanner cannot be nil", err.Error()) +} + func TestArgumentTypeInfo_ParseSlice(t *testing.T) { s := make([]int, 0) typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(&s)) @@ -372,6 +436,16 @@ func TestArgumentTypeInfo_ParseSlice(t *testing.T) { assert.Equal(t, []int{1, 2, 3, 4, 5}, s) } +func TestArgumentTypeInfo_parseSliceShouldReturnErrorIfScannerIsNil(t *testing.T) { + typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf([]string{})) + assert.Nil(t, err) + assert.Equal(t, SliceType, typeInfo.ActualType) + + err = typeInfo.parseSlice(nil, reflect.ValueOf("")) + assert.NotNil(t, err) + assert.Equal(t, "scanner cannot be nil", err.Error()) +} + func TestArgumentTypeInfo_TypeInference(t *testing.T) { var value any typeInfo, err := ArgumentTypeInfoFromType(reflect.TypeOf(&value)) From f26f1e71823b59f99f3b01f2686f66e34a03a9fb Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Sun, 9 Apr 2023 00:15:32 +0300 Subject: [PATCH 11/16] Add more unit tests --- test/any/error.go | 2 -- test/any/generics.go | 4 +-- test/any/method.go | 2 +- test/any/other.go | 3 ++ test/menu/coffee.go | 2 ++ test/menu/dessert.go | 2 +- visitor/custom_type_test.go | 14 ++++++--- visitor/function.go | 8 +++-- visitor/function_test.go | 58 ++++++++++++++++++++++++++++++++++--- visitor/interface_test.go | 2 +- visitor/struct_test.go | 5 ++++ visitor/type.go | 8 +++-- visitor/visitor_test.go | 7 +++-- 13 files changed, 96 insertions(+), 21 deletions(-) create mode 100644 test/any/other.go diff --git a/test/any/error.go b/test/any/error.go index 7eae971..40cadf4 100644 --- a/test/any/error.go +++ b/test/any/error.go @@ -1,7 +1,5 @@ package any -type errorList []error - func (e errorList) Print() { } diff --git a/test/any/generics.go b/test/any/generics.go index da4e388..2955e21 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -10,7 +10,7 @@ func GenericFunction[K []map[T]X, T int | bool, X ~string](x []K) T { return value } -type Repository[T any, ID any | string] interface { +type Repository[T any, ID any | string | constraints.Ordered] interface { Save(entity T) T } @@ -32,7 +32,7 @@ type Number interface { ToString() } -type HttpHandler[C context.Context, K string | int] func(ctx C) K +type HttpHandler[C context.Context, K string | int, V constraints.Ordered | constraints.Complex] func(ctx C, value V) K type EventPublisher[E any] interface { Publish(e E) diff --git a/test/any/method.go b/test/any/method.go index 4037682..72ad6b4 100644 --- a/test/any/method.go +++ b/test/any/method.go @@ -1,5 +1,5 @@ package any -func (HttpHandler[C, K]) Print(ctx C) { +func (HttpHandler[C, K, V]) Print(ctx C, value V) { } diff --git a/test/any/other.go b/test/any/other.go new file mode 100644 index 0000000..2404fe4 --- /dev/null +++ b/test/any/other.go @@ -0,0 +1,3 @@ +package any + +type errorList []error diff --git a/test/menu/coffee.go b/test/menu/coffee.go index 25b394e..ace1cae 100644 --- a/test/menu/coffee.go +++ b/test/menu/coffee.go @@ -15,3 +15,5 @@ const ( func (c *cookie) PrintCookie(v interface{}) []string { return nil } + +type CustomBakeryShop BakeryShop diff --git a/test/menu/dessert.go b/test/menu/dessert.go index c5d0ada..abc4a76 100644 --- a/test/menu/dessert.go +++ b/test/menu/dessert.go @@ -85,7 +85,7 @@ type Dessert interface { // CupCake is a method // +marker:interface-method-level:Name=CupCake - CupCake(a []int, b bool) float32 + CupCake([]int, bool) float32 // Tart is a method // +marker:interface-method-level:Name=Tart diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index 24965e8..a37888f 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -54,6 +54,12 @@ var ( isExported: true, stringValue: "menu.Coffee", }, + "CustomBakeryShop": { + name: "CustomBakeryShop", + underlyingTypeName: "BakeryShop", + isExported: true, + stringValue: "menu.CustomBakeryShop", + }, } freshCustomTypes = map[string]customTypeInfo{ "Lemonade": { @@ -66,12 +72,12 @@ var ( genericsCustomTypes = map[string]customTypeInfo{ "HttpHandler": { name: "HttpHandler", - underlyingTypeName: "func (ctx C) K", + underlyingTypeName: "func (ctx C,value V) K", isExported: true, methods: map[string]functionInfo{ "Print": printHttpHandlerMethod, }, - stringValue: "any.HttpHandler[C context.Context,K string|int]", + stringValue: "any.HttpHandler[C context.Context,K string|int,V constraints.Ordered|constraints.Complex]", }, } ) @@ -108,8 +114,8 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy t.Errorf("custom type name in file %s shoud be %s, but got %s", file.name, expectedCustomTypeName, actualCustomType.Name()) } - if expectedCustomType.underlyingTypeName != actualCustomType.Underlying().String() { - t.Errorf("underlying type of custom type %s in file %s shoud be %s, but got %s", file.name, expectedCustomType.name, expectedCustomType.underlyingTypeName, actualCustomType.Underlying().String()) + if expectedCustomType.underlyingTypeName != actualCustomType.Underlying().Name() { + t.Errorf("underlying type of custom type %s in file %s shoud be %s, but got %s", file.name, expectedCustomType.name, expectedCustomType.underlyingTypeName, actualCustomType.Underlying().Name()) } if actualCustomType.IsExported() && !expectedCustomType.isExported { diff --git a/visitor/function.go b/visitor/function.go index b2b6e81..2033a2f 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -223,8 +223,8 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { case *ast.Ident: if typedReceiver.Obj == nil { receiverTypeName = typedReceiver.Name - unprocessedype := getTypeFromScope(receiverTypeName, f.visitor) - _, isStructMethod = unprocessedype.(*Struct) + unprocessedType := getTypeFromScope(receiverTypeName, f.visitor) + _, isStructMethod = unprocessedType.(*Struct) } else { receiverTypeSpec = typedReceiver.Obj.Decl.(*ast.TypeSpec) receiverTypeName = receiverTypeSpec.Name.Name @@ -270,6 +270,10 @@ func (f *Function) receiverType(receiverExpr ast.Expr) Type { } func (f *Function) Name() string { + if f.name == "" { + return f.String() + } + return f.name } diff --git a/visitor/function_test.go b/visitor/function_test.go index fa4fa00..54eab6e 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -274,6 +274,10 @@ var ( name: "ctx", typeName: "C", }, + { + name: "value", + typeName: "V", + }, }, results: []variableInfo{}, typeParams: []variableInfo{ @@ -285,6 +289,10 @@ var ( name: "K", typeName: "", }, + { + name: "V", + typeName: "", + }, }, } printErrorMethod = functionInfo{ @@ -292,7 +300,7 @@ var ( name: "Print", fileName: "error.go", position: Position{ - Line: 5, + Line: 3, Column: 1, }, isVariadic: false, @@ -316,7 +324,7 @@ var ( name: "ToErrors", fileName: "error.go", position: Position{ - Line: 12, + Line: 10, Column: 1, }, isVariadic: false, @@ -585,11 +593,11 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "a", + name: "", typeName: "[]int", }, { - name: "b", + name: "", typeName: "bool", }, }, @@ -928,6 +936,17 @@ func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actua if expectedFunctionParam.String() != actualFunctionParam.Type().Name() { t.Errorf("at index %d, the parameter type name of the %s should be %s, but got %s", index, msg, expectedFunctionParam.typeName, actualFunctionParam.Type().Name()) } + + var expectedFunctionParamString string + if expectedFunctionParam.name == "" { + expectedFunctionParamString = expectedFunctionParam.typeName + } else { + expectedFunctionParamString = fmt.Sprintf("%s %s", expectedFunctionParam.name, expectedFunctionParam.typeName) + } + + if expectedFunctionParamString != actualFunctionParam.String() { + t.Errorf("at index %d parameter, the String() method should return %s, but got %s", index, expectedFunctionParamString, actualFunctionParam.String()) + } } } @@ -956,3 +975,34 @@ func TestFunctions_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { functions := &Functions{} assert.Nil(t, functions.At(0)) } + +func TestParameters_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + parameters := &Parameters{} + assert.Nil(t, parameters.At(0)) +} + +func TestResults_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + results := &Results{} + assert.Nil(t, results.At(0)) +} + +func TestParameters_FindByNameShouldReturnFalseIfParameterNameDoesNotExist(t *testing.T) { + parameters := &Parameters{} + parameter, ok := parameters.FindByName("anyName") + assert.Nil(t, parameter) + assert.False(t, ok) +} + +func TestResults_FindByNameShouldReturnFalseIfResultNameDoesNotExist(t *testing.T) { + results := &Results{} + result, ok := results.FindByName("anyName") + assert.Nil(t, result) + assert.False(t, ok) +} + +func TestFunctions_FindByNameShouldReturnFalseIfFunctionNameDoesNotExist(t *testing.T) { + functions := &Functions{} + function, ok := functions.FindByName("anyName") + assert.Nil(t, function) + assert.False(t, ok) +} diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 7d06d47..8642830 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -36,7 +36,7 @@ var ( methods: map[string]functionInfo{ "Save": saveFunction, }, - stringValue: "any.Repository[T any,ID any|string]", + stringValue: "any.Repository[T any,ID any|string|constraints.Ordered]", } numberInterface = interfaceInfo{ name: "Number", diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 068a51a..0cbddeb 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -301,3 +301,8 @@ func assertStructFields(t *testing.T, structName string, actualFields *Fields, e return true } + +func TestStructs_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + structs := &Structs{} + assert.Nil(t, structs.At(0)) +} diff --git a/visitor/type.go b/visitor/type.go index 3c39cbb..806394b 100644 --- a/visitor/type.go +++ b/visitor/type.go @@ -69,13 +69,17 @@ func getTypeFromScope(name string, visitor *packageVisitor) Type { pkg := visitor.pkg typ := pkg.Types.Scope().Lookup(name) - typedName, ok := typ.Type().(*types.Named) + typedName, isNamedType := typ.Type().(*types.Named) if _, ok := visitor.collector.unprocessedTypes[pkg.ID]; !ok { visitor.collector.unprocessedTypes[pkg.ID] = make(map[string]Type) } - if ok { + if unprocessedType, ok := visitor.collector.unprocessedTypes[pkg.ID][name]; ok { + return unprocessedType + } + + if isNamedType { switch typedName.Underlying().(type) { case *types.Struct: structType := newStruct(nil, nil, nil, pkg, visitor, nil) diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index be93a70..5d835bf 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -134,13 +134,16 @@ func TestVisitor_VisitPackage(t *testing.T) { }, "github.com/procyon-projects/markers/test/any": { "error.go": { - constants: []constantInfo{}, - customTypes: errorCustomTypes, + constants: []constantInfo{}, functions: map[string]functionInfo{ "Print": printErrorMethod, "ToErrors": toErrorsMethod, }, }, + "other.go": { + constants: []constantInfo{}, + customTypes: errorCustomTypes, + }, "permission.go": { constants: permissionConstants, customTypes: permissionCustomTypes, From 2b9fe9419be621ad93fe81cee706a13f76d07bc8 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Thu, 20 Apr 2023 19:56:33 +0300 Subject: [PATCH 12/16] Add more unit tests --- test/any/permission.go | 2 +- visitor/file_test.go | 53 +++++++++++++++++++++++++++++++++++------ visitor/visitor_test.go | 34 +++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/test/any/permission.go b/test/any/permission.go index 851fac9..d37fc77 100644 --- a/test/any/permission.go +++ b/test/any/permission.go @@ -1,4 +1,4 @@ -// +import=marker, Pkg=github.com/procyon-projects/markers +// +import=marker, Pkg=github.com/procyon-projects/markers, Alias=test // +marker:package-level:Name=permission.go package any diff --git a/visitor/file_test.go b/visitor/file_test.go index 707777d..341f754 100644 --- a/visitor/file_test.go +++ b/visitor/file_test.go @@ -6,12 +6,15 @@ import ( ) type testFile struct { - constants []constantInfo - interfaces map[string]interfaceInfo - structs map[string]structInfo - functions map[string]functionInfo - imports []importInfo - customTypes map[string]customTypeInfo + path string + constants []constantInfo + interfaces map[string]interfaceInfo + structs map[string]structInfo + functions map[string]functionInfo + imports []importInfo + customTypes map[string]customTypeInfo + importMarkers []importMarkerInfo + fileMarkers []fileMarkerInfo } type importInfo struct { @@ -22,6 +25,14 @@ type importInfo struct { position Position } +type importMarkerInfo struct { + value string + pkg string + alias string +} + +type fileMarkerInfo any + func sideEffects(imports []importInfo) []importInfo { result := make([]importInfo, 0) for _, importItem := range imports { @@ -33,7 +44,7 @@ func sideEffects(imports []importInfo) []importInfo { return result } -func assertImports(t *testing.T, file *File, expectedImports []importInfo) bool { +func assertImports(t *testing.T, file *File, expectedImports []importInfo, expectedImportMarkers []importMarkerInfo, fileMarkers []fileMarkerInfo) bool { if file.Imports().Len() != len(expectedImports) { t.Errorf("the number of the imports in file %s should be %d, but got %d", file.name, len(expectedImports), file.Imports().Len()) } @@ -82,5 +93,33 @@ func assertImports(t *testing.T, file *File, expectedImports []importInfo) bool assert.Equal(t, expectedImport.position, actualImport.Position(), "position for import with path %s in file %s should be %w, but got %w", expectedImport.name, "", expectedImport.position, fileImport.Position()) } + if file.Markers().Count() != len(fileMarkers) { + t.Errorf("the number of the file markers in file %s should be %d, but got %d", file.name, len(fileMarkers), file.Markers().Count()) + } + + assertImportMarkers(t, file, expectedImportMarkers) + return true } + +func assertImportMarkers(t *testing.T, file *File, expectedImportMarkers []importMarkerInfo) { + + if file.NumImportMarkers() != len(expectedImportMarkers) { + t.Errorf("the number of the import markers in file %s should be %d, but got %d", file.name, len(expectedImportMarkers), len(file.ImportMarkers())) + } + + for index, importMarker := range file.ImportMarkers() { + expectedImportMarker := expectedImportMarkers[index] + if importMarker.Pkg != expectedImportMarker.pkg { + t.Errorf("the Pkg attribute of the import marker in file %s shoud be %s, but got %s", file.name, expectedImportMarker.pkg, importMarker.Pkg) + } + + if importMarker.Value != expectedImportMarker.value { + t.Errorf("the Value attribute of the import marker in file %s shoud be %s, but got %s", file.name, expectedImportMarker.value, importMarker.Value) + } + + if importMarker.Alias != expectedImportMarker.alias { + t.Errorf("the Alias attribute of the import marker in file %s shoud be %s, but got %s", file.name, expectedImportMarker.alias, importMarker.Alias) + } + } +} diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 5d835bf..cda81e1 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -85,17 +85,32 @@ func TestVisitor_VisitPackage(t *testing.T) { testCasePkgs := map[string]map[string]testFile{ "github.com/procyon-projects/markers/test/menu": { "coffee.go": { + path: "github.com/procyon-projects/markers/test/menu/coffee.go", constants: coffeeConstants, customTypes: coffeeCustomTypes, functions: map[string]functionInfo{ "PrintCookie": printCookieMethod, }, + importMarkers: []importMarkerInfo{ + { + pkg: "github.com/procyon-projects/markers", + value: "marker", + }, + }, }, "fresh.go": { + path: "github.com/procyon-projects/markers/test/menu/fresh.go", constants: freshConstants, customTypes: freshCustomTypes, + importMarkers: []importMarkerInfo{ + { + pkg: "github.com/procyon-projects/markers", + value: "marker", + }, + }, }, "dessert.go": { + path: "github.com/procyon-projects/markers/test/menu/dessert.go", imports: []importInfo{ { name: "", @@ -130,6 +145,12 @@ func TestVisitor_VisitPackage(t *testing.T) { "FriedCookie": friedCookieStruct, "cookie": cookieStruct, }, + importMarkers: []importMarkerInfo{ + { + pkg: "github.com/procyon-projects/markers", + value: "marker", + }, + }, }, }, "github.com/procyon-projects/markers/test/any": { @@ -147,6 +168,13 @@ func TestVisitor_VisitPackage(t *testing.T) { "permission.go": { constants: permissionConstants, customTypes: permissionCustomTypes, + importMarkers: []importMarkerInfo{ + { + pkg: "github.com/procyon-projects/markers", + value: "marker", + alias: "test", + }, + }, }, "math.go": { constants: mathConstants, @@ -229,7 +257,11 @@ func TestVisitor_VisitPackage(t *testing.T) { return nil } - if !assertImports(t, file, testCase.imports) { + if testCase.path != file.Path() { + t.Errorf("file path %s shoud be %s, but got %s", file.name, testCase.path, file.Path()) + } + + if !assertImports(t, file, testCase.imports, testCase.importMarkers, testCase.fileMarkers) { return nil } From 5070d33fec01b6602935c5c4b1f9c80444864532 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Thu, 20 Apr 2023 22:32:49 +0300 Subject: [PATCH 13/16] Add more unit tests --- target.go | 2 +- test/any/permission.go | 3 +- test/menu/coffee.go | 3 +- test/menu/dessert.go | 63 ++++++++++++++++---------------- test/menu/fresh.go | 3 +- visitor/constant_test.go | 38 +++++++++---------- visitor/file_test.go | 40 ++++++++++++++++++-- visitor/function_test.go | 68 +++++++++++++++++----------------- visitor/interface_test.go | 16 ++++---- visitor/struct_test.go | 8 ++-- visitor/visitor_test.go | 77 ++++++++++++++++++++++++++++++++------- 11 files changed, 204 insertions(+), 117 deletions(-) diff --git a/target.go b/target.go index 13061b3..3aa91a1 100644 --- a/target.go +++ b/target.go @@ -57,7 +57,7 @@ func FindTargetLevel(node ast.Node) TargetLevel { } else { return FunctionLevel } - case *ast.Package: + case *ast.Package, *ast.File: return PackageLevel } diff --git a/test/any/permission.go b/test/any/permission.go index d37fc77..2f873ad 100644 --- a/test/any/permission.go +++ b/test/any/permission.go @@ -1,5 +1,6 @@ // +import=marker, Pkg=github.com/procyon-projects/markers, Alias=test -// +marker:package-level:Name=permission.go +// +import=test-marker, Pkg=github.com/procyon-projects/test-markers +// +test-marker:package-level:Name=permission.go package any diff --git a/test/menu/coffee.go b/test/menu/coffee.go index ace1cae..702cc9b 100644 --- a/test/menu/coffee.go +++ b/test/menu/coffee.go @@ -1,5 +1,6 @@ // +import=marker, Pkg=github.com/procyon-projects/markers -// +marker:package-level:Name=coffee.go +// +import=test-marker, Pkg=github.com/procyon-projects/test-markers +// +test-marker:package-level:Name=coffee.go package menu diff --git a/test/menu/dessert.go b/test/menu/dessert.go index abc4a76..e9d3224 100644 --- a/test/menu/dessert.go +++ b/test/menu/dessert.go @@ -1,5 +1,6 @@ // +import=marker, Pkg=github.com/procyon-projects/markers -// +marker:package-level:Name=dessert.go +// +import=test-marker, Pkg=github.com/procyon-projects/test-markers +// +test-marker:package-level:Name=dessert.go package menu @@ -9,130 +10,130 @@ import ( ) // BakeryShop is an interface -// +marker:interface-type-level:Name=BakeryShop +// +test-marker:interface-type-level:Name=BakeryShop type BakeryShop interface { // Bread is a method - // +marker:interface-method-level:Name=Bread + // +test-marker:interface-method-level:Name=Bread Bread(i, k float64) struct{} // Dessert is an embedded interface - // +marker:interface-method-level:Name=Dessert + // +test-marker:interface-method-level:Name=Dessert Dessert } // Eat is a method -// +marker:struct-method-level:Name=Eat +// +test-marker:struct-method-level:Name=Eat func (c *FriedCookie) Eat() bool { return true } // FriedCookie is a struct -// +marker:struct-type-level:Name=FriedCookie +// +test-marker:struct-type-level:Name=FriedCookie type FriedCookie struct { // Cookie is an embedded struct - // +marker:interface-method-level:Name=Cookie + // +test-marker:interface-method-level:Name=Cookie cookie // ChocolateChip is a field - // +marker:struct-field-level:Name=CookieDough + // +test-marker:struct-field-level:Name=CookieDough cookieDough any } // Buy is a method -// +marker:struct-method-level:Name=Buy +// +test-marker:struct-method-level:Name=Buy func (c *FriedCookie) Buy(i int) { } // NewYearsEveCookie is an interface -// +marker:interface-type-level:Name=newYearsEveCookie +// +test-marker:interface-type-level:Name=newYearsEveCookie type newYearsEveCookie interface { // Funfetti is a method - // +marker:interface-method-level:Name=Funfetti + // +test-marker:interface-method-level:Name=Funfetti Funfetti(v rune) byte } // Cookie is a struct -// +marker:struct-type-level:Name=cookie, Any={key:"value"} +// +test-marker:struct-type-level:Name=cookie, Any={key:"value"} type cookie struct { // ChocolateChip is a field - // +marker:struct-field-level:Name=ChocolateChip + // +test-marker:struct-field-level:Name=ChocolateChip ChocolateChip string // tripleChocolateCookie is a field - // +marker:struct-field-level:Name=tripleChocolateCookie + // +test-marker:struct-field-level:Name=tripleChocolateCookie tripleChocolateCookie map[string]error } // FortuneCookie is a method -// +marker:struct-method-level:Name=FortuneCookie +// +test-marker:struct-method-level:Name=FortuneCookie func (c *cookie) FortuneCookie(v interface{}) []string { return nil } // Oreo is a method -// +marker:struct-method-level:Name=Oreo +// +test-marker:struct-method-level:Name=Oreo func (c *cookie) Oreo(a []interface{}, v ...string) error { return nil } // Dessert is an interface -// +marker:interface-type-level:Name=Dessert +// +test-marker:interface-type-level:Name=Dessert type Dessert interface { // IceCream is a method - // +marker:interface-method-level:Name=IceCream - // +marker:interface-type-level:Name=IceCream + // +test-marker:interface-method-level:Name=IceCream + // +test-marker:interface-type-level:Name=IceCream IceCream(s string, v ...bool) (r string) // CupCake is a method - // +marker:interface-method-level:Name=CupCake + // +test-marker:interface-method-level:Name=CupCake CupCake([]int, bool) float32 // Tart is a method - // +marker:interface-method-level:Name=Tart + // +test-marker:interface-method-level:Name=Tart Tart(s interface{}) // Donut is a method - // +marker:interface-method-level:Name=Donut + // +test-marker:interface-method-level:Name=Donut Donut() error // Pudding is a method - // +marker:interface-method-level:Name=Pudding + // +test-marker:interface-method-level:Name=Pudding Pudding() [5]string // Pie is a method - // +marker:interface-method-level:Name=Pie + // +test-marker:interface-method-level:Name=Pie Pie() interface{} // muffin is a method - // +marker:interface-method-level:Name=muffin + // +test-marker:interface-method-level:Name=muffin muffin() (*string, error) } // MakeACake is a function -// +marker:function-level:Name=MakeACake +// +test-marker:function-level:Name=MakeACake func MakeACake(s interface{}) error { return nil } // BiscuitCake is a function -// +marker:function-level:Name=BiscuitCake +// +test-marker:function-level:Name=BiscuitCake func BiscuitCake(s string, arr []int, v ...int16) (i int, b bool) { return } // SweetShop is an interface -// +marker:interface-type-level:Name=SweetShop +// +test-marker:interface-type-level:Name=SweetShop type SweetShop interface { // NewYearsEveCookie is an embedded interface - // +marker:interface-method-level:Name=NewYearsEveCookie + // +test-marker:interface-method-level:Name=NewYearsEveCookie newYearsEveCookie // Macaron is a method - // +marker:interface-method-level:Name=Macaron + // +test-marker:interface-method-level:Name=Macaron Macaron(c complex128) (chan string, fmt.Stringer) // Dessert is an embedded interface - // +marker:interface-method-level:Name=Dessert + // +test-marker:interface-method-level:Name=Dessert Dessert } diff --git a/test/menu/fresh.go b/test/menu/fresh.go index 052bdc9..e9f0092 100644 --- a/test/menu/fresh.go +++ b/test/menu/fresh.go @@ -1,5 +1,6 @@ // +import=marker, Pkg=github.com/procyon-projects/markers -// +marker:package-level:Name=fresh.go +// +import=test-marker, Pkg=github.com/procyon-projects/test-markers +// +test-marker:package-level:Name=fresh.go package menu diff --git a/visitor/constant_test.go b/visitor/constant_test.go index 9ec11f1..4cd6fd4 100644 --- a/visitor/constant_test.go +++ b/visitor/constant_test.go @@ -18,7 +18,7 @@ var ( { name: "Cappuccino", position: Position{ - Line: 9, + Line: 10, Column: 2, }, value: -1, @@ -28,7 +28,7 @@ var ( { name: "Americano", position: Position{ - Line: 10, + Line: 11, Column: 2, }, value: -2, @@ -38,7 +38,7 @@ var ( { name: "Latte", position: Position{ - Line: 11, + Line: 12, Column: 2, }, value: -3, @@ -48,7 +48,7 @@ var ( { name: "TurkishCoffee", position: Position{ - Line: 12, + Line: 13, Column: 2, }, value: -4, @@ -60,7 +60,7 @@ var ( { name: "ClassicLemonade", position: Position{ - Line: 9, + Line: 10, Column: 2, }, value: 0, @@ -70,7 +70,7 @@ var ( { name: "BlueberryLemonade", position: Position{ - Line: 10, + Line: 11, Column: 2, }, value: 1, @@ -80,7 +80,7 @@ var ( { name: "WatermelonLemonade", position: Position{ - Line: 11, + Line: 12, Column: 2, }, value: 2, @@ -90,7 +90,7 @@ var ( { name: "MangoLemonade", position: Position{ - Line: 12, + Line: 13, Column: 2, }, value: 3, @@ -100,7 +100,7 @@ var ( { name: "StrawberryLemonade", position: Position{ - Line: 13, + Line: 14, Column: 2, }, value: 4, @@ -134,7 +134,7 @@ var ( { name: "Read", position: Position{ - Line: 9, + Line: 10, Column: 2, }, value: 1, @@ -144,7 +144,7 @@ var ( { name: "Write", position: Position{ - Line: 10, + Line: 11, Column: 2, }, value: 2, @@ -154,7 +154,7 @@ var ( { name: "ReadWrite", position: Position{ - Line: 11, + Line: 12, Column: 2, }, value: 3, @@ -164,7 +164,7 @@ var ( { name: "RequestGet", position: Position{ - Line: 17, + Line: 18, Column: 2, }, value: "GET", @@ -174,7 +174,7 @@ var ( { name: "RequestPost", position: Position{ - Line: 18, + Line: 19, Column: 2, }, value: "POST", @@ -184,7 +184,7 @@ var ( { name: "RequestPatch", position: Position{ - Line: 19, + Line: 20, Column: 2, }, value: "PATCH", @@ -194,7 +194,7 @@ var ( { name: "RequestDelete", position: Position{ - Line: 20, + Line: 21, Column: 2, }, value: "DELETE", @@ -204,7 +204,7 @@ var ( { name: "SendDir", position: Position{ - Line: 26, + Line: 27, Column: 2, }, value: 2, @@ -214,7 +214,7 @@ var ( { name: "ReceiveDir", position: Position{ - Line: 27, + Line: 28, Column: 2, }, value: 1, @@ -224,7 +224,7 @@ var ( { name: "BothDir", position: Position{ - Line: 28, + Line: 29, Column: 2, }, value: 3, diff --git a/visitor/file_test.go b/visitor/file_test.go index 341f754..6b3e28b 100644 --- a/visitor/file_test.go +++ b/visitor/file_test.go @@ -93,19 +93,31 @@ func assertImports(t *testing.T, file *File, expectedImports []importInfo, expec assert.Equal(t, expectedImport.position, actualImport.Position(), "position for import with path %s in file %s should be %w, but got %w", expectedImport.name, "", expectedImport.position, fileImport.Position()) } - if file.Markers().Count() != len(fileMarkers) { - t.Errorf("the number of the file markers in file %s should be %d, but got %d", file.name, len(fileMarkers), file.Markers().Count()) - } - + assertFileMarkers(t, file, fileMarkers) assertImportMarkers(t, file, expectedImportMarkers) return true } +func assertFileMarkers(t *testing.T, file *File, expectedFileMarkers []fileMarkerInfo) { + + if file.Markers().Count() != len(expectedFileMarkers) { + t.Errorf("the number of the file markers in file %s should be %d, but got %d", file.name, len(expectedFileMarkers), file.Markers().Count()) + } + + index := 0 + for _, actualMarker := range file.Markers() { + expectedMarker := expectedFileMarkers[index].(PackageLevel) + assert.Equal(t, expectedMarker, actualMarker[0]) + index++ + } +} + func assertImportMarkers(t *testing.T, file *File, expectedImportMarkers []importMarkerInfo) { if file.NumImportMarkers() != len(expectedImportMarkers) { t.Errorf("the number of the import markers in file %s should be %d, but got %d", file.name, len(expectedImportMarkers), len(file.ImportMarkers())) + return } for index, importMarker := range file.ImportMarkers() { @@ -123,3 +135,23 @@ func assertImportMarkers(t *testing.T, file *File, expectedImportMarkers []impor } } } + +func TestFiles_FindByNameShouldReturnFileIfItExists(t *testing.T) { + files := &Files{ + elements: []*File{ + { + name: "test", + }, + }, + } + + file, ok := files.FindByName("test") + assert.True(t, ok) + assert.NotNil(t, file) + assert.Equal(t, "test", file.name) +} + +func TestFiles_AtShouldReturnNilIfIndexIsOutOfBound(t *testing.T) { + files := &Files{} + assert.Nil(t, files.At(-1)) +} diff --git a/visitor/function_test.go b/visitor/function_test.go index 54eab6e..042a9d1 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -234,7 +234,7 @@ var ( name: "PrintCookie", fileName: "coffee.go", position: Position{ - Line: 15, + Line: 16, Column: 1, }, isVariadic: false, @@ -380,7 +380,7 @@ var ( } breadFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Bread", }, @@ -389,7 +389,7 @@ var ( name: "Bread", fileName: "dessert.go", position: Position{ - Line: 16, + Line: 17, Column: 7, }, isVariadic: false, @@ -413,7 +413,7 @@ var ( macaronFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Macaron", }, @@ -422,7 +422,7 @@ var ( name: "Macaron", fileName: "dessert.go", position: Position{ - Line: 133, + Line: 134, Column: 9, }, isVariadic: false, @@ -446,7 +446,7 @@ var ( makeACakeFunction = functionInfo{ markers: markers.Values{ - "marker:function-level": { + "test-marker:function-level": { FunctionLevel{ Name: "MakeACake", }, @@ -455,7 +455,7 @@ var ( name: "MakeACake", fileName: "dessert.go", position: Position{ - Line: 113, + Line: 114, Column: 1, }, isVariadic: false, @@ -475,7 +475,7 @@ var ( biscuitCakeFunction = functionInfo{ markers: markers.Values{ - "marker:function-level": { + "test-marker:function-level": { FunctionLevel{ Name: "BiscuitCake", }, @@ -484,7 +484,7 @@ var ( name: "BiscuitCake", fileName: "dessert.go", position: Position{ - Line: 119, + Line: 120, Column: 1, }, isVariadic: true, @@ -516,7 +516,7 @@ var ( funfettiFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Funfetti", }, @@ -525,7 +525,7 @@ var ( name: "Funfetti", fileName: "dessert.go", position: Position{ - Line: 51, + Line: 52, Column: 10, }, isVariadic: false, @@ -545,7 +545,7 @@ var ( iceCreamFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "IceCream", }, @@ -554,7 +554,7 @@ var ( name: "IceCream", fileName: "dessert.go", position: Position{ - Line: 84, + Line: 85, Column: 10, }, isVariadic: true, @@ -578,7 +578,7 @@ var ( cupCakeFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "CupCake", }, @@ -587,7 +587,7 @@ var ( name: "CupCake", fileName: "dessert.go", position: Position{ - Line: 88, + Line: 89, Column: 9, }, isVariadic: false, @@ -611,7 +611,7 @@ var ( tartFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Tart", }, @@ -620,7 +620,7 @@ var ( name: "Tart", fileName: "dessert.go", position: Position{ - Line: 92, + Line: 93, Column: 6, }, isVariadic: false, @@ -635,7 +635,7 @@ var ( donutFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Donut", }, @@ -644,7 +644,7 @@ var ( name: "Donut", fileName: "dessert.go", position: Position{ - Line: 96, + Line: 97, Column: 7, }, isVariadic: false, @@ -659,7 +659,7 @@ var ( puddingFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Pudding", }, @@ -668,7 +668,7 @@ var ( name: "Pudding", fileName: "dessert.go", position: Position{ - Line: 100, + Line: 101, Column: 9, }, isVariadic: false, @@ -683,7 +683,7 @@ var ( pieFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "Pie", }, @@ -692,7 +692,7 @@ var ( name: "Pie", fileName: "dessert.go", position: Position{ - Line: 104, + Line: 105, Column: 5, }, isVariadic: false, @@ -707,7 +707,7 @@ var ( muffinFunction = functionInfo{ markers: markers.Values{ - "marker:interface-method-level": { + "test-marker:interface-method-level": { InterfaceMethodLevel{ Name: "muffin", }, @@ -716,7 +716,7 @@ var ( name: "muffin", fileName: "dessert.go", position: Position{ - Line: 108, + Line: 109, Column: 8, }, isVariadic: false, @@ -736,7 +736,7 @@ var ( eatMethod = functionInfo{ markers: markers.Values{ - "marker:struct-method-level": { + "test-marker:struct-method-level": { StructMethodLevel{ Name: "Eat", }, @@ -745,7 +745,7 @@ var ( name: "Eat", fileName: "dessert.go", position: Position{ - Line: 24, + Line: 25, Column: 1, }, receiver: &receiverInfo{ @@ -765,7 +765,7 @@ var ( buyMethod = functionInfo{ markers: markers.Values{ - "marker:struct-method-level": { + "test-marker:struct-method-level": { StructMethodLevel{ Name: "Buy", }, @@ -774,7 +774,7 @@ var ( name: "Buy", fileName: "dessert.go", position: Position{ - Line: 42, + Line: 43, Column: 1, }, receiver: &receiverInfo{ @@ -794,7 +794,7 @@ var ( fortuneCookieMethod = functionInfo{ markers: markers.Values{ - "marker:struct-method-level": { + "test-marker:struct-method-level": { StructMethodLevel{ Name: "FortuneCookie", }, @@ -803,7 +803,7 @@ var ( name: "FortuneCookie", fileName: "dessert.go", position: Position{ - Line: 67, + Line: 68, Column: 1, }, receiver: &receiverInfo{ @@ -828,7 +828,7 @@ var ( oreoMethod = functionInfo{ markers: markers.Values{ - "marker:struct-method-level": { + "test-marker:struct-method-level": { StructMethodLevel{ Name: "Oreo", }, @@ -837,7 +837,7 @@ var ( name: "Oreo", fileName: "dessert.go", position: Position{ - Line: 73, + Line: 74, Column: 1, }, receiver: &receiverInfo{ @@ -898,7 +898,7 @@ func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, t.Errorf("the function %s should not be a variadic function for %s", expectedMethodName, descriptor) } - // TODO Type Params + // TODO: Type Params typeParam := actualMethod.TypeParameters() if typeParam != nil { typeParam.Len() diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 8642830..c83cf0d 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -74,7 +74,7 @@ var ( } bakeryShopInterface = interfaceInfo{ markers: markers.Values{ - "marker:interface-type-level": { + "test-marker:interface-type-level": { InterfaceTypeLevel{ Name: "BakeryShop", }, @@ -84,7 +84,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 13, + Line: 14, Column: 6, }, explicitMethods: map[string]functionInfo{ @@ -107,7 +107,7 @@ var ( dessertInterface = interfaceInfo{ markers: markers.Values{ - "marker:interface-type-level": { + "test-marker:interface-type-level": { InterfaceTypeLevel{ Name: "Dessert", }, @@ -117,7 +117,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 79, + Line: 80, Column: 6, }, explicitMethods: map[string]functionInfo{ @@ -143,7 +143,7 @@ var ( newYearsEveCookieInterface = interfaceInfo{ markers: markers.Values{ - "marker:interface-type-level": { + "test-marker:interface-type-level": { InterfaceTypeLevel{ Name: "newYearsEveCookie", }, @@ -153,7 +153,7 @@ var ( fileName: "dessert.go", isExported: false, position: Position{ - Line: 48, + Line: 49, Column: 6, }, methods: map[string]functionInfo{ @@ -167,7 +167,7 @@ var ( sweetShopInterface = interfaceInfo{ markers: markers.Values{ - "marker:interface-type-level": { + "test-marker:interface-type-level": { InterfaceTypeLevel{ Name: "SweetShop", }, @@ -177,7 +177,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 125, + Line: 126, Column: 6, }, explicitMethods: map[string]functionInfo{ diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 0cbddeb..6115016 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -92,7 +92,7 @@ var ( friedCookieStruct = structInfo{ markers: markers.Values{ - "marker:struct-type-level": { + "test-marker:struct-type-level": { StructTypeLevel{ Name: "FriedCookie", }, @@ -102,7 +102,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 30, + Line: 31, Column: 6, }, methods: map[string]functionInfo{ @@ -142,7 +142,7 @@ var ( cookieStruct = structInfo{ markers: markers.Values{ - "marker:struct-type-level": { + "test-marker:struct-type-level": { StructTypeLevel{ Name: "cookie", Any: map[string]interface{}{ @@ -155,7 +155,7 @@ var ( fileName: "dessert.go", isExported: false, position: Position{ - Line: 56, + Line: 57, Column: 6, }, methods: map[string]functionInfo{ diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index cda81e1..103e961 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -6,6 +6,9 @@ import ( "github.com/procyon-projects/markers" "github.com/procyon-projects/markers/packages" "github.com/stretchr/testify/assert" + "path/filepath" + "runtime" + "strings" "testing" ) @@ -68,24 +71,29 @@ func TestEachFile_ShouldReturnErrorIfTraversedPkgIsNil(t *testing.T) { } func TestVisitor_VisitPackage(t *testing.T) { + _, file, _, _ := runtime.Caller(0) + path := filepath.Dir(file) + lastSlashIndex := strings.LastIndex(path, "/") + path = path[:lastSlashIndex] + markerList := []struct { Name string Level markers.TargetLevel Output interface{} }{ - {Name: "marker:package-level", Level: markers.PackageLevel, Output: &PackageLevel{}}, - {Name: "marker:interface-type-level", Level: markers.InterfaceTypeLevel, Output: &InterfaceTypeLevel{}}, - {Name: "marker:interface-method-level", Level: markers.InterfaceMethodLevel, Output: &InterfaceMethodLevel{}}, - {Name: "marker:function-level", Level: markers.FunctionLevel, Output: &FunctionLevel{}}, - {Name: "marker:struct-type-level", Level: markers.StructTypeLevel, Output: &StructTypeLevel{}}, - {Name: "marker:struct-method-level", Level: markers.StructMethodLevel, Output: &StructMethodLevel{}}, - {Name: "marker:struct-field-level", Level: markers.FieldLevel, Output: &StructFieldLevel{}}, + {Name: "test-marker:package-level", Level: markers.PackageLevel, Output: &PackageLevel{}}, + {Name: "test-marker:interface-type-level", Level: markers.InterfaceTypeLevel, Output: &InterfaceTypeLevel{}}, + {Name: "test-marker:interface-method-level", Level: markers.InterfaceMethodLevel, Output: &InterfaceMethodLevel{}}, + {Name: "test-marker:function-level", Level: markers.FunctionLevel, Output: &FunctionLevel{}}, + {Name: "test-marker:struct-type-level", Level: markers.StructTypeLevel, Output: &StructTypeLevel{}}, + {Name: "test-marker:struct-method-level", Level: markers.StructMethodLevel, Output: &StructMethodLevel{}}, + {Name: "test-marker:struct-field-level", Level: markers.FieldLevel, Output: &StructFieldLevel{}}, } testCasePkgs := map[string]map[string]testFile{ "github.com/procyon-projects/markers/test/menu": { "coffee.go": { - path: "github.com/procyon-projects/markers/test/menu/coffee.go", + path: fmt.Sprintf("%s/test/menu/coffee.go", path), constants: coffeeConstants, customTypes: coffeeCustomTypes, functions: map[string]functionInfo{ @@ -96,10 +104,19 @@ func TestVisitor_VisitPackage(t *testing.T) { pkg: "github.com/procyon-projects/markers", value: "marker", }, + { + pkg: "github.com/procyon-projects/test-markers", + value: "test-marker", + }, + }, + fileMarkers: []fileMarkerInfo{ + PackageLevel{ + Name: "coffee.go", + }, }, }, "fresh.go": { - path: "github.com/procyon-projects/markers/test/menu/fresh.go", + path: fmt.Sprintf("%s/test/menu/fresh.go", path), constants: freshConstants, customTypes: freshCustomTypes, importMarkers: []importMarkerInfo{ @@ -107,24 +124,33 @@ func TestVisitor_VisitPackage(t *testing.T) { pkg: "github.com/procyon-projects/markers", value: "marker", }, + { + pkg: "github.com/procyon-projects/test-markers", + value: "test-marker", + }, + }, + fileMarkers: []fileMarkerInfo{ + PackageLevel{ + Name: "fresh.go", + }, }, }, "dessert.go": { - path: "github.com/procyon-projects/markers/test/menu/dessert.go", + path: fmt.Sprintf("%s/test/menu/dessert.go", path), imports: []importInfo{ { name: "", path: "fmt", sideEffect: false, file: "dessert.go", - position: Position{Line: 7, Column: 2}, + position: Position{Line: 8, Column: 2}, }, { name: "_", path: "strings", sideEffect: true, file: "dessert.go", - position: Position{Line: 8, Column: 2}, + position: Position{Line: 9, Column: 2}, }, }, functions: map[string]functionInfo{ @@ -150,11 +176,21 @@ func TestVisitor_VisitPackage(t *testing.T) { pkg: "github.com/procyon-projects/markers", value: "marker", }, + { + pkg: "github.com/procyon-projects/test-markers", + value: "test-marker", + }, + }, + fileMarkers: []fileMarkerInfo{ + PackageLevel{ + Name: "dessert.go", + }, }, }, }, "github.com/procyon-projects/markers/test/any": { "error.go": { + path: fmt.Sprintf("%s/test/any/error.go", path), constants: []constantInfo{}, functions: map[string]functionInfo{ "Print": printErrorMethod, @@ -162,10 +198,12 @@ func TestVisitor_VisitPackage(t *testing.T) { }, }, "other.go": { + path: fmt.Sprintf("%s/test/any/other.go", path), constants: []constantInfo{}, customTypes: errorCustomTypes, }, "permission.go": { + path: fmt.Sprintf("%s/test/any/permission.go", path), constants: permissionConstants, customTypes: permissionCustomTypes, importMarkers: []importMarkerInfo{ @@ -174,12 +212,23 @@ func TestVisitor_VisitPackage(t *testing.T) { value: "marker", alias: "test", }, + { + pkg: "github.com/procyon-projects/test-markers", + value: "test-marker", + }, + }, + fileMarkers: []fileMarkerInfo{ + PackageLevel{ + Name: "permission.go", + }, }, }, "math.go": { + path: fmt.Sprintf("%s/test/any/math.go", path), constants: mathConstants, }, "generics.go": { + path: fmt.Sprintf("%s/test/any/generics.go", path), constants: []constantInfo{}, functions: map[string]functionInfo{ "GenericFunction": genericFunction, @@ -213,11 +262,13 @@ func TestVisitor_VisitPackage(t *testing.T) { customTypes: genericsCustomTypes, }, "method.go": { + path: fmt.Sprintf("%s/test/any/method.go", path), functions: map[string]functionInfo{ "Print": printHttpHandlerMethod, }, }, "string.go": { + path: fmt.Sprintf("%s/test/any/string.go", path), imports: []importInfo{ { name: "", @@ -236,7 +287,7 @@ func TestVisitor_VisitPackage(t *testing.T) { registry := markers.NewRegistry() for _, m := range markerList { - err := registry.Register(m.Name, "github.com/procyon-projects/markers", m.Level, m.Output) + err := registry.Register(m.Name, "github.com/procyon-projects/test-markers", m.Level, m.Output) if err != nil { t.Errorf("marker %s could not be registered", m.Name) return From 5edcd5f51d565b3eca50cba501ba92d55699cb89 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Fri, 21 Apr 2023 16:20:37 +0300 Subject: [PATCH 14/16] Add more unit tests --- test/any/generics.go | 6 +-- test/menu/dessert.go | 9 +++++ visitor/collector.go | 26 ++---------- visitor/constant.go | 2 +- visitor/custom_type.go | 6 --- visitor/custom_type_test.go | 4 ++ visitor/function.go | 2 - visitor/function_test.go | 56 +++++++++++++++++++------- visitor/interface.go | 2 - visitor/interface_test.go | 25 ++++++++++-- visitor/struct.go | 5 +-- visitor/struct_test.go | 79 +++++++++++++++++++++++++++++++++---- visitor/type.go | 5 ++- visitor/type_constraint.go | 2 +- visitor/visitor.go | 2 +- visitor/visitor_test.go | 1 + 16 files changed, 162 insertions(+), 70 deletions(-) diff --git a/test/any/generics.go b/test/any/generics.go index 2955e21..67eac6d 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -14,17 +14,17 @@ type Repository[T any, ID any | string | constraints.Ordered] interface { Save(entity T) T } -type Controller[C context.Context, T any | int] struct { +type Controller[C context.Context, T any | int, Y ~int] struct { AnyField1 string AnyField2 int } -func (c Controller[K, C]) Index(ctx K, h C) { +func (c Controller[K, C, Y]) Index(ctx K, h C) { } type TestController struct { - Controller[context.Context, int16] + Controller[context.Context, int16, int] } type Number interface { diff --git a/test/menu/dessert.go b/test/menu/dessert.go index e9d3224..edf09db 100644 --- a/test/menu/dessert.go +++ b/test/menu/dessert.go @@ -36,6 +36,10 @@ type FriedCookie struct { // ChocolateChip is a field // +test-marker:struct-field-level:Name=CookieDough cookieDough any + // This is an anonymous struct + anonymousStruct struct{} + // This is an empty interface + emptyInterface interface{} } // Buy is a method @@ -137,3 +141,8 @@ type SweetShop interface { // +test-marker:interface-method-level:Name=Dessert Dessert } + +// Meal is an interface +type Meal interface { + Eat() bool +} diff --git a/visitor/collector.go b/visitor/collector.go index 5f8f194..8f767ec 100644 --- a/visitor/collector.go +++ b/visitor/collector.go @@ -3,10 +3,9 @@ package visitor import "github.com/procyon-projects/markers/packages" type packageCollector struct { - hasSeen map[string]bool - hasProcessed map[string]bool - files map[string]*Files - packages map[string]*packages.Package + hasSeen map[string]bool + files map[string]*Files + packages map[string]*packages.Package unprocessedTypes map[string]map[string]Type @@ -16,7 +15,6 @@ type packageCollector struct { func newPackageCollector() *packageCollector { return &packageCollector{ hasSeen: make(map[string]bool), - hasProcessed: make(map[string]bool), files: make(map[string]*Files), packages: make(map[string]*packages.Package), unprocessedTypes: make(map[string]map[string]Type), @@ -24,18 +22,10 @@ func newPackageCollector() *packageCollector { } } -func (collector *packageCollector) getPackage(pkgId string) *packages.Package { - return collector.packages[pkgId] -} - func (collector *packageCollector) markAsSeen(pkgId string) { collector.hasSeen[pkgId] = true } -func (collector *packageCollector) markAsProcessed(pkgId string) { - collector.hasProcessed[pkgId] = true -} - func (collector *packageCollector) isVisited(pkgId string) bool { visited, ok := collector.hasSeen[pkgId] @@ -46,16 +36,6 @@ func (collector *packageCollector) isVisited(pkgId string) bool { return visited } -func (collector *packageCollector) isProcessed(pkgId string) bool { - processed, ok := collector.hasProcessed[pkgId] - - if !ok { - return false - } - - return processed -} - func (collector *packageCollector) addFile(pkgId string, file *File) { if _, ok := collector.files[pkgId]; !ok { collector.files[pkgId] = &Files{ diff --git a/visitor/constant.go b/visitor/constant.go index 729a5c3..ca7252b 100644 --- a/visitor/constant.go +++ b/visitor/constant.go @@ -47,7 +47,6 @@ func (c *Constant) evaluateExpression() { return } - // TODO: There might be some issues with const expressions. defer func() { if r := recover(); r != nil { } @@ -82,6 +81,7 @@ func (c *Constant) Underlying() Type { } func (c *Constant) String() string { + // TODO: complete string implementation return "" } diff --git a/visitor/custom_type.go b/visitor/custom_type.go index 5e90ec1..c0860fd 100644 --- a/visitor/custom_type.go +++ b/visitor/custom_type.go @@ -97,10 +97,6 @@ func (c *CustomType) Markers() markers.Values { return c.markers } -func (c *CustomType) SpecType() *ast.TypeSpec { - return c.specType -} - func (c *CustomType) String() string { var builder strings.Builder @@ -155,8 +151,6 @@ func (c *CustomType) loadTypeParams() { for _, item := range typeSets { if constraint, isConstraint := item.(*TypeConstraint); isConstraint { constraints = append(constraints, constraint) - } else { - constraints = append(constraints, &TypeConstraint{typ: item}) } } } else { diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index a37888f..62816d4 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -2,6 +2,7 @@ package visitor import ( "fmt" + "github.com/procyon-projects/markers" "github.com/stretchr/testify/assert" "testing" ) @@ -12,6 +13,7 @@ type customTypeInfo struct { isExported bool methods map[string]functionInfo stringValue string + markers markers.Values } var ( @@ -133,6 +135,8 @@ func assertCustomTypes(t *testing.T, file *File, customTypes map[string]customTy } assertFunctions(t, fmt.Sprintf("custom type %s", actualCustomType.Name()), actualCustomType.Methods(), expectedCustomType.methods) + assertMarkers(t, expectedCustomType.markers, actualCustomType.Markers(), fmt.Sprintf("type %s %s", expectedCustomTypeName, expectedCustomType.name)) + index++ } diff --git a/visitor/function.go b/visitor/function.go index 2033a2f..1a9fbdb 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -562,8 +562,6 @@ func (f *Function) loadTypeParams() { for _, item := range typeSets { if constraint, isConstraint := item.(*TypeConstraint); isConstraint { constraints = append(constraints, constraint) - } else { - constraints = append(constraints, &TypeConstraint{typ: item}) } } } else { diff --git a/visitor/function_test.go b/visitor/function_test.go index 042a9d1..1c341c9 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -211,6 +211,10 @@ var ( name: "C", typeName: "", }, + { + name: "Y", + typeName: "", + }, }, } publishMethod = functionInfo{ @@ -422,7 +426,7 @@ var ( name: "Macaron", fileName: "dessert.go", position: Position{ - Line: 134, + Line: 138, Column: 9, }, isVariadic: false, @@ -455,7 +459,7 @@ var ( name: "MakeACake", fileName: "dessert.go", position: Position{ - Line: 114, + Line: 118, Column: 1, }, isVariadic: false, @@ -484,7 +488,7 @@ var ( name: "BiscuitCake", fileName: "dessert.go", position: Position{ - Line: 120, + Line: 124, Column: 1, }, isVariadic: true, @@ -525,7 +529,7 @@ var ( name: "Funfetti", fileName: "dessert.go", position: Position{ - Line: 52, + Line: 56, Column: 10, }, isVariadic: false, @@ -554,7 +558,7 @@ var ( name: "IceCream", fileName: "dessert.go", position: Position{ - Line: 85, + Line: 89, Column: 10, }, isVariadic: true, @@ -587,7 +591,7 @@ var ( name: "CupCake", fileName: "dessert.go", position: Position{ - Line: 89, + Line: 93, Column: 9, }, isVariadic: false, @@ -620,7 +624,7 @@ var ( name: "Tart", fileName: "dessert.go", position: Position{ - Line: 93, + Line: 97, Column: 6, }, isVariadic: false, @@ -644,7 +648,7 @@ var ( name: "Donut", fileName: "dessert.go", position: Position{ - Line: 97, + Line: 101, Column: 7, }, isVariadic: false, @@ -668,7 +672,7 @@ var ( name: "Pudding", fileName: "dessert.go", position: Position{ - Line: 101, + Line: 105, Column: 9, }, isVariadic: false, @@ -692,7 +696,7 @@ var ( name: "Pie", fileName: "dessert.go", position: Position{ - Line: 105, + Line: 109, Column: 5, }, isVariadic: false, @@ -716,7 +720,7 @@ var ( name: "muffin", fileName: "dessert.go", position: Position{ - Line: 109, + Line: 113, Column: 8, }, isVariadic: false, @@ -774,7 +778,7 @@ var ( name: "Buy", fileName: "dessert.go", position: Position{ - Line: 43, + Line: 47, Column: 1, }, receiver: &receiverInfo{ @@ -803,7 +807,7 @@ var ( name: "FortuneCookie", fileName: "dessert.go", position: Position{ - Line: 68, + Line: 72, Column: 1, }, receiver: &receiverInfo{ @@ -837,7 +841,7 @@ var ( name: "Oreo", fileName: "dessert.go", position: Position{ - Line: 74, + Line: 78, Column: 1, }, receiver: &receiverInfo{ @@ -863,6 +867,24 @@ var ( }, }, } + + mealEatMethod = functionInfo{ + markers: markers.Values{}, + name: "Eat", + fileName: "dessert.go", + position: Position{ + Line: 147, + Column: 5, + }, + isVariadic: false, + params: []variableInfo{}, + results: []variableInfo{ + { + name: "", + typeName: "bool", + }, + }, + } ) func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, expectedMethods map[string]functionInfo) bool { @@ -909,6 +931,12 @@ func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, assert.Equal(t, expectedMethod.position, actualMethod.Position(), "the position of the function %s for %s should be %w, but got %w", expectedMethodName, descriptor, expectedMethod.position, actualMethod.Position()) + if expectedMethod.receiver != nil && actualMethod.Receiver() == nil { + t.Errorf("the function %s should have a recevier named %s", actualMethod.Name(), expectedMethod.receiver.name) + } else if expectedMethod.receiver == nil && actualMethod.Receiver() != nil { + t.Errorf("the function %s should not have a recevier named %s", actualMethod.Name(), actualMethod.Name()) + } + assertFunctionParameters(t, expectedMethod.params, actualMethod.Parameters(), fmt.Sprintf("function %s (%s)", expectedMethodName, descriptor)) assertFunctionResult(t, expectedMethod.results, actualMethod.Results(), fmt.Sprintf("function %s (%s)", expectedMethodName, descriptor)) diff --git a/visitor/interface.go b/visitor/interface.go index a0e6867..c072fc9 100644 --- a/visitor/interface.go +++ b/visitor/interface.go @@ -328,8 +328,6 @@ func (i *Interface) loadTypeParams() { for _, item := range typeSets { if constraint, isConstraint := item.(*TypeConstraint); isConstraint { constraints = append(constraints, constraint) - } else { - constraints = append(constraints, &TypeConstraint{typ: item}) } } } else { diff --git a/visitor/interface_test.go b/visitor/interface_test.go index c83cf0d..e253cc8 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -117,7 +117,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 80, + Line: 84, Column: 6, }, explicitMethods: map[string]functionInfo{ @@ -153,7 +153,7 @@ var ( fileName: "dessert.go", isExported: false, position: Position{ - Line: 49, + Line: 53, Column: 6, }, methods: map[string]functionInfo{ @@ -177,7 +177,7 @@ var ( fileName: "dessert.go", isExported: true, position: Position{ - Line: 126, + Line: 130, Column: 6, }, explicitMethods: map[string]functionInfo{ @@ -198,12 +198,29 @@ var ( embeddedInterfaces: []string{"newYearsEveCookie", "Dessert"}, stringValue: "menu.SweetShop", } + mealInterface = interfaceInfo{ + markers: markers.Values{}, + name: "Meal", + fileName: "dessert.go", + isExported: true, + position: Position{ + Line: 146, + Column: 6, + }, + explicitMethods: map[string]functionInfo{ + "Eat": mealEatMethod, + }, + methods: map[string]functionInfo{ + "Eat": mealEatMethod, + }, + stringValue: "menu.Meal", + } ) func assertInterfaces(t *testing.T, file *File, interfaces map[string]interfaceInfo) bool { if len(interfaces) != file.Interfaces().Len() { - t.Errorf("the number of the interface should be %d, but got %d", len(interfaces), file.Interfaces().Len()) + t.Errorf("the number of the interface in file %s should be %d, but got %d", file.Name(), len(interfaces), file.Interfaces().Len()) return false } diff --git a/visitor/struct.go b/visitor/struct.go index c8958ee..a3c05d2 100644 --- a/visitor/struct.go +++ b/visitor/struct.go @@ -358,8 +358,7 @@ func (s *Struct) getFieldsFromFieldList() []*Field { field := &Field{ name: name, isExported: ast.IsExported(name), - // TODO set position - position: Position{}, + position: getPosition(s.pkg, rawField.Pos()), markers: markers[rawField], file: s.file, tags: tags, @@ -506,8 +505,6 @@ func (s *Struct) loadTypeParams() { for _, item := range typeSets { if constraint, isConstraint := item.(*TypeConstraint); isConstraint { constraints = append(constraints, constraint) - } else { - constraints = append(constraints, &TypeConstraint{typ: item}) } } } else { diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 6115016..ebfc1dd 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -12,6 +12,7 @@ type fieldInfo struct { typeName string isExported bool isEmbeddedField bool + markers markers.Values } type structInfo struct { @@ -26,15 +27,16 @@ type structInfo struct { numFields int totalFields int numEmbeddedFields int - implements map[string]struct{} stringValue string + isAnonymous bool + interfaces []string } // structs var ( controllerStruct = structInfo{ markers: markers.Values{}, - stringValue: "any.Controller[C context.Context,T any|int]", + stringValue: "any.Controller[C context.Context,T any|int,Y ~int]", fileName: "generics.go", isExported: true, position: Position{ @@ -77,7 +79,13 @@ var ( allMethods: map[string]functionInfo{ "Index": indexMethod, }, - fields: map[string]fieldInfo{}, + fields: map[string]fieldInfo{ + "Controller": { + isExported: true, + isEmbeddedField: true, + typeName: "Controller", + }, + }, embeddedFields: map[string]fieldInfo{ "Controller": { isExported: true, @@ -126,6 +134,23 @@ var ( isExported: false, isEmbeddedField: false, typeName: "any", + markers: markers.Values{ + "test-marker:struct-field-level": { + StructFieldLevel{ + Name: "CookieDough", + }, + }, + }, + }, + "anonymousStruct": { + isExported: false, + isEmbeddedField: false, + typeName: "struct{}", + }, + "emptyInterface": { + isExported: false, + isEmbeddedField: false, + typeName: "interface{}", }, }, embeddedFields: map[string]fieldInfo{ @@ -135,8 +160,9 @@ var ( typeName: "cookie", }, }, - numFields: 2, - totalFields: 3, + interfaces: []string{"Meal"}, + numFields: 4, + totalFields: 5, numEmbeddedFields: 1, } @@ -155,7 +181,7 @@ var ( fileName: "dessert.go", isExported: false, position: Position{ - Line: 57, + Line: 61, Column: 6, }, methods: map[string]functionInfo{ @@ -173,11 +199,25 @@ var ( isExported: true, isEmbeddedField: false, typeName: "string", + markers: markers.Values{ + "test-marker:struct-field-level": { + StructFieldLevel{ + Name: "ChocolateChip", + }, + }, + }, }, "tripleChocolateCookie": { isExported: false, isEmbeddedField: false, typeName: "map[string]error", + markers: markers.Values{ + "test-marker:struct-field-level": { + StructFieldLevel{ + Name: "tripleChocolateCookie", + }, + }, + }, }, }, embeddedFields: map[string]fieldInfo{}, @@ -221,6 +261,12 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool t.Errorf("struct with name %s is not exported, but should be exported", actualStruct.Name()) } + if actualStruct.IsAnonymous() && !expectedStruct.isAnonymous { + t.Errorf("struct with name %s is anonymous, but should be anonymous", actualStruct.Name()) + } else if !actualStruct.IsAnonymous() && expectedStruct.isAnonymous { + t.Errorf("struct with name %s is not anonymous, but should be anonymous", actualStruct.Name()) + } + if actualStruct.NumFields() == 0 && !actualStruct.IsEmpty() { t.Errorf("the struct %s should be empty", actualStruct.Name()) } else if actualStruct.NumFields() != 0 && actualStruct.IsEmpty() { @@ -240,7 +286,7 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool } if actualStruct.NumFieldsInHierarchy() != expectedStruct.totalFields { - t.Errorf("the number of the all fields of the struct %s should be %d, but got %d", expectedStructName, expectedStruct.totalFields, actualStruct.NumFields()) + t.Errorf("the number of the all fields of the struct %s should be %d, but got %d", expectedStructName, expectedStruct.totalFields, actualStruct.NumFieldsInHierarchy()) } if actualStruct.NumEmbeddedFields() != expectedStruct.numEmbeddedFields { @@ -262,6 +308,19 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool assertStructFields(t, actualStruct.Name(), actualStruct.Fields(), expectedStruct.fields) assertMarkers(t, expectedStruct.markers, actualStruct.Markers(), fmt.Sprintf("struct %s", expectedStructName)) + for _, interfaceName := range expectedStruct.interfaces { + iface, exists := file.Interfaces().FindByName(interfaceName) + + if !exists { + t.Errorf("the interface %s should exists in file %s and the struct %s should implement it", interfaceName, file.Name(), actualStruct.Name()) + continue + } + + if !actualStruct.Implements(iface) { + t.Errorf(" the struct %s should implement the interface %s", actualStruct.Name(), interfaceName) + continue + } + } index++ } @@ -269,6 +328,10 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool } func assertStructFields(t *testing.T, structName string, actualFields *Fields, expectedFields map[string]fieldInfo) bool { + if actualFields.Len() != len(expectedFields) { + t.Errorf("the number of the fields of struct %s should be %d, but got %d", structName, len(expectedFields), actualFields.Len()) + return false + } for expectedFieldName, expectedField := range expectedFields { actualField, ok := actualFields.FindByName(expectedFieldName) @@ -297,6 +360,8 @@ func assertStructFields(t *testing.T, structName string, actualFields *Fields, e } else if !actualField.IsEmbedded() && expectedField.isEmbeddedField { t.Errorf("field with name %s for struct %s is not embedded, but should be embedded field", expectedFieldName, structName) } + + assertMarkers(t, expectedField.markers, actualField.Markers(), fmt.Sprintf("field %s in struct %s", expectedFieldName, structName)) } return true diff --git a/visitor/type.go b/visitor/type.go index 806394b..4f83647 100644 --- a/visitor/type.go +++ b/visitor/type.go @@ -135,6 +135,7 @@ func collectTypeFromTypeSpec(typeSpec *ast.TypeSpec, visitor *packageVisitor) Ty file.customTypes.elements = append(file.customTypes.elements, t) } t.initialize(typeSpec, file, pkg) + } t.markers = visitor.packageMarkers[typeSpec] return t @@ -188,7 +189,7 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor, o if field, isField := typed.Obj.Decl.(*ast.Field); isField { if typeParameters == nil { - // TODO return invalid type + //TODO: return invalid type return nil } @@ -196,7 +197,7 @@ func getTypeFromExpression(expr ast.Expr, file *File, visitor *packageVisitor, o return typeParameter } - // TODO return invalid type + //TODO: return invalid type return nil } diff --git a/visitor/type_constraint.go b/visitor/type_constraint.go index 35610cb..d7e856c 100644 --- a/visitor/type_constraint.go +++ b/visitor/type_constraint.go @@ -20,7 +20,7 @@ func (c *TypeConstraint) Underlying() Type { } func (c *TypeConstraint) Satisfy(t Type) bool { - // TODO implement this method + //TODO: implement this method return false } diff --git a/visitor/visitor.go b/visitor/visitor.go index b32f264..5fddd4f 100644 --- a/visitor/visitor.go +++ b/visitor/visitor.go @@ -113,7 +113,7 @@ func EachFile(collector *markers.Collector, pkgs []*packages.Package, callback F pkgCollector := newPackageCollector() for _, pkg := range pkgs { - if !pkgCollector.isVisited(pkg.ID) || !pkgCollector.isProcessed(pkg.ID) { + if !pkgCollector.isVisited(pkg.ID) { visitPackage(pkg, pkgCollector, packageMarkers) } } diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 103e961..c08b77a 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -166,6 +166,7 @@ func TestVisitor_VisitPackage(t *testing.T) { "Dessert": dessertInterface, "newYearsEveCookie": newYearsEveCookieInterface, "SweetShop": sweetShopInterface, + "Meal": mealInterface, }, structs: map[string]structInfo{ "FriedCookie": friedCookieStruct, From 23c9c473c3d8bbce973b5e8b98451e8fbe8d5786 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Sat, 22 Apr 2023 18:07:35 +0300 Subject: [PATCH 15/16] Add more unit tests --- test/any/custom.go | 5 ++ test/any/generics.go | 4 +- test/any/method.go | 2 +- visitor/custom_type_test.go | 5 +- visitor/function_test.go | 103 ++++++++++++++++++++++++++++++++++++ visitor/interface_test.go | 4 +- visitor/struct_test.go | 18 ++++++- visitor/visitor_test.go | 7 +++ 8 files changed, 140 insertions(+), 8 deletions(-) create mode 100644 test/any/custom.go diff --git a/test/any/custom.go b/test/any/custom.go new file mode 100644 index 0000000..0e09fa1 --- /dev/null +++ b/test/any/custom.go @@ -0,0 +1,5 @@ +package any + +func (HttpHandler[Z, K, V, M]) CustomMethod(ctx Z, value V) { + +} diff --git a/test/any/generics.go b/test/any/generics.go index 67eac6d..df957e2 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -32,8 +32,8 @@ type Number interface { ToString() } -type HttpHandler[C context.Context, K string | int, V constraints.Ordered | constraints.Complex] func(ctx C, value V) K +type HttpHandler[C context.Context, K string | int, V constraints.Ordered | constraints.Complex, M ~string] func(ctx C, value V) K -type EventPublisher[E any] interface { +type EventPublisher[E any, ID ~int] interface { Publish(e E) } diff --git a/test/any/method.go b/test/any/method.go index 72ad6b4..cac0c7e 100644 --- a/test/any/method.go +++ b/test/any/method.go @@ -1,5 +1,5 @@ package any -func (HttpHandler[C, K, V]) Print(ctx C, value V) { +func (HttpHandler[C, K, V, M]) Print(ctx C, value V) { } diff --git a/visitor/custom_type_test.go b/visitor/custom_type_test.go index 62816d4..e975ff0 100644 --- a/visitor/custom_type_test.go +++ b/visitor/custom_type_test.go @@ -77,9 +77,10 @@ var ( underlyingTypeName: "func (ctx C,value V) K", isExported: true, methods: map[string]functionInfo{ - "Print": printHttpHandlerMethod, + "Print": printHttpHandlerMethod, + "CustomMethod": customHttpHandlerMethod, }, - stringValue: "any.HttpHandler[C context.Context,K string|int,V constraints.Ordered|constraints.Complex]", + stringValue: "any.HttpHandler[C context.Context,K string|int,V constraints.Ordered|constraints.Complex,M ~string]", }, } ) diff --git a/visitor/function_test.go b/visitor/function_test.go index 1c341c9..7a22d86 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -297,6 +297,10 @@ var ( name: "V", typeName: "", }, + { + name: "M", + typeName: "", + }, }, } printErrorMethod = functionInfo{ @@ -885,6 +889,50 @@ var ( }, }, } + + customHttpHandlerMethod = functionInfo{ + markers: markers.Values{}, + name: "CustomMethod", + fileName: "custom.go", + position: Position{ + Line: 3, + Column: 1, + }, + isVariadic: false, + receiver: &receiverInfo{ + isPointer: false, + typeName: "HttpHandler", + }, + params: []variableInfo{ + { + name: "ctx", + typeName: "Z", + }, + { + name: "value", + typeName: "V", + }, + }, + results: []variableInfo{}, + typeParams: []variableInfo{ + { + name: "Z", + typeName: "", + }, + { + name: "K", + typeName: "", + }, + { + name: "V", + typeName: "", + }, + { + name: "M", + typeName: "", + }, + }, + } ) func assertFunctions(t *testing.T, descriptor string, actualMethods *Functions, expectedMethods map[string]functionInfo) bool { @@ -1021,6 +1069,20 @@ func TestParameters_FindByNameShouldReturnFalseIfParameterNameDoesNotExist(t *te assert.False(t, ok) } +func TestParameters_FindByNameShouldReturnIfParameterWithGivenNameExist(t *testing.T) { + parameters := &Parameters{ + elements: []*Parameter{ + { + name: "anyName", + }, + }, + } + parameter, ok := parameters.FindByName("anyName") + assert.NotNil(t, parameter) + assert.True(t, ok) + assert.Equal(t, "anyName", parameter.Name()) +} + func TestResults_FindByNameShouldReturnFalseIfResultNameDoesNotExist(t *testing.T) { results := &Results{} result, ok := results.FindByName("anyName") @@ -1028,9 +1090,50 @@ func TestResults_FindByNameShouldReturnFalseIfResultNameDoesNotExist(t *testing. assert.False(t, ok) } +func TestResults_FindByNameShouldReturnIfResultWithGivenNameExist(t *testing.T) { + results := &Results{ + elements: []*Result{ + { + name: "anyName", + }, + }, + } + result, ok := results.FindByName("anyName") + assert.NotNil(t, result) + assert.True(t, ok) + assert.Equal(t, "anyName", result.Name()) +} + func TestFunctions_FindByNameShouldReturnFalseIfFunctionNameDoesNotExist(t *testing.T) { functions := &Functions{} function, ok := functions.FindByName("anyName") assert.Nil(t, function) assert.False(t, ok) } + +func TestFunctions_AtShouldReturnIfFunctionWithGivenIndexExist(t *testing.T) { + functions := &Functions{ + elements: []*Function{ + { + name: "anyName", + }, + }, + } + function := functions.At(0) + assert.NotNil(t, function) + assert.Equal(t, "anyName", function.Name()) +} + +func TestFunctions_FindByNameShouldReturnIfFunctionWithGivenNameExist(t *testing.T) { + functions := &Functions{ + elements: []*Function{ + { + name: "anyName", + }, + }, + } + function, ok := functions.FindByName("anyName") + assert.NotNil(t, function) + assert.True(t, ok) + assert.Equal(t, "anyName", function.Name()) +} diff --git a/visitor/interface_test.go b/visitor/interface_test.go index e253cc8..3a8eeb7 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -70,7 +70,7 @@ var ( methods: map[string]functionInfo{ "Publish": publishMethod, }, - stringValue: "any.EventPublisher[E any]", + stringValue: "any.EventPublisher[E any,ID ~int]", } bakeryShopInterface = interfaceInfo{ markers: markers.Values{ @@ -279,7 +279,7 @@ func assertInterfaces(t *testing.T, file *File, interfaces map[string]interfaceI assert.Equal(t, expectedInterface.position, actualInterface.Position(), "the position of the interface %s should be %w, but got %w", expectedInterfaceName, expectedInterface.position, actualInterface.Position()) - // TODO fix + //TODO: fix actualInterface.IsConstraint() actualInterface.EmbeddedInterfaces() actualInterface.EmbeddedTypes() diff --git a/visitor/struct_test.go b/visitor/struct_test.go index ebfc1dd..065ff73 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -10,6 +10,7 @@ import ( type fieldInfo struct { name string typeName string + stringValue string isExported bool isEmbeddedField bool markers markers.Values @@ -54,11 +55,13 @@ var ( isExported: true, isEmbeddedField: false, typeName: "string", + stringValue: "string", }, "AnyField2": { isExported: true, isEmbeddedField: false, typeName: "int", + stringValue: "int", }, }, embeddedFields: map[string]fieldInfo{}, @@ -84,6 +87,7 @@ var ( isExported: true, isEmbeddedField: true, typeName: "Controller", + stringValue: "Controller[context.Context,int16,int]", }, }, embeddedFields: map[string]fieldInfo{ @@ -91,6 +95,7 @@ var ( isExported: true, isEmbeddedField: true, typeName: "Controller", + stringValue: "Controller[context.Context,int16,int]", }, }, numFields: 1, @@ -129,11 +134,13 @@ var ( isExported: false, isEmbeddedField: true, typeName: "cookie", + stringValue: "menu.cookie", }, "cookieDough": { isExported: false, isEmbeddedField: false, typeName: "any", + stringValue: "any", markers: markers.Values{ "test-marker:struct-field-level": { StructFieldLevel{ @@ -146,11 +153,13 @@ var ( isExported: false, isEmbeddedField: false, typeName: "struct{}", + stringValue: "struct{}", }, "emptyInterface": { isExported: false, isEmbeddedField: false, typeName: "interface{}", + stringValue: "interface{}", }, }, embeddedFields: map[string]fieldInfo{ @@ -158,6 +167,7 @@ var ( isExported: false, isEmbeddedField: true, typeName: "cookie", + stringValue: "menu.cookie", }, }, interfaces: []string{"Meal"}, @@ -199,6 +209,7 @@ var ( isExported: true, isEmbeddedField: false, typeName: "string", + stringValue: "string", markers: markers.Values{ "test-marker:struct-field-level": { StructFieldLevel{ @@ -211,6 +222,7 @@ var ( isExported: false, isEmbeddedField: false, typeName: "map[string]error", + stringValue: "map[string]error", markers: markers.Values{ "test-marker:struct-field-level": { StructFieldLevel{ @@ -346,7 +358,11 @@ func assertStructFields(t *testing.T, structName string, actualFields *Fields, e } if actualField.Type().Name() != expectedField.typeName { - t.Errorf("type of field with name %s for struct %s shoud be %s, but got %s", actualField.Name(), structName, expectedField.typeName, actualField.Type().Name()) + t.Errorf("type of field with name %s for struct %s shoud be '%s', but got %s", actualField.Name(), structName, expectedField.typeName, actualField.Type().Name()) + } + + if actualField.Type().String() != expectedField.stringValue { + t.Errorf("String() result shoud be '%s' for the field with name %s in struct %s, but got %s", expectedField.stringValue, actualField.Name(), structName, actualField.Type().String()) } if actualField.IsExported() && !expectedField.isExported { diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index c08b77a..3a71e15 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -190,6 +190,13 @@ func TestVisitor_VisitPackage(t *testing.T) { }, }, "github.com/procyon-projects/markers/test/any": { + "custom.go": { + path: fmt.Sprintf("%s/test/any/custom.go", path), + constants: []constantInfo{}, + functions: map[string]functionInfo{ + "CustomMethod": customHttpHandlerMethod, + }, + }, "error.go": { path: fmt.Sprintf("%s/test/any/error.go", path), constants: []constantInfo{}, From e78b2309c7778759a4125a0f00e313d0d38c1e77 Mon Sep 17 00:00:00 2001 From: Burak Koken Date: Sun, 23 Apr 2023 21:55:11 +0300 Subject: [PATCH 16/16] Add more unit tests --- test/any/custom.go | 4 +- test/any/generics.go | 6 +- visitor/function.go | 43 +++++---- visitor/function_test.go | 193 ++++++++++++++++++++++++-------------- visitor/interface_test.go | 6 +- visitor/struct_test.go | 106 ++++++++++++++++----- visitor/type_test.go | 47 ++++++++++ visitor/visitor_test.go | 20 +++- 8 files changed, 305 insertions(+), 120 deletions(-) create mode 100644 visitor/type_test.go diff --git a/test/any/custom.go b/test/any/custom.go index 0e09fa1..21d1140 100644 --- a/test/any/custom.go +++ b/test/any/custom.go @@ -1,5 +1,7 @@ package any -func (HttpHandler[Z, K, V, M]) CustomMethod(ctx Z, value V) { +import "net/http" +func (HttpHandler[Z, K, V, M]) CustomMethod(ctx Z, value V, req http.Request) http.Response { + return http.Response{} } diff --git a/test/any/generics.go b/test/any/generics.go index df957e2..c1e677b 100644 --- a/test/any/generics.go +++ b/test/any/generics.go @@ -10,7 +10,7 @@ func GenericFunction[K []map[T]X, T int | bool, X ~string](x []K) T { return value } -type Repository[T any, ID any | string | constraints.Ordered] interface { +type Repository[T any, ID any | string | constraints.Ordered | int | float32] interface { Save(entity T) T } @@ -24,6 +24,7 @@ func (c Controller[K, C, Y]) Index(ctx K, h C) { } type TestController struct { + BaseController[int] Controller[context.Context, int16, int] } @@ -37,3 +38,6 @@ type HttpHandler[C context.Context, K string | int, V constraints.Ordered | cons type EventPublisher[E any, ID ~int] interface { Publish(e E) } + +type BaseController[M any] struct { +} diff --git a/visitor/function.go b/visitor/function.go index 1a9fbdb..fb4030a 100644 --- a/visitor/function.go +++ b/visitor/function.go @@ -23,11 +23,21 @@ func (p *Parameter) Type() Type { } func (p *Parameter) String() string { + _, isTypeParameter := p.Type().(*TypeParameter) + if p.name == "" { - return p.typ.Name() + if isTypeParameter { + return p.typ.Name() + } + + return p.typ.String() + } + + if isTypeParameter { + return fmt.Sprintf("%s %s", p.name, p.typ.Name()) } - return fmt.Sprintf("%s %s", p.name, p.typ.Name()) + return fmt.Sprintf("%s %s", p.name, p.typ.String()) } type Parameters struct { @@ -70,11 +80,21 @@ func (r *Result) Type() Type { } func (r *Result) String() string { + _, isTypeParameter := r.Type().(*TypeParameter) + if r.name == "" { - return r.typ.Name() + if isTypeParameter { + return r.typ.Name() + } + + return r.typ.String() } - return fmt.Sprintf("%s %s", r.name, r.typ.Name()) + if isTypeParameter { + return fmt.Sprintf("%s %s", r.name, r.typ.Name()) + } + + return fmt.Sprintf("%s %s", r.name, r.typ.String()) } type Results struct { @@ -438,20 +458,7 @@ func (f *Function) String() string { if f.Parameters().Len() != 0 { for i := 0; i < f.Parameters().Len(); i++ { param := f.Parameters().At(i) - builder.WriteString(param.Name()) - - if i == f.params.Len()-1 && f.IsVariadic() { - if param.name != "" { - builder.WriteString(" ") - } - builder.WriteString("...") - builder.WriteString(param.Type().Name()) - } else { - if param.name != "" { - builder.WriteString(" ") - } - builder.WriteString(param.Type().Name()) - } + builder.WriteString(param.String()) if i != f.Parameters().Len()-1 { builder.WriteString(",") diff --git a/visitor/function_test.go b/visitor/function_test.go index 7a22d86..c4c784f 100644 --- a/visitor/function_test.go +++ b/visitor/function_test.go @@ -76,6 +76,11 @@ func (f functionInfo) String() string { builder.WriteString(param.name + " ") } + if param.importPackage != "" { + builder.WriteString(param.importPackage) + builder.WriteByte('.') + } + builder.WriteString(param.typeName) if i != len(f.typeParams)-1 { @@ -101,6 +106,12 @@ func (f functionInfo) String() string { if param.isPointer { builder.WriteString("*") } + + if param.importPackage != "" { + builder.WriteString(param.importPackage) + builder.WriteByte('.') + } + builder.WriteString(param.typeName) if i != len(f.params)-1 { @@ -128,6 +139,12 @@ func (f functionInfo) String() string { if result.isPointer { builder.WriteString("*") } + + if result.importPackage != "" { + builder.WriteString(result.importPackage) + builder.WriteByte('.') + } + builder.WriteString(result.typeName) if i != len(f.results)-1 { @@ -156,8 +173,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "entity", - typeName: "T", + name: "entity", + typeName: "T", + stringValue: "entity T", }, }, results: []variableInfo{ @@ -171,7 +189,7 @@ var ( name: "ToString", fileName: "generics.go", position: Position{ - Line: 32, + Line: 33, Column: 10, }, isVariadic: false, @@ -194,12 +212,14 @@ var ( }, params: []variableInfo{ { - name: "ctx", - typeName: "K", + name: "ctx", + typeName: "K", + stringValue: "ctx K", }, { - name: "h", - typeName: "C", + name: "h", + typeName: "C", + stringValue: "h C", }, }, typeParams: []variableInfo{ @@ -222,14 +242,15 @@ var ( name: "Publish", fileName: "generics.go", position: Position{ - Line: 38, + Line: 39, Column: 9, }, isVariadic: false, params: []variableInfo{ { - name: "e", - typeName: "E", + name: "e", + typeName: "E", + stringValue: "e E", }, }, } @@ -249,8 +270,9 @@ var ( }, params: []variableInfo{ { - name: "v", - typeName: "interface{}", + name: "v", + typeName: "interface{}", + stringValue: "v interface{}", }, }, results: []variableInfo{ @@ -275,12 +297,14 @@ var ( }, params: []variableInfo{ { - name: "ctx", - typeName: "C", + name: "ctx", + typeName: "C", + stringValue: "ctx C", }, { - name: "value", - typeName: "V", + name: "value", + typeName: "V", + stringValue: "value V", }, }, results: []variableInfo{}, @@ -361,8 +385,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "x", - typeName: "[]K", + name: "x", + typeName: "[]K", + stringValue: "x []K", }, }, results: []variableInfo{ @@ -403,12 +428,14 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "i", - typeName: "float64", + name: "i", + typeName: "float64", + stringValue: "i float64", }, { - name: "k", - typeName: "float64", + name: "k", + typeName: "float64", + stringValue: "k float64", }, }, results: []variableInfo{ @@ -436,8 +463,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "c", - typeName: "complex128", + name: "c", + typeName: "complex128", + stringValue: "c complex128", }, }, results: []variableInfo{ @@ -446,8 +474,9 @@ var ( typeName: "chan string", }, { - name: "", - typeName: "Stringer", + name: "", + importPackage: "fmt", + typeName: "Stringer", }, }, } @@ -469,8 +498,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "s", - typeName: "interface{}", + name: "s", + typeName: "interface{}", + stringValue: "s interface{}", }, }, results: []variableInfo{ @@ -498,16 +528,19 @@ var ( isVariadic: true, params: []variableInfo{ { - name: "s", - typeName: "string", + name: "s", + typeName: "string", + stringValue: "s string", }, { - name: "arr", - typeName: "[]int", + name: "arr", + typeName: "[]int", + stringValue: "arr []int", }, { - name: "v", - typeName: "int16", + name: "v", + typeName: "int16", + stringValue: "v ...int16", }, }, results: []variableInfo{ @@ -539,8 +572,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "v", - typeName: "rune", + name: "v", + typeName: "rune", + stringValue: "v rune", }, }, results: []variableInfo{ @@ -568,12 +602,14 @@ var ( isVariadic: true, params: []variableInfo{ { - name: "s", - typeName: "string", + name: "s", + typeName: "string", + stringValue: "s string", }, { - name: "v", - typeName: "bool", + name: "v", + typeName: "bool", + stringValue: "v ...bool", }, }, results: []variableInfo{ @@ -601,12 +637,14 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "", - typeName: "[]int", + name: "", + typeName: "[]int", + stringValue: "[]int", }, { - name: "", - typeName: "bool", + name: "", + typeName: "bool", + stringValue: "bool", }, }, results: []variableInfo{ @@ -634,8 +672,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "s", - typeName: "interface{}", + name: "s", + typeName: "interface{}", + stringValue: "s interface{}", }, }, results: []variableInfo{}, @@ -793,8 +832,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "i", - typeName: "int", + name: "i", + typeName: "int", + stringValue: "i int", }, }, results: []variableInfo{}, @@ -822,8 +862,9 @@ var ( isVariadic: false, params: []variableInfo{ { - name: "v", - typeName: "interface{}", + name: "v", + typeName: "interface{}", + stringValue: "v interface{}", }, }, results: []variableInfo{ @@ -856,12 +897,14 @@ var ( isVariadic: true, params: []variableInfo{ { - name: "a", - typeName: "[]interface{}", + name: "a", + typeName: "[]interface{}", + stringValue: "a []interface{}", }, { - name: "v", - typeName: "string", + name: "v", + typeName: "string", + stringValue: "v ...string", }, }, results: []variableInfo{ @@ -895,7 +938,7 @@ var ( name: "CustomMethod", fileName: "custom.go", position: Position{ - Line: 3, + Line: 5, Column: 1, }, isVariadic: false, @@ -905,15 +948,30 @@ var ( }, params: []variableInfo{ { - name: "ctx", - typeName: "Z", + name: "ctx", + typeName: "Z", + stringValue: "ctx Z", }, { - name: "value", - typeName: "V", + name: "value", + typeName: "V", + stringValue: "value V", + }, + { + name: "req", + importPackage: "http", + typeName: "Request", + stringValue: "req http.Request", + }, + }, + results: []variableInfo{ + { + name: "", + importPackage: "http", + typeName: "Response", + stringValue: "http.Response", }, }, - results: []variableInfo{}, typeParams: []variableInfo{ { name: "Z", @@ -1009,19 +1067,12 @@ func assertFunctionParameters(t *testing.T, expectedParams []variableInfo, actua t.Errorf("at index %d, the parameter name of the %s should be %s, but got %s", index, msg, expectedFunctionParam.name, actualFunctionParam.name) } - if expectedFunctionParam.String() != actualFunctionParam.Type().Name() { + if expectedFunctionParam.TypeName() != actualFunctionParam.Type().Name() { t.Errorf("at index %d, the parameter type name of the %s should be %s, but got %s", index, msg, expectedFunctionParam.typeName, actualFunctionParam.Type().Name()) } - var expectedFunctionParamString string - if expectedFunctionParam.name == "" { - expectedFunctionParamString = expectedFunctionParam.typeName - } else { - expectedFunctionParamString = fmt.Sprintf("%s %s", expectedFunctionParam.name, expectedFunctionParam.typeName) - } - - if expectedFunctionParamString != actualFunctionParam.String() { - t.Errorf("at index %d parameter, the String() method should return %s, but got %s", index, expectedFunctionParamString, actualFunctionParam.String()) + if expectedFunctionParam.stringValue != actualFunctionParam.String() { + t.Errorf("at index %d parameter, the String() method of the %s should return '%s', but got '%s'", index, msg, expectedFunctionParam.stringValue, actualFunctionParam.String()) } } } @@ -1040,7 +1091,7 @@ func assertFunctionResult(t *testing.T, expectedResults []variableInfo, actualRe t.Errorf("at index %d, the parameter result of the %s should be %s, but got %s", index, msg, expectedFunctionParam.name, actualFunctionParam.name) } - if expectedFunctionParam.String() != actualFunctionParam.Type().Name() { + if expectedFunctionParam.TypeName() != actualFunctionParam.Type().Name() { t.Errorf("at index %d, the parameter result type of the %s should be %s, but got %s", index, msg, expectedFunctionParam.typeName, actualFunctionParam.Type().Name()) } } diff --git a/visitor/interface_test.go b/visitor/interface_test.go index 3a8eeb7..2b73e68 100644 --- a/visitor/interface_test.go +++ b/visitor/interface_test.go @@ -36,14 +36,14 @@ var ( methods: map[string]functionInfo{ "Save": saveFunction, }, - stringValue: "any.Repository[T any,ID any|string|constraints.Ordered]", + stringValue: "any.Repository[T any,ID any|string|constraints.Ordered|int|float32]", } numberInterface = interfaceInfo{ name: "Number", fileName: "generics.go", isExported: true, position: Position{ - Line: 30, + Line: 31, Column: 6, }, explicitMethods: map[string]functionInfo{ @@ -61,7 +61,7 @@ var ( fileName: "generics.go", isExported: true, position: Position{ - Line: 37, + Line: 38, Column: 6, }, explicitMethods: map[string]functionInfo{ diff --git a/visitor/struct_test.go b/visitor/struct_test.go index 065ff73..a3fb400 100644 --- a/visitor/struct_test.go +++ b/visitor/struct_test.go @@ -17,20 +17,21 @@ type fieldInfo struct { } type structInfo struct { - fileName string - isExported bool - position Position - markers markers.Values - methods map[string]functionInfo - allMethods map[string]functionInfo - fields map[string]fieldInfo - embeddedFields map[string]fieldInfo - numFields int - totalFields int - numEmbeddedFields int - stringValue string - isAnonymous bool - interfaces []string + fileName string + isExported bool + position Position + markers markers.Values + methods map[string]functionInfo + allMethods map[string]functionInfo + fields map[string]fieldInfo + embeddedFields map[string]fieldInfo + numFields int + totalFields int + numEmbeddedFields int + stringValue string + isAnonymous bool + implementsInterfaces []string + noImplementsInterfaces []string } // structs @@ -89,6 +90,12 @@ var ( typeName: "Controller", stringValue: "Controller[context.Context,int16,int]", }, + "BaseController": { + isExported: true, + isEmbeddedField: true, + typeName: "BaseController", + stringValue: "BaseController[int]", + }, }, embeddedFields: map[string]fieldInfo{ "Controller": { @@ -97,10 +104,16 @@ var ( typeName: "Controller", stringValue: "Controller[context.Context,int16,int]", }, + "BaseController": { + isExported: true, + isEmbeddedField: true, + typeName: "BaseController", + stringValue: "BaseController[int]", + }, }, - numFields: 1, + numFields: 2, totalFields: 2, - numEmbeddedFields: 1, + numEmbeddedFields: 2, } friedCookieStruct = structInfo{ @@ -170,10 +183,11 @@ var ( stringValue: "menu.cookie", }, }, - interfaces: []string{"Meal"}, - numFields: 4, - totalFields: 5, - numEmbeddedFields: 1, + implementsInterfaces: []string{"Meal"}, + noImplementsInterfaces: []string{"Dessert"}, + numFields: 4, + totalFields: 5, + numEmbeddedFields: 1, } cookieStruct = structInfo{ @@ -237,12 +251,29 @@ var ( totalFields: 2, numEmbeddedFields: 0, } + baseControllerStruct = structInfo{ + markers: markers.Values{}, + stringValue: "any.BaseController[M any]", + fileName: "generics.go", + isExported: true, + position: Position{ + Line: 42, + Column: 6, + }, + methods: map[string]functionInfo{}, + allMethods: map[string]functionInfo{}, + fields: map[string]fieldInfo{}, + embeddedFields: map[string]fieldInfo{}, + numFields: 0, + totalFields: 0, + numEmbeddedFields: 0, + } ) func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool { if len(structs) != file.Structs().Len() { - t.Errorf("the number of the structs should be %d, but got %d", len(structs), file.Structs().Len()) + t.Errorf("the number of the structs in file %s should be %d, but got %d", file.Name(), len(structs), file.Structs().Len()) return false } @@ -320,7 +351,7 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool assertStructFields(t, actualStruct.Name(), actualStruct.Fields(), expectedStruct.fields) assertMarkers(t, expectedStruct.markers, actualStruct.Markers(), fmt.Sprintf("struct %s", expectedStructName)) - for _, interfaceName := range expectedStruct.interfaces { + for _, interfaceName := range expectedStruct.implementsInterfaces { iface, exists := file.Interfaces().FindByName(interfaceName) if !exists { @@ -333,6 +364,20 @@ func assertStructs(t *testing.T, file *File, structs map[string]structInfo) bool continue } } + + for _, interfaceName := range expectedStruct.noImplementsInterfaces { + iface, exists := file.Interfaces().FindByName(interfaceName) + + if !exists { + t.Errorf("the interface %s should exists in file %s and the struct %s should implement it", interfaceName, file.Name(), actualStruct.Name()) + continue + } + + if actualStruct.Implements(iface) { + t.Errorf(" the struct %s should not implement the interface %s", actualStruct.Name(), interfaceName) + continue + } + } index++ } @@ -387,3 +432,20 @@ func TestStructs_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { structs := &Structs{} assert.Nil(t, structs.At(0)) } + +func TestFields_AtShouldReturnNilIfIndexIsOutOfRange(t *testing.T) { + fields := &Fields{} + assert.Nil(t, fields.At(0)) +} +func TestFields_AtShouldReturnFieldIfIndexIsBetweenRange(t *testing.T) { + fields := &Fields{ + elements: []*Field{ + { + name: "anyField", + }, + }, + } + field := fields.At(0) + assert.NotNil(t, field) + assert.Equal(t, "anyField", field.Name()) +} diff --git a/visitor/type_test.go b/visitor/type_test.go new file mode 100644 index 0000000..a2cdf0c --- /dev/null +++ b/visitor/type_test.go @@ -0,0 +1,47 @@ +package visitor + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestTypes_FindByNameShouldReturnTypeIfItExists(t *testing.T) { + types := &Types{ + elements: []Type{ + &Struct{ + name: "test", + }, + }, + } + + typ, ok := types.FindByName("test") + assert.True(t, ok) + assert.NotNil(t, typ) + assert.Equal(t, "test", typ.Name()) +} + +func TestTypes_FindByNameShouldReturnNilIfItDoesExist(t *testing.T) { + types := &Types{ + elements: []Type{}, + } + + typ, ok := types.FindByName("test") + assert.False(t, ok) + assert.Nil(t, typ) +} + +func TestTypes_AtShouldReturnNilIfGivenIndexIsOutOfRange(t *testing.T) { + types := &Types{ + elements: []Type{}, + } + + typ := types.At(-1) + assert.Nil(t, typ) +} + +func TestTypeSets_AtShouldReturnNilIfGivenIndexIsOutOfRange(t *testing.T) { + typeSets := &TypeSets{} + + typ := typeSets.At(-1) + assert.Nil(t, typ) +} diff --git a/visitor/visitor_test.go b/visitor/visitor_test.go index 3a71e15..24145c5 100644 --- a/visitor/visitor_test.go +++ b/visitor/visitor_test.go @@ -42,12 +42,14 @@ type FunctionLevel struct { } type variableInfo struct { - name string - typeName string - isPointer bool + name string + typeName string + stringValue string + importPackage string + isPointer bool } -func (v variableInfo) String() string { +func (v variableInfo) TypeName() string { if v.isPointer { return fmt.Sprintf("*%s", v.typeName) } @@ -196,6 +198,15 @@ func TestVisitor_VisitPackage(t *testing.T) { functions: map[string]functionInfo{ "CustomMethod": customHttpHandlerMethod, }, + imports: []importInfo{ + { + name: "", + path: "net/http", + sideEffect: false, + file: "custom.go", + position: Position{Line: 3, Column: 8}, + }, + }, }, "error.go": { path: fmt.Sprintf("%s/test/any/error.go", path), @@ -250,6 +261,7 @@ func TestVisitor_VisitPackage(t *testing.T) { structs: map[string]structInfo{ "Controller": controllerStruct, "TestController": testControllerStruct, + "BaseController": baseControllerStruct, }, imports: []importInfo{ {