1717
1818import static org .junit .jupiter .api .Assertions .assertEquals ;
1919import static org .junit .jupiter .api .Assertions .assertInstanceOf ;
20+ import static org .junit .jupiter .api .Assertions .assertNotNull ;
2021import static org .junit .jupiter .api .Assertions .assertThrows ;
2122import static org .junit .jupiter .api .Assertions .assertTrue ;
2223import static org .mockito .ArgumentMatchers .any ;
2526import static org .mockito .Mockito .times ;
2627import static org .mockito .Mockito .verify ;
2728import static org .mockito .Mockito .when ;
29+ import static software .amazon .awssdk .services .signin .auth .internal .DpopTestUtils .VALID_TEST_PEM ;
30+ import static software .amazon .awssdk .services .signin .auth .internal .DpopTestUtils .getJwtPayloadFromEncodedDpopHeader ;
31+ import static software .amazon .awssdk .services .signin .auth .internal .DpopTestUtils .verifySignature ;
2832
33+ import java .io .ByteArrayInputStream ;
34+ import java .net .URI ;
35+ import java .nio .charset .StandardCharsets ;
2936import java .nio .file .Path ;
3037import java .time .Instant ;
38+ import java .util .ArrayList ;
39+ import java .util .List ;
40+ import java .util .Map ;
3141import org .junit .jupiter .api .BeforeEach ;
3242import org .junit .jupiter .api .Test ;
3343import org .junit .jupiter .api .io .TempDir ;
3444import org .mockito .ArgumentCaptor ;
3545import software .amazon .awssdk .auth .credentials .AwsCredentials ;
3646import software .amazon .awssdk .auth .credentials .AwsSessionCredentials ;
47+ import software .amazon .awssdk .auth .signer .AwsSignerExecutionAttribute ;
48+ import software .amazon .awssdk .core .SdkRequest ;
3749import software .amazon .awssdk .core .exception .SdkClientException ;
50+ import software .amazon .awssdk .core .interceptor .Context ;
51+ import software .amazon .awssdk .core .interceptor .ExecutionAttributes ;
52+ import software .amazon .awssdk .core .interceptor .ExecutionInterceptor ;
3853import software .amazon .awssdk .core .useragent .BusinessMetricFeatureId ;
54+ import software .amazon .awssdk .http .AbortableInputStream ;
55+ import software .amazon .awssdk .http .HttpExecuteResponse ;
56+ import software .amazon .awssdk .http .SdkHttpRequest ;
57+ import software .amazon .awssdk .http .SdkHttpResponse ;
58+ import software .amazon .awssdk .regions .Region ;
3959import software .amazon .awssdk .services .signin .SigninClient ;
4060import software .amazon .awssdk .services .signin .internal .AccessTokenManager ;
4161import software .amazon .awssdk .services .signin .internal .LoginAccessToken ;
4262import software .amazon .awssdk .services .signin .internal .OnDiskTokenManager ;
4363import software .amazon .awssdk .services .signin .model .CreateOAuth2TokenRequest ;
4464import software .amazon .awssdk .services .signin .model .CreateOAuth2TokenResponse ;
4565import software .amazon .awssdk .services .signin .model .SigninException ;
66+ import software .amazon .awssdk .testutils .service .http .MockAsyncHttpClient ;
67+ import software .amazon .awssdk .testutils .service .http .MockSyncHttpClient ;
4668
4769public class LoginCredentialsProviderTest {
4870 private static final String LOGIN_SESSION_ID = "loginSessionId" ;
4971
5072 private AccessTokenManager tokenManager ;
5173 private SigninClient signinClient ;
74+ private MockSyncHttpClient mockHttpClient ;
75+ private CaptureRequestInterceptor captureRequestInterceptor ;
5276 private LoginCredentialsProvider loginCredentialsProvider ;
5377
5478 @ TempDir
5579 Path tempDir ;
5680
5781 @ BeforeEach
5882 public void setup () {
59- signinClient = mock (SigninClient .class );
83+ mockHttpClient = new MockSyncHttpClient ();
84+ captureRequestInterceptor = new CaptureRequestInterceptor ();
85+ signinClient = SigninClient
86+ .builder ()
87+ .region (Region .US_EAST_1 )
88+ .endpointOverride (URI .create ("https://custom-signin-endpoint.com" ))
89+ .httpClient (mockHttpClient )
90+ .overrideConfiguration (c -> c .addExecutionInterceptor (captureRequestInterceptor ))
91+ .build ();
92+
6093 tokenManager = OnDiskTokenManager .create (tempDir , LOGIN_SESSION_ID );
6194
6295 loginCredentialsProvider = LoginCredentialsProvider
@@ -89,7 +122,7 @@ public void resolveCredentials_whenCredentialsFresh_usesFromDisk() {
89122
90123 AwsCredentials resolveCredentials = loginCredentialsProvider .resolveCredentials ();
91124
92- verify ( signinClient , never ()). createOAuth2Token ( any ( CreateOAuth2TokenRequest . class ));
125+ assertEquals ( 0 , mockHttpClient . getRequests (). size ( ));
93126
94127 assertEquals (creds .accessKeyId (), resolveCredentials .accessKeyId ());
95128 assertEquals (creds .secretAccessKey (), resolveCredentials .secretAccessKey ());
@@ -104,19 +137,17 @@ public void resolveCredentials_whenCredentialsNearExpiration_refreshesAndUpdates
104137 AwsSessionCredentials creds = buildCredentials (Instant .now ().plusSeconds (10 ));
105138 LoginAccessToken token = buildAccessToken (creds );
106139 tokenManager .storeToken (token );
107- when (signinClient .createOAuth2Token (any (CreateOAuth2TokenRequest .class ))).thenReturn (
108- buildSuccessfulRefreshResponse ()
109- );
110- AwsCredentials resolvedCredentials = loginCredentialsProvider .resolveCredentials ();
140+ stubSuccessfulRefreshResponse ();
111141
112- ArgumentCaptor <CreateOAuth2TokenRequest > captor =
113- ArgumentCaptor .forClass (CreateOAuth2TokenRequest .class );
142+ AwsCredentials resolvedCredentials = loginCredentialsProvider .resolveCredentials ();
114143
115144 // verify the service was called with correct arguments
116- verify (signinClient , times (1 )).createOAuth2Token (captor .capture ());
117- assertEquals (token .getClientId (), captor .getValue ().tokenInput ().clientId ());
118- assertEquals (token .getRefreshToken (), captor .getValue ().tokenInput ().refreshToken ());
119- assertEquals ("refresh_token" , captor .getValue ().tokenInput ().grantType ());
145+ assertEquals (1 , captureRequestInterceptor .requests .size ());
146+ assertInstanceOf (CreateOAuth2TokenRequest .class , captureRequestInterceptor .requests .get (0 ));
147+ CreateOAuth2TokenRequest request = (CreateOAuth2TokenRequest ) captureRequestInterceptor .requests .get (0 );
148+ assertEquals (token .getClientId (), request .tokenInput ().clientId ());
149+ assertEquals (token .getRefreshToken (), request .tokenInput ().refreshToken ());
150+ assertEquals ("refresh_token" , request .tokenInput ().grantType ());
120151 // TODO: Assert validity of DPoP header once implemented
121152
122153 // verify that returned credentials are updated
@@ -126,28 +157,34 @@ public void resolveCredentials_whenCredentialsNearExpiration_refreshesAndUpdates
126157 verifyTokenCacheUpdated ();
127158 }
128159
129-
130-
131160 @ Test
132- public void resolveCredentials_whenCredentialsExpired_refreshesAndUpdatesCache () {
161+ public void resolveCredentials_whenCredentialsExpired_refreshesAndUpdatesCache () throws Exception {
133162 // within the stale time
134163 AwsSessionCredentials creds = buildCredentials (Instant .now ().minusSeconds (600 ));
135164 LoginAccessToken token = buildAccessToken (creds );
136165 tokenManager .storeToken (token );
137- when (signinClient .createOAuth2Token (any (CreateOAuth2TokenRequest .class ))).thenReturn (
138- buildSuccessfulRefreshResponse ()
139- );
140- AwsCredentials resolvedCredentials = loginCredentialsProvider .resolveCredentials ();
166+ stubSuccessfulRefreshResponse ();
141167
142- ArgumentCaptor <CreateOAuth2TokenRequest > captor =
143- ArgumentCaptor .forClass (CreateOAuth2TokenRequest .class );
168+ AwsCredentials resolvedCredentials = loginCredentialsProvider .resolveCredentials ();
144169
145170 // verify the service was called with correct arguments
146- verify (signinClient , times (1 )).createOAuth2Token (captor .capture ());
147- assertEquals (token .getClientId (), captor .getValue ().tokenInput ().clientId ());
148- assertEquals (token .getRefreshToken (), captor .getValue ().tokenInput ().refreshToken ());
149- assertEquals ("refresh_token" , captor .getValue ().tokenInput ().grantType ());
150- // TODO: Assert validity of DPoP header once implemented
171+ assertEquals (1 , captureRequestInterceptor .requests .size ());
172+ assertInstanceOf (CreateOAuth2TokenRequest .class , captureRequestInterceptor .requests .get (0 ));
173+ CreateOAuth2TokenRequest request = (CreateOAuth2TokenRequest ) captureRequestInterceptor .requests .get (0 );
174+ assertEquals (token .getClientId (), request .tokenInput ().clientId ());
175+ assertEquals (token .getRefreshToken (), request .tokenInput ().refreshToken ());
176+ assertEquals ("refresh_token" , request .tokenInput ().grantType ());
177+
178+ // verify the request is correctly signed with DPoP header
179+ List <String > dpopHeader = captureRequestInterceptor .httpRequests .get (0 ).headers ().get ("DPoP" );
180+ assertNotNull (dpopHeader );
181+ assertEquals (1 , dpopHeader .size ());
182+ assertTrue (verifySignature (dpopHeader .get (0 )));
183+
184+ Map <String , Object > payload = getJwtPayloadFromEncodedDpopHeader (dpopHeader .get (0 ));
185+ assertEquals ("POST" , payload .get ("htm" ));
186+ assertEquals ("https://custom-signin-endpoint.com/v1/token" , payload .get ("htu" ));
187+
151188
152189 // verify that returned credentials are updated
153190 verifyResolvedCredentialsAreUpdated (resolvedCredentials );
@@ -162,7 +199,12 @@ public void resolveCredentials_whenCredentialsExpired_serviceCallFails_raisesExc
162199 AwsSessionCredentials creds = buildCredentials (Instant .now ().minusSeconds (60 ));
163200 LoginAccessToken token = buildAccessToken (creds );
164201 tokenManager .storeToken (token );
165- when (signinClient .createOAuth2Token (any (CreateOAuth2TokenRequest .class ))).thenThrow (SigninException .class );
202+ mockHttpClient .stubNextResponse (
203+ HttpExecuteResponse
204+ .builder ()
205+ .response (SdkHttpResponse .builder ().statusCode (500 ).build ())
206+ .build ()
207+ );
166208 assertThrows (SdkClientException .class , () -> loginCredentialsProvider .resolveCredentials ());
167209 }
168210
@@ -188,22 +230,23 @@ private void verifyTokenCacheUpdated() {
188230 assertEquals ("new-refresh-token" , updatedToken .getRefreshToken ());
189231 }
190232
191- private static CreateOAuth2TokenResponse buildSuccessfulRefreshResponse () {
192- return CreateOAuth2TokenResponse
193- .builder ()
194- .tokenOutput (
195- t ->
196- t
197- .expiresIn (600 )
198- .refreshToken ("new-refresh-token" )
199- .accessToken (
200- c ->
201- c
202- .accessKeyId ("new-akid" )
203- .secretAccessKey ("new-skid" )
204- .sessionToken ("new-session-token" ))
205- )
206- .build ();
233+ private void stubSuccessfulRefreshResponse () {
234+ String jsonBody =
235+ "{\" accessToken\" :"
236+ + "{\" accessKeyId\" :\" new-akid\" ,"
237+ + "\" secretAccessKey\" :\" new-skid\" ,"
238+ + "\" sessionToken\" :\" new-session-token\" },"
239+ + "\" tokenType\" :\" aws_sigv4\" ,"
240+ + "\" expiresIn\" :600,"
241+ + "\" refreshToken\" :\" new-refresh-token\" }" ;
242+
243+ mockHttpClient .stubNextResponse (
244+ HttpExecuteResponse
245+ .builder ()
246+ .response (SdkHttpResponse .builder ().statusCode (200 ).build ())
247+ .responseBody (AbortableInputStream .create (new ByteArrayInputStream (jsonBody .getBytes (StandardCharsets .UTF_8 ))))
248+ .build ()
249+ );
207250 }
208251
209252 private AwsSessionCredentials buildCredentials (Instant expirationTime ) {
@@ -220,12 +263,22 @@ private LoginAccessToken buildAccessToken(AwsSessionCredentials credentials) {
220263 return LoginAccessToken .builder ()
221264 .accessToken (credentials )
222265 .clientId ("client-123" )
223- .dpopKey ("dpop-key" )
266+ .dpopKey (VALID_TEST_PEM )
224267 .refreshToken ("refresh-token" )
225268 .tokenType ("aws_sigv4" )
226269 .identityToken ("id-token" )
227270 .build ();
228271 }
229272
273+ private static class CaptureRequestInterceptor implements ExecutionInterceptor {
230274
275+ private List <SdkHttpRequest > httpRequests = new ArrayList <>();
276+ private List <SdkRequest > requests = new ArrayList <>();
277+
278+ @ Override
279+ public void beforeTransmission (Context .BeforeTransmission context , ExecutionAttributes executionAttributes ) {
280+ this .httpRequests .add (context .httpRequest ());
281+ this .requests .add (context .request ());
282+ }
283+ }
231284}
0 commit comments