Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mains/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,17 @@ func generateExample(genpkg string, roots []eval.Root, files []*codegen.File) ([
fsCounts[sd.Service.Name] = len(sd.FileServers)
}
}
hasWS := httpcodegen.NeedDialer(httpSvcs)
// Detect WebSocket usage from HTTP endpoints (streaming without SSE).
// NeedDialer() is for client dialers, but here we're generating server mains
// and need to know whether to import gorilla/websocket for the upgrader.
hasWS := false
wsBySvc := httpWebSocketByService(roots)
for _, sd := range httpSvcs {
if sd != nil && sd.Service != nil && wsBySvc[sd.Service.Name] {
hasWS = true
break
}
}
apipkg := apiPkgAlias(genpkg, roots)
if info, ok := srvMap[dir]; ok {
info.HasWS = hasWS
Expand Down
21 changes: 13 additions & 8 deletions mains/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,19 @@ func TestWebSocketMainIncludesUpgrader(t *testing.T) {
}
require.NotNil(t, mainFile)

// Extract code for our section and assert WS import and upgrader usage
sections := mainFile.Section("mains-main")
require.Greater(t, len(sections), 0)
var buf bytes.Buffer
require.NoError(t, sections[0].Write(&buf))
code := buf.String()
assert.Contains(t, code, "github.com/gorilla/websocket")
assert.Contains(t, code, "websocket.Upgrader")
// Assert WS import is present in the generated header.
header := mainFile.Section("source-header")
require.Greater(t, len(header), 0)
var hbuf bytes.Buffer
require.NoError(t, header[0].Write(&hbuf))
assert.Contains(t, hbuf.String(), "github.com/gorilla/websocket")

// Assert upgrader usage is present in the main body.
body := mainFile.Section("mains-main")
require.Greater(t, len(body), 0)
var bbuf bytes.Buffer
require.NoError(t, body[0].Write(&bbuf))
assert.Contains(t, bbuf.String(), "websocket.Upgrader")
}

func TestMainsAddsFileServerNils(t *testing.T) {
Expand Down
74 changes: 42 additions & 32 deletions testing/codegen/scenarios.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"os"
"path/filepath"

"gopkg.in/yaml.v3"
"goa.design/goa/v3/codegen"
"goa.design/goa/v3/codegen/service"
"goa.design/goa/v3/expr"
"gopkg.in/yaml.v3"
)

// generateScenarios generates the scenario runner for a service.
Expand All @@ -33,7 +33,7 @@ func generateScenarios(genpkg string, svcData *service.Data, root *expr.RootExpr
{Path: "gopkg.in/yaml.v3", Name: "yaml"},
{Path: filepath.Join(genpkg, codegen.SnakeCase(svc.Name)), Name: svcData.PkgName},
}

// Add validator package import if specified in YAML
// Only add if it's different from the current package
currentPkgPath := filepath.Join(genpkg, codegen.SnakeCase(svc.Name), codegen.SnakeCase(svc.Name)+"test")
Expand Down Expand Up @@ -101,34 +101,38 @@ type (
*service.MethodData
// Transports lists valid transport strings for YAML
Transports []string
// ResultTypeRef is the fully qualified Go reference to the result type
// as seen from the generated test package (e.g. "*svc.Foo", "[]*svc.Bar").
// This is used in generated type assertions for custom validators.
ResultTypeRef string
}
)

