diff --git a/scope.go b/scope.go index ef970565..289f9e14 100644 --- a/scope.go +++ b/scope.go @@ -396,6 +396,8 @@ func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) { if req.RequestURI == pingEndpoint { req.URL.Scheme = s.host.Scheme() req.URL.Host = s.host.Host() + // Set the Host header to match the target URL host + req.Host = req.URL.Host return req, req.URL.Query() } @@ -446,6 +448,9 @@ func (s *scope) decorateRequest(req *http.Request) (*http.Request, url.Values) { req.URL.Scheme = s.host.Scheme() req.URL.Host = s.host.Host() + // Set the Host header to match the target URL host + req.Host = req.URL.Host + // Extend ua with additional info, so it may be queried // via system.query_log.http_user_agent. ua := fmt.Sprintf("RemoteAddr: %s; LocalAddr: %s; CHProxy-User: %s; CHProxy-ClusterUser: %s; %s", diff --git a/scope_test.go b/scope_test.go index 438966c6..0afdf800 100644 --- a/scope_test.go +++ b/scope_test.go @@ -523,3 +523,46 @@ func testGetScope(c *cluster, u *user, cu *clusterUser, sessionId string) *scope } return s } + +func TestDecorateRequestHostHeader(t *testing.T) { + testCases := []struct { + name string + requestURI string + targetHost string + }{ + { + name: "ping endpoint", + requestURI: "/ping", + targetHost: "127.0.0.2:8123", + }, + { + name: "regular query", + requestURI: "/?query=SELECT%201", + targetHost: "127.0.0.3:9000", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest("GET", "http://127.0.0.1"+tc.requestURI, nil) + if err != nil { + t.Fatalf("unexpected error while creating request: %s", err) + } + s := &scope{ + id: newScopeID(), + clusterUser: &clusterUser{name: "default"}, + user: &user{}, + host: topology.NewNode(&url.URL{Host: tc.targetHost}, nil, "", ""), + } + + req, _ = s.decorateRequest(req) + + if req.Host != tc.targetHost { + t.Fatalf("expected Host header %q; got %q", tc.targetHost, req.Host) + } + if req.URL.Host != tc.targetHost { + t.Fatalf("expected URL.Host %q; got %q", tc.targetHost, req.URL.Host) + } + }) + } +}