|
| 1 | +//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.apigwv2) |
| 2 | + |
1 | 3 | package handler |
2 | 4 |
|
3 | 5 | import ( |
| 6 | + "bytes" |
4 | 7 | "context" |
5 | 8 | "encoding/base64" |
6 | 9 | "github.com/aws/aws-lambda-go/events" |
7 | 10 | "net/http" |
| 11 | + "strconv" |
8 | 12 | "strings" |
9 | 13 | "unicode/utf8" |
10 | 14 | ) |
11 | 15 |
|
12 | | -func apiGwV2RequestConverter(ctx context.Context, event events.APIGatewayV2HTTPRequest) (*http.Request, error) { |
| 16 | +func convertApiGwV2Request(ctx context.Context, event events.APIGatewayV2HTTPRequest) (*http.Request, error) { |
13 | 17 | url := buildFullRequestURL(event.RequestContext.DomainName, event.RawPath, event.RequestContext.HTTP.Path, buildQuery(event.RawQueryString, event.QueryStringParameters)) |
14 | 18 | req, err := http.NewRequestWithContext(ctx, event.RequestContext.HTTP.Method, url, getBody(event.Body, event.IsBase64Encoded)) |
15 | 19 | if err != nil { |
@@ -40,46 +44,89 @@ func apiGwV2RequestConverter(ctx context.Context, event events.APIGatewayV2HTTPR |
40 | 44 | return req, nil |
41 | 45 | } |
42 | 46 |
|
43 | | -func apiGwV2ResponseInitializer(ctx context.Context) *ResponseWriterProxy { |
44 | | - return NewResponseWriterProxy() |
| 47 | +type apiGwV2ResponseWriter struct { |
| 48 | + headersWritten bool |
| 49 | + contentTypeSet bool |
| 50 | + contentLengthSet bool |
| 51 | + headers http.Header |
| 52 | + body bytes.Buffer |
| 53 | + res events.APIGatewayV2HTTPResponse |
45 | 54 | } |
46 | 55 |
|
47 | | -func apiGwV2ResponseFinalizer(ctx context.Context, w *ResponseWriterProxy) (events.APIGatewayV2HTTPResponse, error) { |
48 | | - out := events.APIGatewayV2HTTPResponse{ |
49 | | - StatusCode: w.Status, |
50 | | - Headers: make(map[string]string), |
51 | | - Cookies: make([]string, 0), |
52 | | - } |
| 56 | +func (w *apiGwV2ResponseWriter) Header() http.Header { |
| 57 | + return w.headers |
| 58 | +} |
| 59 | + |
| 60 | +func (w *apiGwV2ResponseWriter) Write(p []byte) (int, error) { |
| 61 | + w.WriteHeader(http.StatusOK) |
| 62 | + return w.body.Write(p) |
| 63 | +} |
| 64 | + |
| 65 | +func (w *apiGwV2ResponseWriter) WriteHeader(statusCode int) { |
| 66 | + if !w.headersWritten { |
| 67 | + w.headersWritten = true |
| 68 | + w.res.StatusCode = statusCode |
53 | 69 |
|
54 | | - for k, values := range w.Headers { |
55 | | - if strings.EqualFold("set-cookie", k) { |
56 | | - out.Cookies = values |
57 | | - } else { |
58 | | - if len(values) == 0 { |
59 | | - out.Headers[k] = "" |
60 | | - } else if len(values) == 1 { |
61 | | - out.Headers[k] = values[0] |
| 70 | + for k, values := range w.headers { |
| 71 | + if strings.EqualFold("set-cookie", k) { |
| 72 | + w.res.Cookies = values |
62 | 73 | } else { |
63 | | - if out.MultiValueHeaders == nil { |
64 | | - out.MultiValueHeaders = make(map[string][]string) |
65 | | - } |
| 74 | + if len(values) == 0 { |
| 75 | + w.res.Headers[k] = "" |
| 76 | + } else if len(values) == 1 { |
| 77 | + w.res.Headers[k] = values[0] |
| 78 | + } else { |
| 79 | + if w.res.MultiValueHeaders == nil { |
| 80 | + w.res.MultiValueHeaders = make(map[string][]string) |
| 81 | + } |
66 | 82 |
|
67 | | - out.MultiValueHeaders[k] = values |
| 83 | + w.res.MultiValueHeaders[k] = values |
| 84 | + } |
68 | 85 | } |
69 | 86 | } |
70 | 87 | } |
| 88 | +} |
| 89 | + |
| 90 | +func handleApiGwV2(ctx context.Context, event events.APIGatewayV2HTTPRequest, adapter AdapterFunc) (events.APIGatewayV2HTTPResponse, error) { |
| 91 | + req, err := convertApiGwV2Request(ctx, event) |
| 92 | + if err != nil { |
| 93 | + var def events.APIGatewayV2HTTPResponse |
| 94 | + return def, err |
| 95 | + } |
| 96 | + |
| 97 | + w := apiGwV2ResponseWriter{ |
| 98 | + headers: make(http.Header), |
| 99 | + res: events.APIGatewayV2HTTPResponse{ |
| 100 | + Headers: make(map[string]string), |
| 101 | + Cookies: make([]string, 0), |
| 102 | + }, |
| 103 | + } |
| 104 | + |
| 105 | + if err = adapter(ctx, req, &w); err != nil { |
| 106 | + var def events.APIGatewayV2HTTPResponse |
| 107 | + return def, err |
| 108 | + } |
| 109 | + |
| 110 | + b := w.body.Bytes() |
| 111 | + |
| 112 | + if !w.contentTypeSet { |
| 113 | + w.res.Headers["Content-Type"] = http.DetectContentType(b) |
| 114 | + } |
| 115 | + |
| 116 | + if !w.contentLengthSet { |
| 117 | + w.res.Headers["Content-Length"] = strconv.Itoa(len(b)) |
| 118 | + } |
71 | 119 |
|
72 | | - b := w.Body.Bytes() |
73 | 120 | if utf8.Valid(b) { |
74 | | - out.Body = string(b) |
| 121 | + w.res.Body = string(b) |
75 | 122 | } else { |
76 | | - out.IsBase64Encoded = true |
77 | | - out.Body = base64.StdEncoding.EncodeToString(b) |
| 123 | + w.res.IsBase64Encoded = true |
| 124 | + w.res.Body = base64.StdEncoding.EncodeToString(b) |
78 | 125 | } |
79 | 126 |
|
80 | | - return out, nil |
| 127 | + return w.res, nil |
81 | 128 | } |
82 | 129 |
|
83 | 130 | func NewAPIGatewayV2Handler(adapter AdapterFunc) func(context.Context, events.APIGatewayV2HTTPRequest) (events.APIGatewayV2HTTPResponse, error) { |
84 | | - return NewHandler(apiGwV2RequestConverter, apiGwV2ResponseInitializer, apiGwV2ResponseFinalizer, adapter) |
| 131 | + return NewHandler(handleApiGwV2, adapter) |
85 | 132 | } |
0 commit comments