Skip to content

Commit 450a20a

Browse files
Warren Baileyjgrandja
Warren Bailey
authored andcommitted
When expired retrieve new Client Credentials token.
Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server. These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails Fixes gh-5893
1 parent f3f84e1 commit 450a20a

File tree

5 files changed

+208
-5
lines changed

5 files changed

+208
-5
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegi
133133
});
134134
}
135135

136-
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
136+
Mono<OAuth2AuthorizedClient> clientCredentials(
137137
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
138138
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
139139
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
8585
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
8686

8787
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
88+
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
89+
}
90+
91+
ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
8892
this.authorizedClientRepository = authorizedClientRepository;
89-
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
93+
this.authorizedClientResolver = authorizedClientResolver;
9094
}
9195

9296
/**
@@ -246,13 +250,30 @@ private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest
246250
}
247251

248252
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
249-
if (shouldRefresh(authorizedClient)) {
253+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
254+
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
255+
return createRequest(request)
256+
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
257+
} else if (shouldRefresh(authorizedClient)) {
250258
return createRequest(request)
251259
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
252260
}
253261
return Mono.just(authorizedClient);
254262
}
255263

264+
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
265+
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
266+
}
267+
268+
private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
269+
Authentication authentication = request.getAuthentication();
270+
ServerWebExchange exchange = request.getExchange();
271+
272+
return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
273+
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
274+
.thenReturn(result));
275+
}
276+
256277
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
257278
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
258279
ServerWebExchange exchange = r.getExchange();
@@ -285,6 +306,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
285306
if (refreshToken == null) {
286307
return false;
287308
}
309+
return hasTokenExpired(authorizedClient);
310+
}
311+
312+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
288313
Instant now = this.clock.instant();
289314
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
290315
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,10 @@ private Mono<OAuth2AuthorizedClient> authorizeClient(ClientRequest request) {
412412
throw new ClientAuthorizationRequiredException(clientRegistrationId);
413413
}
414414

415+
private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
416+
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
417+
}
418+
415419
private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
416420
Map<String, Object> attrs) {
417421

@@ -439,7 +443,11 @@ private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegi
439443
}
440444

441445
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
442-
if (shouldRefresh(authorizedClient)) {
446+
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
447+
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
448+
//Client credentials grant do not have refresh tokens but can expire so we need to get another one
449+
return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
450+
} else if (shouldRefresh(authorizedClient)) {
443451
return refreshAuthorizedClient(request, next, authorizedClient);
444452
}
445453
return Mono.just(authorizedClient);
@@ -484,6 +492,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
484492
if (refreshToken == null) {
485493
return false;
486494
}
495+
return hasTokenExpired(authorizedClient);
496+
}
497+
498+
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
487499
Instant now = this.clock.instant();
488500
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
489501
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.security.oauth2.client.registration.ClientRegistration;
4545
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
4646
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
47+
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
4748
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
4849
import org.springframework.security.oauth2.core.OAuth2AccessToken;
4950
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -69,6 +70,7 @@
6970
import static org.assertj.core.api.Assertions.assertThat;
7071
import static org.mockito.ArgumentMatchers.any;
7172
import static org.mockito.ArgumentMatchers.eq;
73+
import static org.mockito.Mockito.never;
7274
import static org.mockito.Mockito.verify;
7375
import static org.mockito.Mockito.verifyZeroInteractions;
7476
import static org.mockito.Mockito.when;
@@ -88,6 +90,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
8890
@Mock
8991
private ReactiveClientRegistrationRepository clientRegistrationRepository;
9092

93+
@Mock
94+
private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;
95+
9196
@Mock
9297
private ServerWebExchange serverWebExchange;
9398

@@ -149,6 +154,88 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
149154
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
150155
}
151156

157+
@Test
158+
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
159+
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
160+
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
161+
String clientRegistrationId = registration.getClientId();
162+
163+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
164+
165+
OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
166+
"new-token",
167+
Instant.now(),
168+
Instant.now().plus(Duration.ofDays(1)));
169+
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
170+
"principalName", newAccessToken, null);
171+
Request r = new Request(clientRegistrationId, authentication, null);
172+
when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
173+
when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));
174+
175+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
176+
177+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
178+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
179+
180+
OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
181+
this.accessToken.getTokenValue(),
182+
issuedAt,
183+
accessTokenExpiresAt);
184+
185+
186+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
187+
"principalName", accessToken, null);
188+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
189+
.attributes(oauth2AuthorizedClient(authorizedClient))
190+
.build();
191+
192+
193+
this.function.filter(request, this.exchange)
194+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
195+
.block();
196+
197+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
198+
verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
199+
verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());
200+
201+
List<ClientRequest> requests = this.exchange.getRequests();
202+
assertThat(requests).hasSize(1);
203+
ClientRequest request1 = requests.get(0);
204+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
205+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
206+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
207+
assertThat(getBody(request1)).isEmpty();
208+
}
209+
210+
@Test
211+
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
212+
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
213+
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
214+
215+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);
216+
217+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
218+
"principalName", this.accessToken, null);
219+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
220+
.attributes(oauth2AuthorizedClient(authorizedClient))
221+
.build();
222+
223+
this.function.filter(request, this.exchange)
224+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
225+
.block();
226+
227+
verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
228+
verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());
229+
230+
List<ClientRequest> requests = this.exchange.getRequests();
231+
assertThat(requests).hasSize(1);
232+
ClientRequest request1 = requests.get(0);
233+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
234+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
235+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
236+
assertThat(getBody(request1)).isEmpty();
237+
}
238+
152239
@Test
153240
public void filterWhenRefreshRequiredThenRefresh() {
154241
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.springframework.security.oauth2.core.OAuth2AccessToken;
5656
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
5757
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
58+
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
5859
import org.springframework.security.oauth2.core.user.OAuth2User;
5960
import org.springframework.web.context.request.RequestContextHolder;
6061
import org.springframework.web.context.request.ServletRequestAttributes;
@@ -80,7 +81,11 @@
8081
import static org.assertj.core.api.Assertions.assertThatCode;
8182
import static org.mockito.ArgumentMatchers.any;
8283
import static org.mockito.ArgumentMatchers.eq;
83-
import static org.mockito.Mockito.*;
84+
import static org.mockito.Mockito.mock;
85+
import static org.mockito.Mockito.never;
86+
import static org.mockito.Mockito.verify;
87+
import static org.mockito.Mockito.verifyZeroInteractions;
88+
import static org.mockito.Mockito.when;
8489
import static org.springframework.http.HttpMethod.GET;
8590
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
8691

@@ -433,6 +438,80 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh
433438
assertThat(getBody(request1)).isEmpty();
434439
}
435440

441+
@Test
442+
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
443+
this.registration = TestClientRegistrations.clientCredentials().build();
444+
445+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
446+
this.authorizedClientRepository);
447+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
448+
449+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
450+
"principalName", this.accessToken, null);
451+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
452+
.attributes(oauth2AuthorizedClient(authorizedClient))
453+
.attributes(authentication(this.authentication))
454+
.build();
455+
456+
this.function.filter(request, this.exchange).block();
457+
458+
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
459+
460+
verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
461+
462+
List<ClientRequest> requests = this.exchange.getRequests();
463+
assertThat(requests).hasSize(1);
464+
465+
ClientRequest request1 = requests.get(0);
466+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
467+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
468+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
469+
assertThat(getBody(request1)).isEmpty();
470+
}
471+
472+
@Test
473+
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
474+
this.registration = TestClientRegistrations.clientCredentials().build();
475+
476+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
477+
.accessTokenResponse().build();
478+
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
479+
accessTokenResponse);
480+
481+
Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
482+
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));
483+
484+
this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
485+
this.accessToken.getTokenValue(),
486+
issuedAt,
487+
accessTokenExpiresAt);
488+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
489+
this.authorizedClientRepository);
490+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
491+
492+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
493+
"principalName", this.accessToken, null);
494+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
495+
.attributes(oauth2AuthorizedClient(authorizedClient))
496+
.attributes(authentication(this.authentication))
497+
.build();
498+
499+
this.function.filter(request, this.exchange).block();
500+
501+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());
502+
503+
verify(clientCredentialsTokenResponseClient).getTokenResponse(any());
504+
505+
List<ClientRequest> requests = this.exchange.getRequests();
506+
assertThat(requests).hasSize(1);
507+
508+
ClientRequest request1 = requests.get(0);
509+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
510+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
511+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
512+
assertThat(getBody(request1)).isEmpty();
513+
}
514+
436515
@Test
437516
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
438517
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")

0 commit comments

Comments
 (0)