diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java index df1a365566d..2180562190a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java @@ -133,7 +133,7 @@ private Mono authorizedClientNotLoaded(String clientRegi }); } - private Mono clientCredentials( + Mono clientCredentials( ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index ef92036690d..548c9180147 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private final OAuth2AuthorizedClientResolver authorizedClientResolver; public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository)); + } + + ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) { this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientResolver = authorizedClientResolver; } /** @@ -245,13 +249,30 @@ private Mono createRequest(ClientRequest } private Mono refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - if (shouldRefresh(authorizedClient)) { + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { + return createRequest(request) + .flatMap(r -> authorizeWithClientCredentials(clientRegistration, r)); + } else if (shouldRefresh(authorizedClient)) { return createRequest(request) .flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r)); } return Mono.just(authorizedClient); } + private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { + return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); + } + + private Mono authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) { + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + + return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange). + flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange) + .thenReturn(result)); + } + private Mono refreshAuthorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) { ServerWebExchange exchange = r.getExchange(); @@ -280,6 +301,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { if (refreshToken == null) { return false; } + return hasTokenExpired(authorizedClient); + } + + private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { Instant now = this.clock.instant(); Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 6914363aa52..40b244fecfc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -332,12 +332,16 @@ private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, if (clientRegistration == null) { throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + if (isClientCredentialsGrantType(clientRegistration)) { return getAuthorizedClient(clientRegistration, attrs); } throw new ClientAuthorizationRequiredException(clientRegistrationId); } + private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { + return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); + } + private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map attrs) { @@ -366,7 +370,11 @@ private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegi } private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - if (shouldRefresh(authorizedClient)) { + ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); + if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { + //Client credentials grant do not have refresh tokens but can expire so we need to get another one + return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes())); + } else if (shouldRefresh(authorizedClient)) { return refreshAuthorizedClient(request, next, authorizedClient); } return Mono.just(authorizedClient); @@ -407,6 +415,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { if (refreshToken == null) { return false; } + return hasTokenExpired(authorizedClient); + } + + private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { Instant now = this.clock.instant(); Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 97d17d0fb14..1f20a75cbb8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -42,6 +42,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -67,6 +68,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ReactiveClientRegistrationRepository clientRegistrationRepository; + @Mock + private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver; + @Mock private ServerWebExchange serverWebExchange; @@ -144,6 +149,88 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); } + @Test + public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); + String clientRegistrationId = registration.getClientId(); + + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver); + + OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "new-token", + Instant.now(), + Instant.now().plus(Duration.ofDays(1))); + OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", newAccessToken, null); + Request r = new Request(clientRegistrationId, authentication, null); + when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient)); + when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r)); + + when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + + OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); + verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any()); + verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { + TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); + ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); + + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + this.function.filter(request, this.exchange) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + + verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any()); + verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + @Test public void filterWhenRefreshRequiredThenRefresh() { when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 1431864f857..e6dfa7a1e99 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -78,6 +78,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; @@ -423,6 +424,80 @@ public void filterWhenRefreshRequiredThenRefresh() { assertThat(getBody(request1)).isEmpty(); } + @Test + public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { + this.registration = TestClientRegistrations.clientCredentials().build(); + + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); + + verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + + @Test + public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { + this.registration = TestClientRegistrations.clientCredentials().build(); + + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses + .accessTokenResponse().build(); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( + accessTokenResponse); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); + + this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), + this.accessToken.getTokenValue(), + issuedAt, + accessTokenExpiresAt); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, + this.authorizedClientRepository); + this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, null); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(authentication(this.authentication)) + .build(); + + this.function.filter(request, this.exchange).block(); + + verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); + + verify(clientCredentialsTokenResponseClient).getTokenResponse(any()); + + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + + ClientRequest request1 = requests.get(0); + assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token"); + assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request1.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request1)).isEmpty(); + } + @Test public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() { OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")