diff --git a/mains/generate.go b/mains/generate.go index f0afc01c..a7bdfcac 100644 --- a/mains/generate.go +++ b/mains/generate.go @@ -125,7 +125,17 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([ fsCounts[sd.Service.Name] = len(sd.FileServers) } } - hasWS := httpcodegen.NeedDialer(httpSvcs) + // Detect WebSocket usage from HTTP endpoints (streaming without SSE). + // NeedDialer() is for client dialers, but here we're generating server mains + // and need to know whether to import gorilla/websocket for the upgrader. + hasWS := false + wsBySvc := httpWebSocketByService(roots) + for _, sd := range httpSvcs { + if sd != nil && sd.Service != nil && wsBySvc[sd.Service.Name] { + hasWS = true + break + } + } apipkg := apiPkgAlias(genpkg, roots) if info, ok := srvMap[dir]; ok { info.HasWS = hasWS diff --git a/mains/generate_test.go b/mains/generate_test.go index 80bebc2d..aca43857 100644 --- a/mains/generate_test.go +++ b/mains/generate_test.go @@ -86,14 +86,19 @@ func TestWebSocketMainIncludesUpgrader(t *testing.T) { } require.NotNil(t, mainFile) - // Extract code for our section and assert WS import and upgrader usage - sections := mainFile.Section("mains-main") - require.Greater(t, len(sections), 0) - var buf bytes.Buffer - require.NoError(t, sections[0].Write(&buf)) - code := buf.String() - assert.Contains(t, code, "github.com/gorilla/websocket") - assert.Contains(t, code, "websocket.Upgrader") + // Assert WS import is present in the generated header. + header := mainFile.Section("source-header") + require.Greater(t, len(header), 0) + var hbuf bytes.Buffer + require.NoError(t, header[0].Write(&hbuf)) + assert.Contains(t, hbuf.String(), "github.com/gorilla/websocket") + + // Assert upgrader usage is present in the main body. + body := mainFile.Section("mains-main") + require.Greater(t, len(body), 0) + var bbuf bytes.Buffer + require.NoError(t, body[0].Write(&bbuf)) + assert.Contains(t, bbuf.String(), "websocket.Upgrader") } func TestMainsAddsFileServerNils(t *testing.T) { diff --git a/testing/codegen/scenarios.go b/testing/codegen/scenarios.go index bdece2ba..ac3b1175 100644 --- a/testing/codegen/scenarios.go +++ b/testing/codegen/scenarios.go @@ -5,10 +5,10 @@ import ( "os" "path/filepath" - "gopkg.in/yaml.v3" "goa.design/goa/v3/codegen" "goa.design/goa/v3/codegen/service" "goa.design/goa/v3/expr" + "gopkg.in/yaml.v3" ) // generateScenarios generates the scenario runner for a service. @@ -33,7 +33,7 @@ func generateScenarios(genpkg string, svcData *service.Data, root *expr.RootExpr {Path: "gopkg.in/yaml.v3", Name: "yaml"}, {Path: filepath.Join(genpkg, codegen.SnakeCase(svc.Name)), Name: svcData.PkgName}, } - + // Add validator package import if specified in YAML // Only add if it's different from the current package currentPkgPath := filepath.Join(genpkg, codegen.SnakeCase(svc.Name), codegen.SnakeCase(svc.Name)+"test") @@ -101,26 +101,30 @@ type ( *service.MethodData // Transports lists valid transport strings for YAML Transports []string + // ResultTypeRef is the fully qualified Go reference to the result type + // as seen from the generated test package (e.g. "*svc.Foo", "[]*svc.Bar"). + // This is used in generated type assertions for custom validators. + ResultTypeRef string } ) // generateExampleScenarios generates an example scenarios.yaml file for a service. func generateExampleScenarios(_ string, root *expr.RootExpr, svc *expr.ServiceExpr) *codegen.File { path := filepath.Join(codegen.Gendir, "..", "scenarios.yaml") - + svcData := service.NewServicesData(root).Get(svc.Name) if svcData == nil { return nil } - + data := buildScenariosData(svcData, root, svc) - + // For YAML files, we need to read the template directly since it's not a .go.tpl file tmplContent, err := templateFS.ReadFile("templates/example_scenarios.yaml.tpl") if err != nil { panic(fmt.Sprintf("failed to read example_scenarios.yaml.tpl: %v", err)) } - + sections := []*codegen.SectionTemplate{ { Name: "example-scenarios", @@ -128,7 +132,7 @@ func generateExampleScenarios(_ string, root *expr.RootExpr, svc *expr.ServiceEx Data: data, }, } - + return &codegen.File{ Path: path, SectionTemplates: sections, @@ -139,20 +143,20 @@ func generateExampleScenarios(_ string, root *expr.RootExpr, svc *expr.ServiceEx func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.ServiceExpr) *scenariosData { // Extract validator info from YAML validatorInfo := ExtractValidatorsFromYAML() - + data := &scenariosData{ - Data: svcData, - ServiceExpr: svc, - Methods: make([]*scenarioMethodData, 0), - HasHTTP: hasHTTPTransport(root, svc), - HasGRPC: hasGRPCTransport(root, svc), - HasJSONRPC: hasJSONRPCTransport(root, svc), + Data: svcData, + ServiceExpr: svc, + Methods: make([]*scenarioMethodData, 0), + HasHTTP: hasHTTPTransport(root, svc), + HasGRPC: hasGRPCTransport(root, svc), + HasJSONRPC: hasJSONRPCTransport(root, svc), ValidTransports: make([]string, 0), - Validators: validatorInfo.Validators, - ValidatorPkg: "", // Will be set below if it's a different package - ValidatorPath: validatorInfo.Path, + Validators: validatorInfo.Validators, + ValidatorPkg: "", // Will be set below if it's a different package + ValidatorPath: validatorInfo.Path, } - + // Build list of valid transports transportSet := make(map[string]bool) transportSet["auto"] = true @@ -176,15 +180,21 @@ func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.Se // Build method data with available transports for i, m := range svc.Methods { methodData := svcData.Methods[i] - + // Build targets for this method using shared function targets := buildMethodTargets(root, svc, m, methodData) - + md := &scenarioMethodData{ MethodData: methodData, Transports: make([]string, 0), } - + // Compute fully qualified result type reference for type assertions. + // This properly handles composite types like ArrayOf(...) without producing + // invalid Go like "svc.[]T" (see issue #234). + if m.Result != nil && m.Result.Type != expr.Empty { + md.ResultTypeRef = svcData.Scope.GoFullTypeRef(m.Result, svcData.PkgName) + } + // Build list of valid transport strings based on targets transportSet := make(map[string]bool) for _, target := range targets { @@ -205,15 +215,15 @@ func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.Se transportSet["jsonrpc-ws"] = true } } - + // Convert set to sorted list for transport := range transportSet { md.Transports = append(md.Transports, transport) } - + data.Methods = append(data.Methods, md) } - + return data } @@ -230,19 +240,19 @@ func ExtractValidatorsFromYAML() ValidatorInfo { Validators: make(map[string][]string), Package: "", // Empty means use current package } - + // Try to read scenarios.yaml from current directory data, err := os.ReadFile("scenarios.yaml") if err != nil { // File doesn't exist or can't be read, that's OK return info } - + if len(data) == 0 { // Empty file return info } - + var config struct { Validators struct { Package string `yaml:"package"` @@ -257,16 +267,16 @@ func ExtractValidatorsFromYAML() ValidatorInfo { } `yaml:"steps"` } `yaml:"scenarios"` } - + if err := yaml.Unmarshal(data, &config); err != nil { // Invalid YAML, skip return info } - + // Extract package info info.Package = config.Validators.Package info.Path = config.Validators.Path - + // Extract unique validators per method for _, scenario := range config.Scenarios { if scenario.Steps == nil { @@ -289,6 +299,6 @@ func ExtractValidatorsFromYAML() ValidatorInfo { } } } - + return info -} \ No newline at end of file +} diff --git a/testing/codegen/scenarios_test.go b/testing/codegen/scenarios_test.go index ffed302d..bf01a8de 100644 --- a/testing/codegen/scenarios_test.go +++ b/testing/codegen/scenarios_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "goa.design/goa/v3/codegen" "goa.design/goa/v3/codegen/service" httpcodegen "goa.design/goa/v3/http/codegen" "goa.design/plugins/v3/testing/codegen/testdata" @@ -52,3 +53,27 @@ func TestGenerateScenarios(t *testing.T) { }) } } + +func TestGenerateScenarios_ArrayResultTypeAssertion(t *testing.T) { + root := httpcodegen.RunHTTPDSL(t, testdata.WithArrayResultDSL) + services := service.NewServicesData(root) + svc := root.Services[0] + svcData := services.Get(svc.Name) + fs := generateScenarios("", svcData, root, svc) + f := fs[0] + + sections := f.Section("scenario-runner") + if len(sections) != 1 { + t.Fatalf("expected 1 scenario-runner section, got %d", len(sections)) + } + code := codegen.SectionCode(t, sections[0]) + + // This is the canonical fully-qualified Go type reference that should be used + // in the generated type assertion. + wantRef := svcData.Scope.GoFullTypeRef(svc.Methods[0].Result, svcData.PkgName) + assert.Contains(t, code, "typedResult := result.("+wantRef+")") + + // Regression guard for invalid formatting like "pkg.[]T" (issue #234). + assert.NotContains(t, code, svcData.PkgName+".[]") + assert.NotContains(t, code, "*"+svcData.PkgName+".[]") +} diff --git a/testing/codegen/templates/scenario_runner.go.tpl b/testing/codegen/templates/scenario_runner.go.tpl index 4e07ffc2..825550de 100644 --- a/testing/codegen/templates/scenario_runner.go.tpl +++ b/testing/codegen/templates/scenario_runner.go.tpl @@ -222,8 +222,8 @@ func (r *ScenarioRunner) callValidator(t *testing.T, method string, result any, switch method { {{- range .Methods }} case "{{ .Name }}": - {{- if .ResultRef }} - typedResult := result.(*{{ $.PkgName }}.{{ .Result }}) + {{- if .ResultTypeRef }} + typedResult := result.({{ .ResultTypeRef }}) {{- $validators := index $.Validators .Name }} {{- if $validators }} diff --git a/testing/codegen/testdata/dsls.go b/testing/codegen/testdata/dsls.go index da0514ce..edd6fe2f 100644 --- a/testing/codegen/testdata/dsls.go +++ b/testing/codegen/testdata/dsls.go @@ -46,6 +46,21 @@ var WithoutPayloadResultDSL = func() { }) } +var WithArrayResultDSL = func() { + var AccessControl = ResultType("AccessControl", func() { + Attribute("id", String) + Required("id") + }) + Service("WithArrayResultService", func() { + Method("ListAccessControl", func() { + Result(ArrayOf(AccessControl)) + HTTP(func() { + GET("/") + }) + }) + }) +} + var WithStreamDSL = func() { Service("WithStreamService", func() { Method("WithStreamMethod", func() {