// generateExampleScenarios generates an example scenarios.yaml file for a service.
func generateExampleScenarios(_ string, root *expr.RootExpr, svc *expr.ServiceExpr) *codegen.File {
path := filepath.Join(codegen.Gendir, "..", "scenarios.yaml")

svcData := service.NewServicesData(root).Get(svc.Name)
if svcData == nil {
return nil
}

data := buildScenariosData(svcData, root, svc)

// For YAML files, we need to read the template directly since it's not a .go.tpl file
tmplContent, err := templateFS.ReadFile("templates/example_scenarios.yaml.tpl")
if err != nil {
panic(fmt.Sprintf("failed to read example_scenarios.yaml.tpl: %v", err))
}

sections := []*codegen.SectionTemplate{
{
Name: "example-scenarios",
Source: string(tmplContent),
Data: data,
},
}

return &codegen.File{
Path: path,
SectionTemplates: sections,
Expand All @@ -139,20 +143,20 @@ func generateExampleScenarios(_ string, root *expr.RootExpr, svc *expr.ServiceEx
func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.ServiceExpr) *scenariosData {
// Extract validator info from YAML
validatorInfo := ExtractValidatorsFromYAML()

data := &scenariosData{
Data: svcData,
ServiceExpr: svc,
Methods: make([]*scenarioMethodData, 0),
HasHTTP: hasHTTPTransport(root, svc),
HasGRPC: hasGRPCTransport(root, svc),
HasJSONRPC: hasJSONRPCTransport(root, svc),
Data: svcData,
ServiceExpr: svc,
Methods: make([]*scenarioMethodData, 0),
HasHTTP: hasHTTPTransport(root, svc),
HasGRPC: hasGRPCTransport(root, svc),
HasJSONRPC: hasJSONRPCTransport(root, svc),
ValidTransports: make([]string, 0),
Validators: validatorInfo.Validators,
ValidatorPkg: "", // Will be set below if it's a different package
ValidatorPath: validatorInfo.Path,
Validators: validatorInfo.Validators,
ValidatorPkg: "", // Will be set below if it's a different package
ValidatorPath: validatorInfo.Path,
}

// Build list of valid transports
transportSet := make(map[string]bool)
transportSet["auto"] = true
Expand All @@ -176,15 +180,21 @@ func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.Se
// Build method data with available transports
for i, m := range svc.Methods {
methodData := svcData.Methods[i]

// Build targets for this method using shared function
targets := buildMethodTargets(root, svc, m, methodData)

md := &scenarioMethodData{
MethodData: methodData,
Transports: make([]string, 0),
}

// Compute fully qualified result type reference for type assertions.
// This properly handles composite types like ArrayOf(...) without producing
// invalid Go like "svc.[]T" (see issue #234).
if m.Result != nil && m.Result.Type != expr.Empty {
md.ResultTypeRef = svcData.Scope.GoFullTypeRef(m.Result, svcData.PkgName)
}

// Build list of valid transport strings based on targets
transportSet := make(map[string]bool)
for _, target := range targets {
Expand All @@ -205,15 +215,15 @@ func buildScenariosData(svcData *service.Data, root *expr.RootExpr, svc *expr.Se
transportSet["jsonrpc-ws"] = true
}
}

// Convert set to sorted list
for transport := range transportSet {
md.Transports = append(md.Transports, transport)
}

data.Methods = append(data.Methods, md)
}

return data
}

Expand All @@ -230,19 +240,19 @@ func ExtractValidatorsFromYAML() ValidatorInfo {
Validators: make(map[string][]string),
Package: "", // Empty means use current package
}

// Try to read scenarios.yaml from current directory
data, err := os.ReadFile("scenarios.yaml")
if err != nil {
// File doesn't exist or can't be read, that's OK
return info
}

if len(data) == 0 {
// Empty file
return info
}

var config struct {
Validators struct {
Package string `yaml:"package"`
Expand All @@ -257,16 +267,16 @@ func ExtractValidatorsFromYAML() ValidatorInfo {
} `yaml:"steps"`
} `yaml:"scenarios"`
}

if err := yaml.Unmarshal(data, &config); err != nil {
// Invalid YAML, skip
return info
}

// Extract package info
info.Package = config.Validators.Package
info.Path = config.Validators.Path

// Extract unique validators per method
for _, scenario := range config.Scenarios {
if scenario.Steps == nil {
Expand All @@ -289,6 +299,6 @@ func ExtractValidatorsFromYAML() ValidatorInfo {
}
}
}

return info
}
}
25 changes: 25 additions & 0 deletions testing/codegen/scenarios_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"goa.design/goa/v3/codegen"
"goa.design/goa/v3/codegen/service"
httpcodegen "goa.design/goa/v3/http/codegen"
"goa.design/plugins/v3/testing/codegen/testdata"
Expand Down Expand Up @@ -52,3 +53,27 @@ func TestGenerateScenarios(t *testing.T) {
})
}
}

func TestGenerateScenarios_ArrayResultTypeAssertion(t *testing.T) {
root := httpcodegen.RunHTTPDSL(t, testdata.WithArrayResultDSL)
services := service.NewServicesData(root)
svc := root.Services[0]
svcData := services.Get(svc.Name)
fs := generateScenarios("", svcData, root, svc)
f := fs[0]

sections := f.Section("scenario-runner")
if len(sections) != 1 {
t.Fatalf("expected 1 scenario-runner section, got %d", len(sections))
}
code := codegen.SectionCode(t, sections[0])

// This is the canonical fully-qualified Go type reference that should be used
// in the generated type assertion.
wantRef := svcData.Scope.GoFullTypeRef(svc.Methods[0].Result, svcData.PkgName)
assert.Contains(t, code, "typedResult := result.("+wantRef+")")

// Regression guard for invalid formatting like "pkg.[]T" (issue #234).
assert.NotContains(t, code, svcData.PkgName+".[]")
assert.NotContains(t, code, "*"+svcData.PkgName+".[]")
}
4 changes: 2 additions & 2 deletions testing/codegen/templates/scenario_runner.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ func (r *ScenarioRunner) callValidator(t *testing.T, method string, result any,
switch method {
{{- range .Methods }}
case "{{ .Name }}":
{{- if .ResultRef }}
typedResult := result.(*{{ $.PkgName }}.{{ .Result }})
{{- if .ResultTypeRef }}
typedResult := result.({{ .ResultTypeRef }})
{{- $validators := index $.Validators .Name }}
{{- if $validators }}

Expand Down
15 changes: 15 additions & 0 deletions testing/codegen/testdata/dsls.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ var WithoutPayloadResultDSL = func() {
})
}

var WithArrayResultDSL = func() {
var AccessControl = ResultType("AccessControl", func() {
Attribute("id", String)
Required("id")
})
Service("WithArrayResultService", func() {
Method("ListAccessControl", func() {
Result(ArrayOf(AccessControl))
HTTP(func() {
GET("/")
})
})
})
}

var WithStreamDSL = func() {
Service("WithStreamService", func() {
Method("WithStreamMethod", func() {
Expand Down
Loading