diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go
index 169709d1..dc4a5bf6 100644
--- a/internal/auth/authenticator.go
+++ b/internal/auth/authenticator.go
@@ -134,36 +134,6 @@ type signInResp struct {
Version string
}
-// SignInPage directs the user to the sign in page
-func (p *Authenticator) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
- rw.WriteHeader(code)
-
- // We construct this URL based on the known callback URL that we send to Google.
- // We don't want to rely on req.Host, as that can be attacked via Host header injection
- // This ends up looking like:
- // https://sso-auth.example.com/sign_in?client_id=...&redirect_uri=...
- path := strings.TrimPrefix(req.URL.Path, "/")
- redirectURL := p.redirectURL.ResolveReference(
- &url.URL{
- Path: path,
- RawQuery: req.URL.RawQuery,
- },
- )
-
- // validateRedirectURI middleware already ensures that this is a valid URL
- destinationURL, _ := url.Parse(redirectURL.Query().Get("redirect_uri"))
-
- t := signInResp{
- ProviderName: p.provider.Data().ProviderName,
- ProviderSlug: p.provider.Data().ProviderSlug,
- EmailDomains: p.EmailDomains,
- Redirect: redirectURL.String(),
- Destination: destinationURL.Host,
- Version: VERSION,
- }
- p.templates.ExecuteTemplate(rw, "sign_in.html", t)
-}
-
func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) (*sessions.SessionState, error) {
logger := log.NewLogEntry()
remoteAddr := getRemoteAddr(req)
@@ -239,15 +209,10 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request)
return session, nil
}
-// SignIn handles the /sign_in endpoint. It attempts to authenticate the user, and if the user is not authenticated, it renders
-// a sign in page.
+// SignIn handles the /sign_in endpoint. It attempts to authenticate the user, and if the user is not authenticated,
+// it starts the authentication process.
+// If the user is authenticated, we redirect back to the proxy application at the `redirect_uri`, with a temporary token.
func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
- // We attempt to authenticate the user. If they cannot be authenticated, we render a sign-in
- // page.
- //
- // If the user is authenticated, we redirect back to the proxy application
- // at the `redirect_uri`, with a temporary token.
- //
// TODO: It is possible for a user to visit this page without a redirect destination.
// Should we allow the user to authenticate? If not, what should be the proposed workflow?
@@ -257,6 +222,22 @@ func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
"action:sign_in",
fmt.Sprintf("proxy_host:%s", proxyHost),
}
+
+ // We construct this URL based on the known callback URL that we send to Google.
+ // We don't want to rely on req.Host, as that can be attacked via Host header injection
+ // This ends up looking like:
+ // https://sso-auth.example.com/sign_in?client_id=...&redirect_uri=...
+ //
+ // The validateRedirectURI middleware ensures that this is a valid URL
+ path := strings.TrimPrefix(req.URL.Path, "/")
+ redirectURL := p.redirectURL.ResolveReference(
+ &url.URL{
+ Path: path,
+ RawQuery: req.URL.RawQuery,
+ },
+ )
+ req.URL = redirectURL
+
session, err := p.authenticate(rw, req)
switch err {
case nil:
@@ -264,13 +245,13 @@ func (p *Authenticator) SignIn(rw http.ResponseWriter, req *http.Request) {
// with the necessary state
p.ProxyOAuthRedirect(rw, req, session, tags)
case http.ErrNoCookie:
- p.SignInPage(rw, req, http.StatusOK)
+ p.OAuthStart(rw, req)
case providers.ErrTokenRevoked:
p.sessionStore.ClearSession(rw, req)
- p.SignInPage(rw, req, http.StatusOK)
+ p.OAuthStart(rw, req)
case sessions.ErrLifetimeExpired, sessions.ErrInvalidSession:
p.sessionStore.ClearSession(rw, req)
- p.SignInPage(rw, req, http.StatusOK)
+ p.OAuthStart(rw, req)
default:
tags = append(tags, "error:sign_in_error")
p.StatsdClient.Incr("application_error", tags, 1.0)
@@ -373,17 +354,13 @@ func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
fmt.Sprintf("proxy_host:%s", proxyHost),
}
- if req.Method == "GET" {
- p.SignOutPage(rw, req, "")
- return
- }
-
session, err := p.sessionStore.LoadSession(req)
switch err {
case nil:
+ // no error - we were able to load the session. continue onwards.
break
- // if there's no cookie in the session we can just redirect
case http.ErrNoCookie:
+ // if there's no session, we can just redirect back.
http.Redirect(rw, req, redirectURI, http.StatusFound)
return
default:
@@ -399,96 +376,63 @@ func (p *Authenticator) SignOut(rw http.ResponseWriter, req *http.Request) {
tags = append(tags, "error:revoke_session")
p.StatsdClient.Incr("provider_error", tags, 1.0)
logger.Error(err, "error revoking session")
- p.SignOutPage(rw, req, "An error occurred during sign out. Please try again.")
+ //TODO: This used to return a sign out page with an error.
+ //TODO: http.StatusInternalServerError or codeForError(err)
+ p.ErrorResponse(rw, req, err.Error(), http.StatusInternalServerError)
return
}
+ // if we reach here, the session has been revoked on the identity providers end,
+ // clear our session and redirect.
p.sessionStore.ClearSession(rw, req)
http.Redirect(rw, req, redirectURI, http.StatusFound)
}
-type signOutResp struct {
- ProviderSlug string
- Version string
- Redirect string
- Signature string
- Timestamp string
- Message string
- Destination string
- Email string
-}
-
-// SignOutPage renders a sign out page with a message
-func (p *Authenticator) SignOutPage(rw http.ResponseWriter, req *http.Request, message string) {
- // validateRedirectURI middleware already ensures that this is a valid URL
- redirectURI := req.Form.Get("redirect_uri")
-
- session, err := p.sessionStore.LoadSession(req)
- if err != nil {
- http.Redirect(rw, req, redirectURI, http.StatusFound)
- return
- }
-
- signature := req.Form.Get("sig")
- timestamp := req.Form.Get("ts")
- destinationURL, _ := url.Parse(redirectURI)
-
- // An error message indicates that an internal server error occurred
- if message != "" {
- rw.WriteHeader(http.StatusInternalServerError)
- }
-
- t := signOutResp{
- ProviderSlug: p.provider.Data().ProviderSlug,
- Version: VERSION,
- Redirect: redirectURI,
- Signature: signature,
- Timestamp: timestamp,
- Message: message,
- Destination: destinationURL.Host,
- Email: session.Email,
- }
- p.templates.ExecuteTemplate(rw, "sign_out.html", t)
- return
-}
-
// OAuthStart starts the authentication process by redirecting to the provider. It provides a
-// `redirectURI`, allowing the provider to redirect back to the sso proxy after authentication.
+// `redirectURI`, allowing the provider to redirect back to sso_auth after authentication.
func (p *Authenticator) OAuthStart(rw http.ResponseWriter, req *http.Request) {
tags := []string{"action:start"}
nonce := fmt.Sprintf("%x", aead.GenerateKey())
p.csrfStore.SetCSRF(rw, req, nonce)
- authRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
- if err != nil || !validRedirectURI(authRedirectURL.String(), p.ProxyRootDomains) {
- tags = append(tags, "error:invalid_redirect_parameter")
- p.StatsdClient.Incr("application_error", tags, 1.0)
- p.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
- return
- }
- // Here we validate the redirect that is nested within the redirect_uri.
- // `authRedirectURL` points to step D, `proxyRedirectURL` points to step E.
+
+ // Here we validate the redirects and signatures that are nested within the request. Each lettered step in the below diagram
+ // represents a step in the request flow.
//
- // A* B C D E
- // /start -> Google -> auth /callback -> /sign_in -> proxy /callback
+ // `authRedirectURL` points to step C, `proxyRedirectURL` points to step E.
//
// * you are here
- proxyRedirectURL, err := url.Parse(authRedirectURL.Query().Get("redirect_uri"))
+ //
+ // A
+ // sso_auth:/sign_in -> user already authenticated?
+ // |
+ // | * B C D E
+ // -> no -> (OAuthStart) Google/Okta -> sso_auth:/callback -> sso_auth:/sign_in -> (now authenticated) sso_proxy:/callback
+ // |
+ // | F
+ // -> yes -> sso_proxy:/callback
+
+ proxyRedirectURL, err := url.Parse(req.URL.Query().Get("redirect_uri"))
if err != nil || !validRedirectURI(proxyRedirectURL.String(), p.ProxyRootDomains) {
tags = append(tags, "error:invalid_redirect_parameter")
p.StatsdClient.Incr("application_error", tags, 1.0)
- p.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
+ p.ErrorResponse(rw, req, "Invalid proxy redirect parameter", http.StatusBadRequest)
return
}
- proxyRedirectSig := authRedirectURL.Query().Get("sig")
- ts := authRedirectURL.Query().Get("ts")
+
+ proxyRedirectSig := req.URL.Query().Get("sig")
+ ts := req.URL.Query().Get("ts")
if !validSignature(proxyRedirectURL.String(), proxyRedirectSig, ts, p.ProxyClientSecret) {
p.ErrorResponse(rw, req, "Invalid redirect parameter", http.StatusBadRequest)
return
}
- redirectURI := p.GetRedirectURI(req.Host)
- state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, authRedirectURL.String())))
- signInURL := p.provider.GetSignInURL(redirectURI, state)
+
+ state := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%v:%v", nonce, req.URL.String())))
+
+ authRedirectURL := p.GetRedirectURI(req.Host)
+ //TODO: do we want to validate this redirect URI again, even though a little redundant?
+ signInURL := p.provider.GetSignInURL(authRedirectURL, state)
+
http.Redirect(rw, req, signInURL, http.StatusFound)
}
diff --git a/internal/auth/authenticator_test.go b/internal/auth/authenticator_test.go
index 998cb998..e84fa17c 100644
--- a/internal/auth/authenticator_test.go
+++ b/internal/auth/authenticator_test.go
@@ -121,41 +121,26 @@ type errResponse struct {
}
func TestSignIn(t *testing.T) {
+
+ const (
+ SaveCookie = iota
+ ClearCookie
+ KeepCookie
+ )
+
testCases := []struct {
- name string
- paramsMap map[string]string
- mockCSRFStore *sessions.MockCSRFStore
- mockSessionStore *sessions.MockSessionStore
- mockAuthCodeCipher *aead.MockCipher
- refreshResponse providerRefreshResponse
- providerValidToken bool
- validEmail bool
- expectedSignInPage bool
- expectedDestinationURL string
- expectedCode int
- expectedErrorResponse *errResponse
+ name string
+ paramsMap map[string]string
+ mockCSRFStore *sessions.MockCSRFStore
+ mockSessionStore *sessions.MockSessionStore
+ mockAuthCodeCipher *aead.MockCipher
+ refreshResponse providerRefreshResponse
+ providerValidToken bool
+ validEmail bool
+ expectedCode int
+ expectedErrorResponse *errResponse
+ cookieExpectation int // One of: {SaveCookie, ClearCookie, KeepCookie}
}{
- {
- name: "err no cookie calls proxy oauth redirect, no params map redirects to sign in page",
- mockSessionStore: &sessions.MockSessionStore{
- LoadError: http.ErrNoCookie,
- },
- expectedSignInPage: true,
- expectedDestinationURL: "",
- expectedCode: http.StatusOK,
- },
- {
- name: "err no cookie calls proxy oauth redirect, with redirect url redirects to sign in page",
- mockSessionStore: &sessions.MockSessionStore{
- LoadError: http.ErrNoCookie,
- },
- paramsMap: map[string]string{
- "redirect_uri": "http://foo.example.com",
- },
- expectedSignInPage: true,
- expectedDestinationURL: "foo.example.com",
- expectedCode: http.StatusOK,
- },
{
name: "another error that isn't no cookie",
mockSessionStore: &sessions.MockSessionStore{
@@ -166,21 +151,7 @@ func TestSignIn(t *testing.T) {
},
expectedCode: http.StatusInternalServerError,
expectedErrorResponse: &errResponse{"another error"},
- },
- {
- name: "expired lifetime of session clears session and redirects to sign in",
- mockSessionStore: &sessions.MockSessionStore{
- Session: &sessions.SessionState{
- Email: "email",
- AccessToken: "accesstoken",
- RefreshToken: "refresh",
- LifetimeDeadline: time.Now().Add(-time.Hour),
- RefreshDeadline: time.Now().Add(time.Hour),
- },
- },
- expectedSignInPage: true,
- expectedDestinationURL: "",
- expectedCode: http.StatusOK,
+ cookieExpectation: ClearCookie,
},
{
name: "refresh period expired, provider error",
@@ -198,6 +169,7 @@ func TestSignIn(t *testing.T) {
},
expectedCode: http.StatusInternalServerError,
expectedErrorResponse: &errResponse{"provider error"},
+ cookieExpectation: ClearCookie,
},
{
name: "refresh period expired, not refreshed - unauthorized user",
@@ -213,6 +185,7 @@ func TestSignIn(t *testing.T) {
refreshResponse: providerRefreshResponse{},
expectedCode: http.StatusUnauthorized,
expectedErrorResponse: &errResponse{ErrUserNotAuthorized.Error()},
+ cookieExpectation: ClearCookie,
},
{
name: "refresh period expired, refresh ok, save session error",
@@ -231,6 +204,7 @@ func TestSignIn(t *testing.T) {
},
expectedCode: http.StatusInternalServerError,
expectedErrorResponse: &errResponse{"save error"},
+ cookieExpectation: ClearCookie,
},
{
name: "refresh period expired, successful refresh, invalid email",
@@ -248,6 +222,7 @@ func TestSignIn(t *testing.T) {
},
expectedCode: http.StatusUnauthorized,
expectedErrorResponse: &errResponse{ErrUserNotAuthorized.Error()},
+ cookieExpectation: KeepCookie,
},
{
name: "valid session state, save session error",
@@ -264,6 +239,7 @@ func TestSignIn(t *testing.T) {
providerValidToken: true,
expectedCode: http.StatusInternalServerError,
expectedErrorResponse: &errResponse{"save error"},
+ cookieExpectation: ClearCookie,
},
{
name: "invalid session state, invalid email",
@@ -278,6 +254,7 @@ func TestSignIn(t *testing.T) {
},
expectedCode: http.StatusUnauthorized,
expectedErrorResponse: &errResponse{ErrUserNotAuthorized.Error()},
+ cookieExpectation: ClearCookie,
},
{
name: "refresh period expired, successful refresh, no state in params",
@@ -296,6 +273,7 @@ func TestSignIn(t *testing.T) {
validEmail: true,
expectedCode: http.StatusForbidden,
expectedErrorResponse: &errResponse{"no state parameter supplied"},
+ cookieExpectation: KeepCookie,
},
{
name: "refresh period expired, successful refresh, no redirect in params",
@@ -317,6 +295,7 @@ func TestSignIn(t *testing.T) {
validEmail: true,
expectedCode: http.StatusForbidden,
expectedErrorResponse: &errResponse{"no redirect_uri parameter supplied"},
+ cookieExpectation: KeepCookie,
},
{
name: "refresh period expired, successful refresh, malformed redirect in params",
@@ -339,6 +318,7 @@ func TestSignIn(t *testing.T) {
validEmail: true,
expectedCode: http.StatusBadRequest,
expectedErrorResponse: &errResponse{"malformed redirect_uri parameter passed"},
+ cookieExpectation: KeepCookie,
},
{
name: "refresh period expired, unsuccessful marshal",
@@ -364,6 +344,7 @@ func TestSignIn(t *testing.T) {
validEmail: true,
expectedCode: http.StatusInternalServerError,
expectedErrorResponse: &errResponse{"error marshal"},
+ cookieExpectation: KeepCookie,
},
{
name: "refresh period expired, successful refresh",
@@ -386,8 +367,9 @@ func TestSignIn(t *testing.T) {
mockAuthCodeCipher: &aead.MockCipher{
MarshalString: "abcdefg",
},
- validEmail: true,
- expectedCode: http.StatusFound,
+ validEmail: true,
+ expectedCode: http.StatusFound,
+ cookieExpectation: SaveCookie,
},
{
name: "valid session state, successful save",
@@ -409,8 +391,8 @@ func TestSignIn(t *testing.T) {
mockAuthCodeCipher: &aead.MockCipher{
MarshalString: "abcdefg",
},
-
- expectedCode: http.StatusFound,
+ expectedCode: http.StatusFound,
+ cookieExpectation: SaveCookie,
},
}
@@ -420,6 +402,7 @@ func TestSignIn(t *testing.T) {
auth, err := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
setMockSessionStore(tc.mockSessionStore),
+ setMockCSRFStore(&sessions.MockCSRFStore{}),
setMockTempl(),
setMockRedirectURL(),
setMockAuthCodeCipher(tc.mockAuthCodeCipher, nil),
@@ -449,20 +432,6 @@ func TestSignIn(t *testing.T) {
resp := rw.Result()
respBytes, err := ioutil.ReadAll(resp.Body)
testutil.Ok(t, err)
- if tc.expectedSignInPage {
- expectedSignInResp := &signInResp{
- ProviderName: provider.Data().ProviderName,
- ProviderSlug: "test",
- EmailDomains: auth.EmailDomains,
- Redirect: u.String(),
- Destination: tc.expectedDestinationURL,
- Version: VERSION,
- }
- actualSignInResp := &signInResp{}
- err := json.Unmarshal(respBytes, actualSignInResp)
- testutil.Ok(t, err)
- testutil.Equal(t, expectedSignInResp, actualSignInResp)
- }
if tc.expectedErrorResponse != nil {
actualErrorResponse := &errResponse{}
@@ -470,46 +439,30 @@ func TestSignIn(t *testing.T) {
testutil.Ok(t, err)
testutil.Equal(t, tc.expectedErrorResponse, actualErrorResponse)
}
- // TODO: add a cleared session cookie check for errored stuff
+ switch tc.cookieExpectation {
+ case SaveCookie:
+ testutil.NotEqual(t, tc.mockSessionStore.ResponseSession, "")
+ case KeepCookie:
+ testutil.NotEqual(t, tc.mockSessionStore.ResponseSession, "")
+ case ClearCookie:
+ testutil.Equal(t, tc.mockSessionStore.ResponseSession, "")
+ }
})
}
}
-func TestSignOutPage(t *testing.T) {
+func TestSignOut(t *testing.T) {
testCases := []struct {
- Name string
- ExpectedStatusCode int
- paramsMap map[string]string
- RedirectURI string
- Method string
- mockSessionStore *sessions.MockSessionStore
- RevokeError error
- expectedSignOutResp *signOutResp
+ Name string
+ ExpectedStatusCode int
+ paramsMap map[string]string
+ RedirectURI string
+ Method string
+ mockSessionStore *sessions.MockSessionStore
+ RevokeError error
+ SuccessfulRevoke bool
}{
- {
- Name: "successful sign out page",
- paramsMap: map[string]string{
- "redirect_uri": "http://service.example.com",
- },
- mockSessionStore: &sessions.MockSessionStore{
- Session: &sessions.SessionState{
- Email: "test@example.com",
- RefreshDeadline: time.Now().Add(time.Hour),
- AccessToken: "accessToken",
- RefreshToken: "refreshToken",
- },
- },
- ExpectedStatusCode: http.StatusOK,
- Method: "GET",
- expectedSignOutResp: &signOutResp{
- ProviderSlug: "test",
- Version: VERSION,
- Redirect: "http://service.example.com",
- Destination: "service.example.com",
- Email: "test@example.com",
- },
- },
{
Name: "redirect if no session exists on GET",
ExpectedStatusCode: http.StatusFound,
@@ -533,14 +486,19 @@ func TestSignOutPage(t *testing.T) {
Method: "POST",
},
{
- Name: "sign out page also used to POST",
+ Name: "clear session and redirect if unexpected error occurs loading session",
ExpectedStatusCode: http.StatusFound,
- mockSessionStore: &sessions.MockSessionStore{},
- RedirectURI: "http://service.example.com",
- Method: "POST",
+ SuccessfulRevoke: true,
+ mockSessionStore: &sessions.MockSessionStore{
+ LoadError: sessions.ErrInvalidSession,
+ },
+ paramsMap: map[string]string{
+ "redirect_uri": "http://service.example.com",
+ },
+ Method: "POST",
},
{
- Name: "sign out page shows error message if revoke fails",
+ Name: "sign out returns error if revoke fails",
ExpectedStatusCode: http.StatusInternalServerError,
mockSessionStore: &sessions.MockSessionStore{
Session: &sessions.SessionState{
@@ -554,6 +512,23 @@ func TestSignOutPage(t *testing.T) {
RedirectURI: "http://service.example.com",
Method: "POST",
},
+ {
+ Name: "successful revoke and redirect on POST",
+ ExpectedStatusCode: http.StatusFound,
+ SuccessfulRevoke: true,
+ mockSessionStore: &sessions.MockSessionStore{
+ Session: &sessions.SessionState{
+ Email: "test@exampknafsadle.com",
+ RefreshDeadline: time.Now().Add(time.Hour),
+ AccessToken: "accessToken",
+ RefreshToken: "refreshToken",
+ },
+ },
+ paramsMap: map[string]string{
+ "redirect_uri": "http://service.example.com",
+ },
+ Method: "POST",
+ },
}
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
@@ -584,16 +559,21 @@ func TestSignOutPage(t *testing.T) {
p.SignOut(rw, req)
testutil.Equal(t, tc.ExpectedStatusCode, rw.Code)
+
resp := rw.Result()
respBytes, err := ioutil.ReadAll(resp.Body)
- testutil.Ok(t, err)
-
- if tc.expectedSignOutResp != nil {
- actualSignOutResp := &signOutResp{}
- err := json.Unmarshal(respBytes, actualSignOutResp)
- testutil.Ok(t, err)
- testutil.Equal(t, tc.expectedSignOutResp, actualSignOutResp)
+ if tc.RevokeError != nil {
+ if !strings.Contains(string(respBytes), tc.RevokeError.Error()) {
+ t.Logf("expected response to contain: %q", tc.RevokeError.Error())
+ t.Logf(" got response: %q", string(respBytes))
+ t.Error("unexpected response after revoke error")
+ }
+ }
+ if tc.SuccessfulRevoke {
+ testutil.Equal(t, tc.mockSessionStore.ResponseSession, "")
+ //TODO: test the session has been cleared?
}
+ testutil.Ok(t, err)
})
}
@@ -1516,9 +1496,9 @@ func TestGlobalHeaders(t *testing.T) {
func TestOAuthStart(t *testing.T) {
testCases := []struct {
Name string
- RedirectURI string
ProxyRedirectURI string
ExpectedStatusCode int
+ InvalidSignature bool
}{
{
Name: "reject requests without a redirect",
@@ -1526,23 +1506,22 @@ func TestOAuthStart(t *testing.T) {
},
{
Name: "reject requests with a malicious auth",
- RedirectURI: "https://auth.evil.com/sign_in",
+ ProxyRedirectURI: "https://auth.evil.com/sign_in",
ExpectedStatusCode: http.StatusBadRequest,
},
{
- Name: "reject requests without a nested redirect",
- RedirectURI: "https://auth.example.com/sign_in",
+ Name: "reject requests with a malicious proxy",
+ ProxyRedirectURI: "https://proxy.evil.com/path/to/badness",
ExpectedStatusCode: http.StatusBadRequest,
},
{
- Name: "reject requests with a malicious proxy",
- RedirectURI: "https://auth.example.com/sign_in",
- ProxyRedirectURI: "https://proxy.evil.com/path/to/badness",
+ Name: "reject requests with invalid signature",
+ ProxyRedirectURI: "https://proxy.example.com/oauth/callback",
+ InvalidSignature: true,
ExpectedStatusCode: http.StatusBadRequest,
},
{
Name: "accept requests with good redirect_uris",
- RedirectURI: "https://auth.example.com/sign_in",
ProxyRedirectURI: "https://proxy.example.com/oauth/callback",
ExpectedStatusCode: http.StatusFound,
},
@@ -1560,20 +1539,19 @@ func TestOAuthStart(t *testing.T) {
)
params := url.Values{}
- if tc.RedirectURI != "" {
- redirectURL, _ := url.Parse(tc.RedirectURI)
- if tc.ProxyRedirectURI != "" {
- // NOTE: redirect signatures tested in middleware_test.go
- now := time.Now()
- sig := redirectURLSignature(tc.ProxyRedirectURI, now, config.ClientConfigs["proxy"].Secret)
- b64sig := base64.URLEncoding.EncodeToString(sig)
- redirectParams := url.Values{}
- redirectParams.Add("redirect_uri", tc.ProxyRedirectURI)
- redirectParams.Add("sig", b64sig)
- redirectParams.Add("ts", fmt.Sprint(now.Unix()))
- redirectURL.RawQuery = redirectParams.Encode()
+ if tc.ProxyRedirectURI != "" {
+ // NOTE: redirect signatures tested in middleware_test.go
+ sig := []byte("")
+ now := time.Now()
+ if tc.InvalidSignature {
+ sig = redirectURLSignature(tc.ProxyRedirectURI, now, "badSecret")
+ } else {
+ sig = redirectURLSignature(tc.ProxyRedirectURI, now, config.ClientConfigs["proxy"].Secret)
}
- params.Add("redirect_uri", redirectURL.String())
+ b64sig := base64.URLEncoding.EncodeToString(sig)
+ params.Add("redirect_uri", tc.ProxyRedirectURI)
+ params.Add("sig", b64sig)
+ params.Add("ts", fmt.Sprint(now.Unix()))
}
req := httptest.NewRequest("GET", "/start?"+params.Encode(), nil)
diff --git a/internal/pkg/templates/templates.go b/internal/pkg/templates/templates.go
index 33b2960e..3059fdad 100644
--- a/internal/pkg/templates/templates.go
+++ b/internal/pkg/templates/templates.go
@@ -34,13 +34,13 @@ Secured by SSO{{end}}`))
t = template.Must(t.Parse(`{{define "sign_in_message.html"}}
{{if eq (len .EmailDomains) 1}}
{{if eq (index .EmailDomains 0) "@*"}}
-
You may sign in with any {{.ProviderName}} account.
+ You may sign in with any {{.ProviderSlug}} account.
{{else}}
- You may sign in with your {{index .EmailDomains 0}} {{.ProviderName}} account.
+ You may sign in with your {{index .EmailDomains 0}} {{.ProviderSlug}} account.
{{end}}
{{else if gt (len .EmailDomains) 1}}
- You may sign in with any of these {{.ProviderName}} accounts:
+ You may sign in with any of these {{.ProviderSlug}} accounts:
{{range $i, $e := .EmailDomains}}{{if $i}}, {{end}}{{$e}}{{end}}
{{end}}
@@ -64,9 +64,15 @@ Secured by SSO{{end}}`))
{{template "sign_in_message.html" .}}
-
@@ -108,19 +114,16 @@ Secured by SSO{{end}}`))
- {{ if .Message }}
-
{{.Message}}
- {{ end}}
Sign out of {{.Destination}}
You're currently signed in as {{.Email}}. This will also sign you out of other internal apps.
-
diff --git a/internal/pkg/templates/templates_test.go b/internal/pkg/templates/templates_test.go
index 1da4adf9..167ffcab 100644
--- a/internal/pkg/templates/templates_test.go
+++ b/internal/pkg/templates/templates_test.go
@@ -58,10 +58,10 @@ func TestSignInMessage(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
buf := &bytes.Buffer{}
ctx := struct {
- ProviderName string
+ ProviderSlug string
EmailDomains []string
}{
- ProviderName: "Google",
+ ProviderSlug: "Google",
EmailDomains: tc.emailDomains,
}
templates.ExecuteTemplate(buf, "sign_in_message.html", ctx)
diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go
index 0944c95e..b425edba 100644
--- a/internal/proxy/oauthproxy.go
+++ b/internal/proxy/oauthproxy.go
@@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
- "html/template"
"net/http"
"net/url"
"reflect"
@@ -15,6 +14,7 @@ import (
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
+ "github.com/buzzfeed/sso/internal/pkg/templates"
"github.com/buzzfeed/sso/internal/proxy/providers"
"github.com/datadog/datadog-go/statsd"
@@ -59,7 +59,7 @@ type OAuthProxy struct {
cookieSecure bool
Validators []options.Validator
redirectURL *url.URL // the url to receive requests at
- templates *template.Template
+ templates templates.Template
skipAuthPreflight bool
passAccessToken bool
@@ -179,7 +179,7 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr
Validators: []options.Validator{},
redirectURL: &url.URL{Path: "/oauth2/callback"},
- templates: getTemplates(),
+ templates: templates.NewHTMLTemplate(),
skipAuthPreflight: opts.SkipAuthPreflight,
passAccessToken: opts.PassAccessToken,
@@ -229,15 +229,24 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr
// Handler returns a http handler for an OAuthProxy
func (p *OAuthProxy) Handler() http.Handler {
+ logger := log.NewLogEntry()
+
mux := http.NewServeMux()
mux.HandleFunc("/favicon.ico", p.Favicon)
mux.HandleFunc("/robots.txt", p.RobotsTxt)
mux.HandleFunc("/oauth2/v1/certs", p.Certs)
- mux.HandleFunc("/oauth2/sign_out", p.SignOut)
+ mux.HandleFunc("/oauth2/sign_out", p.SignOutPage)
mux.HandleFunc("/oauth2/callback", p.OAuthCallback)
mux.HandleFunc("/oauth2/auth", p.AuthenticateOnly)
mux.HandleFunc("/", p.Proxy)
+ // load static files
+ fsHandler, err := loadFSHandler()
+ if err != nil {
+ logger.Fatal(err)
+ }
+ mux.Handle("/static/", http.StripPrefix("/static/", fsHandler))
+
// Global middleware, which will be applied to each request in reverse
// order as applied here (i.e., we want to validate the host _first_ when
// processing a request)
@@ -334,6 +343,103 @@ func (p *OAuthProxy) XHRError(rw http.ResponseWriter, req *http.Request, code in
rw.Write(jsonBytes)
}
+type signInResp struct {
+ ProviderSlug string
+ EmailDomains []string
+ ClientID string
+ Action string
+ Redirect string
+ Destination string
+ Version string
+ SignInParams providers.SignInParams
+}
+
+// SignInPage renders a sign in page stating the in-use provider and email domains,
+// and redirects to sso_auth.
+func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, state string) {
+ rw.WriteHeader(http.StatusFound)
+
+ // this forms req.Host + /oauth2/callback
+ callbackURL := p.GetRedirectURL(req.Host)
+
+ // this forms the sso-auth signin URL + callbackURL as the redirect URL,
+ // + other parameters (client id, scope etc)
+ signInURL, signInParams := p.provider.GetSignInURL(callbackURL, state)
+
+ // validateRedirectURI middleware already ensures that this is a valid URL
+
+ t := signInResp{
+ ProviderSlug: strings.Title(p.provider.Data().ProviderSlug),
+ EmailDomains: p.upstreamConfig.AllowedEmailDomains,
+ Action: signInURL.String(),
+ Destination: callbackURL.Host,
+ Version: VERSION,
+ SignInParams: signInParams,
+ }
+ p.templates.ExecuteTemplate(rw, "sign_in.html", t)
+}
+
+type signOutResp struct {
+ ProviderSlug string
+ Version string
+ Action string
+ Message string
+ Destination string
+ Email string
+ SignOutParams providers.SignOutParams
+}
+
+// SignOutPage renders a sign out page
+func (p *OAuthProxy) SignOutPage(rw http.ResponseWriter, req *http.Request) {
+ // Build redirect URI from request host
+
+ scheme := req.URL.Scheme
+ if req.URL.Scheme == "" {
+ if p.cookieSecure {
+ scheme = "https"
+ } else {
+ scheme = "http"
+ }
+ }
+
+ redirectURL := &url.URL{
+ Scheme: scheme,
+ Host: req.Host,
+ Path: "/",
+ }
+
+ signOutURL, signOutParams := p.provider.GetSignOutURL(redirectURL)
+
+ session, err := p.sessionStore.LoadSession(req)
+ if err != nil {
+ // If no session exists on sso_proxy, we just redirect
+ // straight to sso_auth so any session there can be properly
+ // cleared
+ params, _ := url.ParseQuery(signOutURL.RawQuery)
+ params.Set("redirect_uri", signOutParams.RedirectURL)
+ params.Set("ts", signOutParams.TimeStamp)
+ params.Set("sig", signOutParams.Signature)
+ signOutURL.RawQuery = params.Encode()
+
+ p.sessionStore.ClearSession(rw, req)
+ http.Redirect(rw, req, signOutURL.String(), http.StatusFound)
+ return
+ }
+
+ // else, if there is a session we render a sign out page
+ t := signOutResp{
+ ProviderSlug: strings.Title(p.provider.Data().ProviderSlug),
+ Version: VERSION,
+ Action: signOutURL.String(),
+ Destination: redirectURL.Host,
+ Email: session.Email,
+ SignOutParams: signOutParams,
+ }
+
+ p.sessionStore.ClearSession(rw, req)
+ p.templates.ExecuteTemplate(rw, "sign_out.html", t)
+}
+
// ErrorPage renders an error page with a given status code, title, and message.
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code int, title string, message string) {
if p.isXHR(req) {
@@ -379,30 +485,6 @@ func (p *OAuthProxy) isXHR(req *http.Request) bool {
return req.Header.Get("X-Requested-With") == "XMLHttpRequest"
}
-// SignOut redirects the request to the provider's sign out url.
-func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
- p.sessionStore.ClearSession(rw, req)
-
- var scheme string
-
- // Build redirect URI from request host
- if req.URL.Scheme == "" {
- if p.cookieSecure {
- scheme = "https"
- } else {
- scheme = "http"
- }
- }
-
- redirectURL := &url.URL{
- Scheme: scheme,
- Host: req.Host,
- Path: "/",
- }
- fullURL := p.provider.GetSignOutURL(redirectURL)
- http.Redirect(rw, req, fullURL.String(), http.StatusFound)
-}
-
// OAuthStart begins the authentication flow, encrypting the redirect url in a request to the provider's sign in endpoint.
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, tags []string) {
// The proxy redirects to the authenticator, and provides it with redirectURI (which points
@@ -417,7 +499,6 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, tags
}
requestURI := req.URL.String()
- callbackURL := p.GetRedirectURL(req.Host)
// We redirect the browser to the authenticator with a 302 status code. The target URL is
// constructed using the GetSignInURL() method, which encodes the following data:
@@ -470,10 +551,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request, tags
p.ErrorPage(rw, req, http.StatusInternalServerError, "Internal Error", err.Error())
return
}
-
- signinURL := p.provider.GetSignInURL(callbackURL, encryptedState)
- logger.WithSignInURL(signinURL).Info("starting OAuth flow")
- http.Redirect(rw, req, signinURL.String(), http.StatusFound)
+ p.SignInPage(rw, req, encryptedState)
}
// OAuthCallback validates the cookie sent back from the provider, then validates
diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go
index d21aced8..51031372 100644
--- a/internal/proxy/oauthproxy_test.go
+++ b/internal/proxy/oauthproxy_test.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"io/ioutil"
"log"
"net/http"
@@ -69,6 +70,15 @@ func setCookieSecure(cookieSecure bool) func(*OAuthProxy) error {
}
}
+type TemplateSpy struct {
+ executeTemplate func(io.Writer, string, interface{})
+}
+
+// TODO: we can probably use the mock_template package...
+func (s TemplateSpy) ExecuteTemplate(rw io.Writer, tpl string, data interface{}) {
+ s.executeTemplate(rw, tpl, data)
+}
+
func testSession() *sessions.SessionState {
theFuture := time.Now().AddDate(100, 100, 100)
@@ -914,6 +924,7 @@ func TestOAuthStart(t *testing.T) {
}
proxy, close := testNewOAuthProxy(t,
+
setSessionStore(&sessions.MockSessionStore{}),
setCSRFStore(csrfStore),
setCookieCipher(cookieCipher),
@@ -925,45 +936,56 @@ func TestOAuthStart(t *testing.T) {
if tc.isXHR {
req.Header.Add("X-Requested-With", "XMLHttpRequest")
}
-
+ // proxy.OAuthStart results in a html page being rendered with some data
+ // in it that we want test.
+ // We take advantage of proxy.templates being an interface and use it to
+ // test the 'state' in the struct that is used to render the html page, rather
+ // than parsing the html page itself for the state.
+ proxy.templates = &TemplateSpy{
+ executeTemplate: func(rw io.Writer, tpl string, data interface{}) {
+ signInResp, ok := data.(signInResp)
+ if !ok {
+ t.Fatalf("Invalid struct used: expected 'signInResp'")
+ }
+
+ state := signInResp.SignInParams.State
+
+ cookieParameter := &StateParameter{}
+ err = json.Unmarshal([]byte(csrfStore.ResponseCSRF), cookieParameter)
+ if err != nil {
+ t.Errorf("response csrf: %v", csrfStore.ResponseCSRF)
+ t.Fatalf("unexpected err during unmarshal: %v", err)
+ }
+
+ stateParameter := &StateParameter{}
+ err = json.Unmarshal([]byte(state), stateParameter)
+ if err != nil {
+ t.Errorf("state: %v", state)
+ t.Fatalf("unexpected err during unmarshal: %v", err)
+ }
+
+ if !reflect.DeepEqual(cookieParameter, stateParameter) {
+ t.Logf("cookie parameter: %#v", cookieParameter)
+ t.Logf(" state parameter: %#v", stateParameter)
+ t.Fatalf("expected structs to be equal")
+ }
+ },
+ }
+ // Note: In this case, proxy.OAuthStart does not cause a html page
+ // to be rendered due to us replacing proxy.templates above.
proxy.OAuthStart(rw, req, []string{})
res := rw.Result()
if res.StatusCode != tc.expectedStatusCode {
- t.Fatalf("unexpected status code response")
+ t.Logf("expected status code: %q", tc.expectedStatusCode)
+ t.Logf(" got: %q", res.StatusCode)
+ t.Fatalf("unexpected status code returned")
}
+ //TODO: what's the intended course of action here?
if tc.expectedStatusCode != http.StatusFound {
return
}
-
- location, err := res.Location()
- if err != nil {
- t.Fatalf("expected req to succeeded err:%v", err)
- }
-
- state := location.Query().Get("state")
-
- cookieParameter := &StateParameter{}
- err = json.Unmarshal([]byte(csrfStore.ResponseCSRF), cookieParameter)
- if err != nil {
- t.Errorf("response csrf: %v", csrfStore.ResponseCSRF)
- t.Fatalf("unexpected err during unmarshal: %v", err)
- }
-
- stateParameter := &StateParameter{}
- err = json.Unmarshal([]byte(state), stateParameter)
- if err != nil {
- t.Errorf("location: %v", location)
- t.Errorf("state: %v", state)
- t.Fatalf("unexpected err during unmarshal: %v", err)
- }
-
- if !reflect.DeepEqual(cookieParameter, stateParameter) {
- t.Logf("cookie parameter: %#v", cookieParameter)
- t.Logf(" state parameter: %#v", stateParameter)
- t.Fatalf("expected structs to be equal")
- }
})
}
}
@@ -1076,8 +1098,24 @@ func TestSecurityHeaders(t *testing.T) {
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
testutil.Assert(t, err == nil, "could not read http response body: %v", err)
- if string(body) != tc.path {
- t.Errorf("expected body = %q, got %q", tc.path, string(body))
+
+ actualBody := string(body)
+ expectedBody := tc.path
+ // if the request is not authenticated (it has no proxy session),
+ // we expect a sign in page to be rendered.
+ if !tc.authenticated {
+ expectedBody = "Sign in with"
+ if !strings.Contains(actualBody, expectedBody) {
+ t.Logf("expected body to contain: %q", expectedBody)
+ t.Logf(" got: %q", actualBody)
+ t.Errorf("received invalid body")
+ }
+ } else {
+ if actualBody != expectedBody {
+ t.Logf("expected body: %q", expectedBody)
+ t.Logf(" got: %q", actualBody)
+ t.Errorf("received invalid body")
+ }
}
}
@@ -1154,16 +1192,15 @@ func TestHeaderOverrides(t *testing.T) {
func TestHTTPSRedirect(t *testing.T) {
testCases := []struct {
- name string
- url string
- host string
- cookieSecure bool
- authenticated bool
- requestHeaders map[string]string
- expectedCode int
- expectedLocation string // must match entire Location header
- expectedLocationHost string // just match hostname of Location header
- expectSTS bool // should we get a Strict-Transport-Security header?
+ name string
+ url string
+ host string
+ cookieSecure bool
+ authenticated bool
+ requestHeaders map[string]string
+ expectedCode int
+ expectedLocation string // must match entire Location header
+ expectSTS bool // should we get a Strict-Transport-Security header?
}{
{
name: "no https redirect with http and cookie_secure=false and authenticated=true",
@@ -1174,13 +1211,20 @@ func TestHTTPSRedirect(t *testing.T) {
expectSTS: false,
},
{
- name: "no https redirect with http cookie_secure=false and authenticated=false",
- url: "http://localhost/",
- cookieSecure: false,
- authenticated: false,
- expectedCode: http.StatusFound,
- expectedLocationHost: "localhost",
- expectSTS: false,
+ name: "no https redirect with http cookie_secure=false and authenticated=false",
+ url: "http://localhost/",
+ cookieSecure: false,
+ authenticated: false,
+ expectedCode: http.StatusFound,
+ expectSTS: false,
+ },
+ {
+ name: "no https redirect with https and cookie_secure=false and authenticated=false",
+ url: "https://localhost/",
+ cookieSecure: false,
+ authenticated: false,
+ expectedCode: http.StatusFound,
+ expectSTS: false,
},
{
name: "no https redirect with https and cookie_secure=false and authenticated=true",
@@ -1190,16 +1234,7 @@ func TestHTTPSRedirect(t *testing.T) {
expectedCode: http.StatusOK,
expectSTS: false,
},
- {
- name: "no https redirect with https and cookie_secure=false and authenticated=false",
- url: "https://localhost/",
- cookieSecure: false,
- authenticated: false,
- expectedCode: http.StatusFound,
- expectedLocationHost: "localhost",
- expectSTS: false,
- },
- {
+ { //TODO
name: "https redirect with cookie_secure=true and authenticated=false",
url: "http://localhost/",
cookieSecure: true,
@@ -1294,15 +1329,21 @@ func TestHTTPSRedirect(t *testing.T) {
}
location := rw.Header().Get("Location")
- locationURL, err := url.Parse(location)
- if err != nil {
- t.Errorf("error parsing location %q: %s", location, err)
- }
- if tc.expectedLocation != "" && location != tc.expectedLocation {
+ if location != tc.expectedLocation {
t.Errorf("expected Location=%q, got Location=%q", tc.expectedLocation, location)
}
- if tc.expectedLocationHost != "" && locationURL.Hostname() != tc.expectedLocationHost {
- t.Errorf("expected location host = %q, got %q", tc.expectedLocationHost, locationURL.Hostname())
+
+ // if we do not require https, and there is no session we expect a sign in page to be
+ // rendered
+ if !tc.cookieSecure && !tc.authenticated {
+ expectedBody := "Sign in with"
+ actualBody, err := ioutil.ReadAll(rw.Body)
+ testutil.Assert(t, err == nil, "could not read http response body: %v", err)
+ if !strings.Contains(string(actualBody), expectedBody) {
+ t.Logf("expected body to contain: %q", expectedBody)
+ t.Logf(" got: %q", string(actualBody))
+ t.Errorf("received invalid body")
+ }
}
stsKey := http.CanonicalHeaderKey("Strict-Transport-Security")
@@ -1321,3 +1362,127 @@ func TestHTTPSRedirect(t *testing.T) {
})
}
}
+
+func TestSignOutPage(t *testing.T) {
+ testCases := []struct {
+ Name string
+ ExpectedStatusCode int
+ cookieSecure bool
+ mockSessionStore *sessions.MockSessionStore
+ RevokeError error
+ expectedSignOutResp signOutResp
+ expectedLocation string
+ }{
+ {
+ Name: "successful rendered sign out html page",
+ mockSessionStore: &sessions.MockSessionStore{
+ Session: &sessions.SessionState{
+ Email: "test@example.com",
+ RefreshDeadline: time.Now().Add(time.Hour),
+ AccessToken: "accessToken",
+ RefreshToken: "refreshToken",
+ },
+ },
+ ExpectedStatusCode: http.StatusOK,
+ },
+ {
+ Name: "successful sign out response",
+ mockSessionStore: &sessions.MockSessionStore{
+ Session: &sessions.SessionState{
+ Email: "test@example.com",
+ RefreshDeadline: time.Now().Add(time.Hour),
+ AccessToken: "accessToken",
+ RefreshToken: "refreshToken",
+ },
+ },
+ ExpectedStatusCode: http.StatusOK,
+ expectedSignOutResp: signOutResp{
+ //TODO: standardise on ProviderSlug or ProviderName?
+ ProviderSlug: "",
+ Version: VERSION,
+ Action: "http://localhost/oauth/sign_out",
+ Destination: "example.com",
+ Email: "test@example.com",
+ SignOutParams: providers.SignOutParams{
+ RedirectURL: "https://example.com/",
+ },
+ },
+ },
+ { // TODO: Check the redirect URL is what we expect
+ Name: "redirect to sso_auth if no session exists",
+ mockSessionStore: &sessions.MockSessionStore{
+ LoadError: http.ErrNoCookie,
+ },
+ ExpectedStatusCode: http.StatusFound,
+ expectedLocation: "http://localhost/oauth/sign_out?redirect_uri=https%3A%2F%2Fexample.com%2F&sig=&ts=",
+ },
+ //{
+ // Name: "cookieSecure sets scheme to https, if no scheme included",
+ //},
+ //{
+ // Name: "session is cleared before rendering template/redirecting",
+ //},
+ }
+ for _, tc := range testCases {
+ t.Run(tc.Name, func(t *testing.T) {
+
+ // set up the provider
+ providerURL, _ := url.Parse("http://localhost/")
+ testProvider := providers.NewTestProvider(providerURL, "")
+
+ proxy, close := testNewOAuthProxy(t,
+ SetProvider(testProvider),
+ setSessionStore(tc.mockSessionStore),
+ setCookieSecure(tc.cookieSecure),
+ )
+ defer close()
+
+ // if this particular test case tests for a specific response struct,
+ // we replace the proxy template function to avoid having to parse
+ // the html response
+ if tc.expectedSignOutResp != (signOutResp{}) {
+ proxy.templates = &TemplateSpy{
+ executeTemplate: func(rw io.Writer, tpl string, data interface{}) {
+ signOutResp, ok := data.(signOutResp)
+ if !ok {
+ t.Fatalf("Invalid struct used: expected 'signOutResp'")
+ }
+
+ testutil.Equal(t, tc.expectedSignOutResp, signOutResp)
+ },
+ }
+ }
+
+ rw := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", "https://example.com/sign_out", nil)
+
+ proxy.SignOutPage(rw, req)
+
+ testutil.Equal(t, tc.ExpectedStatusCode, rw.Code)
+
+ location := rw.Header().Get("Location")
+ if location != tc.expectedLocation {
+ t.Errorf("expected Location=%q, got Location=%q", tc.expectedLocation, location)
+ }
+
+ // if this particular test case doesn't test for a specific
+ // response struct, and is not a straight redirect,
+ // then make sure that the template is returned as expected.
+ if tc.expectedSignOutResp == (signOutResp{}) && tc.ExpectedStatusCode != 302 {
+ resp := rw.Result()
+ respBytes, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unable to properly parse response: %v", err)
+ }
+
+ expectedBody := "Sign out of
example.com"
+ actualBody := string(respBytes)
+ if !strings.Contains(actualBody, expectedBody) {
+ t.Logf("expected body to contain: %q", expectedBody)
+ t.Logf(" got: %q", actualBody)
+ t.Errorf("received invalid body")
+ }
+ }
+ })
+ }
+}
diff --git a/internal/proxy/providers/providers.go b/internal/proxy/providers/providers.go
index 226dc16f..0140b425 100644
--- a/internal/proxy/providers/providers.go
+++ b/internal/proxy/providers/providers.go
@@ -14,8 +14,8 @@ type Provider interface {
ValidateGroup(string, []string, string) ([]string, bool, error)
UserGroups(string, []string, string) ([]string, error)
ValidateSessionState(*sessions.SessionState, []string) bool
- GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL
- GetSignOutURL(redirectURL *url.URL) *url.URL
+ GetSignInURL(redirectURL *url.URL, finalRedirect string) (*url.URL, SignInParams)
+ GetSignOutURL(redirectURL *url.URL) (*url.URL, SignOutParams)
RefreshSession(*sessions.SessionState, []string) (bool, error)
}
diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go
index 40e6fcf7..be0fea87 100644
--- a/internal/proxy/providers/singleflight_middleware.go
+++ b/internal/proxy/providers/singleflight_middleware.go
@@ -131,11 +131,11 @@ func (p *SingleFlightProvider) RefreshSession(s *sessions.SessionState, allowedG
}
// GetSignInURL calls the GetSignInURL for the provider, which will return the sign in url
-func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) *url.URL {
+func (p *SingleFlightProvider) GetSignInURL(redirectURI *url.URL, finalRedirect string) (*url.URL, SignInParams) {
return p.provider.GetSignInURL(redirectURI, finalRedirect)
}
// GetSignOutURL calls the GetSignOutURL for the provider, which will return the sign out url
-func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) *url.URL {
+func (p *SingleFlightProvider) GetSignOutURL(redirectURI *url.URL) (*url.URL, SignOutParams) {
return p.provider.GetSignOutURL(redirectURI)
}
diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go
index 1aca089c..6bc4086c 100644
--- a/internal/proxy/providers/sso.go
+++ b/internal/proxy/providers/sso.go
@@ -408,35 +408,60 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou
return true
}
+type SignInParams struct {
+ RedirectURL string
+ Scope string
+ ClientID string
+ ResponseType string
+ State string
+ TimeStamp string
+ Signature string
+}
+
// GetSignInURL with typical oauth parameters
-func (p *SSOProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
+func (p *SSOProvider) GetSignInURL(redirectURL *url.URL, state string) (*url.URL, SignInParams) {
+
a := *p.Data().SignInURL
now := time.Now()
rawRedirect := redirectURL.String()
- params, _ := url.ParseQuery(a.RawQuery)
- params.Set("redirect_uri", rawRedirect)
- params.Add("scope", p.Scope)
- params.Set("client_id", p.ClientID)
- params.Set("response_type", "code")
- params.Add("state", state)
- params.Set("ts", fmt.Sprint(now.Unix()))
- params.Set("sig", p.signRedirectURL(rawRedirect, now))
- a.RawQuery = params.Encode()
- return &a
+
+ //TODO: With this, we overwrite any previously set scope or state params
+ // because we are no longer appending with `.Add`. Need to check this out
+ signInParams := SignInParams{
+ RedirectURL: rawRedirect,
+ Scope: p.Scope,
+ ClientID: p.ClientID,
+ ResponseType: "code",
+ State: state,
+ TimeStamp: fmt.Sprint(now.Unix()),
+ Signature: p.signRedirectURL(rawRedirect, now),
+ }
+
+ return &a, signInParams
+}
+
+type SignOutParams struct {
+ RedirectURL string
+ TimeStamp string
+ Signature string
}
// GetSignOutURL creates and returns the sign out URL, given a redirectURL
-func (p *SSOProvider) GetSignOutURL(redirectURL *url.URL) *url.URL {
+func (p *SSOProvider) GetSignOutURL(redirectURL *url.URL) (*url.URL, SignOutParams) {
a := *p.Data().SignOutURL
now := time.Now()
rawRedirect := redirectURL.String()
- params, _ := url.ParseQuery(a.RawQuery)
- params.Add("redirect_uri", rawRedirect)
- params.Set("ts", fmt.Sprint(now.Unix()))
- params.Set("sig", p.signRedirectURL(rawRedirect, now))
- a.RawQuery = params.Encode()
- return &a
+
+ //TODO: With this, we overwrite any previously set redirect params
+ // because we are no longer appending with `.Add`. Need to check this out.
+ signOutParams := SignOutParams{
+ RedirectURL: rawRedirect,
+ TimeStamp: fmt.Sprint(now.Unix()),
+ Signature: p.signRedirectURL(rawRedirect, now),
+ }
+
+ return &a, signOutParams
}
// signRedirectURL signs the redirect url string, given a timestamp, and returns it
diff --git a/internal/proxy/providers/test_provider.go b/internal/proxy/providers/test_provider.go
index 35494786..ef1e28bc 100644
--- a/internal/proxy/providers/test_provider.go
+++ b/internal/proxy/providers/test_provider.go
@@ -72,15 +72,28 @@ func (tp *TestProvider) ValidateGroup(email string, groups []string, accessToken
}
// GetSignOutURL mocks GetSignOutURL function
-func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) *url.URL {
- return tp.Data().SignOutURL
+func (tp *TestProvider) GetSignOutURL(redirectURL *url.URL) (*url.URL, SignOutParams) {
+ a := *tp.Data().SignOutURL
+ rawRedirect := redirectURL.String()
+
+ // this returns less than required for an actual call,
+ // but is enough for us to test functionality
+ signOutParams := SignOutParams{
+ RedirectURL: rawRedirect,
+ }
+
+ return &a, signOutParams
}
// GetSignInURL mocks GetSignInURL
-func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) *url.URL {
+func (tp *TestProvider) GetSignInURL(redirectURL *url.URL, state string) (*url.URL, SignInParams) {
a := *tp.Data().SignInURL
- params, _ := url.ParseQuery(a.RawQuery)
- params.Add("state", state)
- a.RawQuery = params.Encode()
- return &a
+ rawRedirect := redirectURL.String()
+
+ signInParams := SignInParams{
+ RedirectURL: rawRedirect,
+ Scope: tp.Scope,
+ State: state,
+ }
+ return &a, signInParams
}
diff --git a/internal/proxy/static/sso.css b/internal/proxy/static/sso.css
new file mode 100644
index 00000000..a674b7dd
--- /dev/null
+++ b/internal/proxy/static/sso.css
@@ -0,0 +1,86 @@
+* {
+ margin: 0;
+ padding: 0;
+}
+body {
+ font-family: "Helvetica Neue",Helvetica,Arial,sans-serif;
+ font-size: 1em;
+ line-height: 1.42857143;
+ color: #333;
+ background: #f0f0f0;
+}
+
+p {
+ margin: 1.5em 0;
+}
+p:first-child {
+ margin-top: 0;
+}
+p:last-child {
+ margin-bottom: 0;
+}
+
+.container {
+ max-width: 40em;
+ display: block;
+ margin: 10% auto;
+ text-align: center;
+}
+
+.content, .message, button {
+ border: 1px solid rgba(0,0,0,.125);
+ border-bottom-width: 4px;
+ border-radius: 4px;
+}
+
+.content, .message {
+ background-color: #fff;
+ padding: 2rem;
+ margin: 1rem 0;
+}
+.error, .message {
+ border-bottom-color: #c00;
+}
+.message {
+ padding: 1.5rem 2rem 1.3rem;
+}
+
+header {
+ border-bottom: 1px solid rgba(0,0,0,.075);
+ margin: -2rem 0 2rem;
+ padding: 2rem 0 1.8rem;
+}
+header h1 {
+ font-size: 1.5em;
+ font-weight: normal;
+}
+.error header {
+ color: #c00;
+}
+.details {
+ font-size: .85rem;
+ color: #999;
+}
+
+button {
+ color: #fff;
+ background-color: #3B8686;
+ cursor: pointer;
+ font-size: 1.5rem;
+ font-weight: bold;
+ padding: 1rem 2.5rem;
+ text-shadow: 0 3px 1px rgba(0,0,0,.2);
+ outline: none;
+}
+button:active {
+ border-top-width: 4px;
+ border-bottom-width: 1px;
+ text-shadow: none;
+}
+
+footer {
+ font-size: 0.75em;
+ color: #999;
+ text-align: right;
+ margin: 1rem;
+}
diff --git a/internal/proxy/static_files.go b/internal/proxy/static_files.go
new file mode 100644
index 00000000..adbf39f6
--- /dev/null
+++ b/internal/proxy/static_files.go
@@ -0,0 +1,47 @@
+package proxy
+
+import (
+ "net/http"
+ "os"
+
+ "github.com/rakyll/statik/fs"
+
+ // Statik makes assets available via a blank import
+ _ "github.com/buzzfeed/sso/internal/auth/statik"
+)
+
+// noDirectoryFilesystem is used to prevent an http.FileServer from providing directory listings
+type noDirectoryFS struct {
+ fs http.FileSystem
+}
+
+func (fs noDirectoryFS) Open(name string) (http.File, error) {
+ f, err := fs.fs.Open(name)
+
+ if err != nil {
+ return nil, err
+ }
+
+ stat, err := f.Stat()
+ if err != nil {
+ return nil, err
+ }
+
+ // prevent directory listings
+ if stat.IsDir() {
+ return nil, os.ErrNotExist
+ }
+
+ return f, nil
+}
+
+//go:generate $GOPATH/bin/statik -f -src=./static
+
+func loadFSHandler() (http.Handler, error) {
+ statikFS, err := fs.New()
+ if err != nil {
+ return nil, err
+ }
+
+ return http.FileServer(noDirectoryFS{statikFS}), nil
+}
diff --git a/internal/proxy/static_files_test.go b/internal/proxy/static_files_test.go
new file mode 100644
index 00000000..9abd9a12
--- /dev/null
+++ b/internal/proxy/static_files_test.go
@@ -0,0 +1,111 @@
+package proxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+
+ "github.com/buzzfeed/sso/internal/pkg/aead"
+ "github.com/buzzfeed/sso/internal/pkg/sessions"
+ "github.com/buzzfeed/sso/internal/pkg/testutil"
+ "github.com/buzzfeed/sso/internal/proxy/providers"
+)
+
+func TestStaticFiles(t *testing.T) {
+ testCases := []struct {
+ name string
+ uri string
+ expectedStatus int
+ expectedContent string
+ }{
+ {
+ name: "static css ok",
+ uri: "https://localhost/static/sso.css",
+ expectedStatus: http.StatusOK,
+ expectedContent: "body {",
+ },
+ {
+ name: "nonexistent file not found",
+ uri: "https://localhost/static/missing.css",
+ expectedStatus: http.StatusNotFound,
+ },
+ {
+ name: "no directory listing",
+ uri: "https://localhost/static/",
+ expectedStatus: http.StatusNotFound,
+ },
+ {
+ // this will result in a 301 -> /config.yml
+ name: "no directory escape",
+ uri: "https://localhost/static/../config.yml",
+ expectedStatus: http.StatusMovedPermanently,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+
+ // create backend server
+ backend := httptest.NewServer(http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ },
+ ))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatalf("unexpected err parsing backend url: %v", err)
+ }
+
+ // create base opts, provider, upstream config and other
+ // components required to create an OAuthProxy
+ opts := NewOptions()
+ providerURL, _ := url.Parse("http://localhost/")
+ provider := providers.NewTestProvider(providerURL, "")
+
+ upstreamConfig := &UpstreamConfig{
+ Route: &SimpleRoute{
+ ToURL: backendURL,
+ },
+ }
+
+ reverseProxy, err := NewUpstreamReverseProxy(upstreamConfig, nil)
+ if err != nil {
+ t.Fatalf("unexpected error creating upstream reverse proxy: %v", err)
+ }
+
+ optFuncs := []func(*OAuthProxy) error{
+ SetProvider(provider),
+ setSessionStore(&sessions.MockSessionStore{}),
+ SetUpstreamConfig(upstreamConfig),
+ SetProxyHandler(reverseProxy),
+ setCSRFStore(&sessions.MockCSRFStore{}),
+ setCookieCipher(&aead.MockCipher{}),
+ }
+
+ proxy, err := NewOAuthProxy(opts, optFuncs...)
+ testutil.Assert(t, err == nil, "could not create upstream reverse proxy: %v", err)
+
+ // make requests and check responses
+ rw := httptest.NewRecorder()
+ req := httptest.NewRequest("GET", tc.uri, nil)
+
+ proxy.Handler().ServeHTTP(rw, req)
+ if rw.Code != tc.expectedStatus {
+ t.Logf("expected response: %v", tc.expectedStatus)
+ t.Logf(" got code: %v", rw.Code)
+ t.Logf(" headers: %v", rw.HeaderMap)
+ t.Errorf("unexpected response returned")
+ }
+
+ if tc.expectedContent != "" && !strings.Contains(rw.Body.String(), tc.expectedContent) {
+ t.Logf("expected body to contain: %v", tc.expectedContent)
+ t.Logf(" got: %v", rw.Body.String())
+ t.Errorf("unexpected body returned")
+ }
+ })
+ }
+}
diff --git a/internal/proxy/statik/statik.go b/internal/proxy/statik/statik.go
new file mode 100644
index 00000000..526bf1e2
--- /dev/null
+++ b/internal/proxy/statik/statik.go
@@ -0,0 +1,13 @@
+// Code generated by statik. DO NOT EDIT.
+
+// Package statik contains static assets.
+package statik
+
+import (
+ "github.com/rakyll/statik/fs"
+)
+
+func init() {
+ data := "PK\x03\x04\x14\x00\x08\x00\x08\x00$y\xfaN\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x00 \x00sso.cssUT\x05\x00\x01\x95\x17;]tT\xc1n\x9c0\x10\xbd\xf3\x15\xa3D\x95\xdajA\xb0d\x93\x8d\xf7\xd4\x9ez\xea?\x18<\x80\x15\xe3A\xb6Iv[\xe5\xdf+\x1b\xc3\xc2j+N\xb6\xc7o\xde{\xf3\xccw\xf8\x9b\x00\xf4\xdc\xb4R3\xc8O \xc0\xc0\x85\x90\xba\x0d\xab\xcf\xa4\"q 5\x0di\x976\xbc\x97\xea\xc2\xe0\xe1\x17\xaawt\xb2\xe6\xf0\x1bG|\xd8-\xeb\xdd\x0f#\xb9\xdaY\xaemj\xd1\xc8\xe64\xdf\xb5\xf2\x0f2(\xb0\xf7;JjL;\x94m\xe7\x18\x14\xd9\xd3\xfexx)\x9eJ\x7fT\x93\"\xc3\xe0\xb1,\xc3\xb2\xe2\xf5[kh\xd4\x82\xc1c\x93\xfb\xcf\xf3J\x86\x0d\xf3\";`?1\x1eX#\x8dui\xddI%VE\xa9\xa3\x81\xcd%\x8a\xdf\xab\xa8\xc89\xeacQ\x92\xd5\xa4\x1d\x97\x1aM\xac9\xa7\x1fR\xb8\x8e\xc1S>\xa9\x10\xd2\x0e\x8a_\x18T\x8a\xea\xb7\xd3\x9aO\xfe\x05\xf8\xe8\xc8\xef9<\xbb\x94+\xd9j\x065j\x87\xe6\n\x8f\xda\xed \xeb\xd1Z\xde\xe2\x0e\xaa\xd19\xd2\xa1]EF\xa0aP\x0cg\xb0\xa4\xa4\x00\xd3V\xfck\xbe\xf3_V\xec\x0f\xdfNKU$\xbe\xd0\x1b\xce\xab3\xc3\x85\x1cm\xdc\xbd\xdbw\xea\xb7\xf8\x9c\xce\x13h\x9af\x13\x88\xbd\x99d/\"\xcd\xecy\x86\xc6\x90\xb9A\xbce7\xc3\xd6\xf9tg[\xbbt)\xb2\x83\xc7\xf5\xcd\xa0\xc8\xca\xd0\xf33I:\xe4\"Nb\x03\xfb?\x87\xf2\x97\xc9\xa1\x99l\x1a\x00\xf3E\xc4F\x15\xe4Pd\xc7\xd8*v\xea\x8ak\xeecv}\xc8\x96<\x7f\xc4\xf4j2=WW\x17`\xc5\xf4\x9a\xe6Y\xb4@\xc7\xa5\xb2\xf1t\x05\x9e\x1d\x0f\x91\xd9\xf5\xd2\xeb\xebk\xd0\xbe\x8a\xc5\xcdl\xeeL\xad\xfcy|>>\x87\x974\x1a\xeb\xb7\x06\x92S\xean\xd5\x98;r*Rb\xe3O\x98\xf2~)\x0ei\xb6\x1d\x17\xf4\xc1 \x87r8\x87\x01\xac\xad\xdf\x07\xe3it\xfe\x95{\x834\x86_IP\xc1x\xed\xe4;\xae\x07\xe9h\xb8\x1f\xddm\xac\x8b\xe9lC`\xc6N\x1a\"\x17M_i\xcc\xb3\x978\xb2\x8d\xa5\xdb7i\xbc\xf0\xdb\\{\xd0\x7f\x01\x00\x00\xff\xffPK\x07\x08\x0b~@~\x1a\x02\x00\x00\x1d\x05\x00\x00PK\x01\x02\x14\x03\x14\x00\x08\x00\x08\x00$y\xfaN\x0b~@~\x1a\x02\x00\x00\x1d\x05\x00\x00\x07\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\xa4\x81\x00\x00\x00\x00sso.cssUT\x05\x00\x01\x95\x17;]PK\x05\x06\x00\x00\x00\x00\x01\x00\x01\x00>\x00\x00\x00X\x02\x00\x00\x00\x00"
+ fs.Register(data)
+}