Skip to content

Commit b754a3d

Browse files
committed
Use the custom ServerRequestCache that the user configures
on for the default authentication entry point and authentication success handler Fixes gh-7721 #7721 Set RequestCache on the Oauth2LoginSpec default authentication success handler import static ReflectionTestUtils.getField Feedback incorporated per #7734 (review)
1 parent 0d24e2b commit b754a3d

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
7777
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
7878
import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
79+
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
7980
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
8081
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
8182
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
83+
import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
8284
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
8385
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
8486
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@@ -984,7 +986,7 @@ public class OAuth2LoginSpec {
984986

985987
private ServerWebExchangeMatcher authenticationMatcher;
986988

987-
private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
989+
private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
988990

989991
private ServerAuthenticationFailureHandler authenticationFailureHandler;
990992

@@ -1175,24 +1177,37 @@ protected void configure(ServerHttpSecurity http) {
11751177
authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
11761178
authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
11771179

1178-
authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
1180+
authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
11791181
authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
11801182
authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
11811183

11821184
MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
11831185
MediaType.TEXT_HTML);
11841186
htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
11851187
Map<String, String> urlToText = http.oauth2Login.getLinks();
1188+
String authenticationEntryPointRedirectPath;
11861189
if (urlToText.size() == 1) {
1187-
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
1190+
authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
11881191
} else {
1189-
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
1192+
authenticationEntryPointRedirectPath = "/login";
11901193
}
1194+
RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
1195+
entryPoint.setRequestCache(http.requestCache.requestCache);
1196+
http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
11911197

11921198
http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
11931199
http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
11941200
}
11951201

1202+
private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
1203+
if (this.authenticationSuccessHandler == null) {
1204+
RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
1205+
handler.setRequestCache(http.requestCache.requestCache);
1206+
this.authenticationSuccessHandler = handler;
1207+
}
1208+
return this.authenticationSuccessHandler;
1209+
}
1210+
11961211
private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
11971212
if (this.authenticationFailureHandler == null) {
11981213
this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");

config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
import static org.mockito.BDDMockito.given;
2121
import static org.mockito.ArgumentMatchers.any;
2222
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.spy;
2324
import static org.mockito.Mockito.verify;
2425
import static org.mockito.Mockito.verifyZeroInteractions;
2526
import static org.mockito.Mockito.when;
2627
import static org.springframework.security.config.Customizer.withDefaults;
28+
import static org.springframework.test.util.ReflectionTestUtils.getField;
2729

2830
import java.util.Arrays;
2931
import java.util.List;
@@ -35,16 +37,20 @@
3537
import org.junit.Before;
3638
import org.junit.Test;
3739
import org.junit.runner.RunWith;
40+
import org.mockito.ArgumentCaptor;
3841
import org.mockito.Mock;
3942
import org.mockito.junit.MockitoJUnitRunner;
4043

4144
import org.springframework.security.core.Authentication;
4245
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
4346
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
47+
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
4448
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
4549
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
4650
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
4751
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
52+
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
53+
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
4854
import reactor.core.publisher.Mono;
4955
import reactor.test.publisher.TestPublisher;
5056

@@ -64,7 +70,6 @@
6470
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
6571
import org.springframework.security.web.server.csrf.CsrfWebFilter;
6672
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
67-
import org.springframework.test.util.ReflectionTestUtils;
6873
import org.springframework.test.web.reactive.server.EntityExchangeResult;
6974
import org.springframework.test.web.reactive.server.FluxExchangeResult;
7075
import org.springframework.test.web.reactive.server.WebTestClient;
@@ -200,7 +205,7 @@ public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
200205
.isNotPresent();
201206

202207
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
203-
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
208+
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
204209

205210
assertThat(logoutHandler)
206211
.get()
@@ -213,17 +218,17 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
213218

214219
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
215220
.get()
216-
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
221+
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
217222
.isEqualTo(this.csrfTokenRepository);
218223

219224
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
220-
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
225+
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
221226

222227
assertThat(logoutHandler)
223228
.get()
224229
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
225230
.extracting(delegatingLogoutHandler ->
226-
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
231+
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
227232
.map(ServerLogoutHandler::getClass)
228233
.collect(Collectors.toList()))
229234
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
@@ -479,6 +484,33 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
479484
verify(customServerCsrfTokenRepository).loadToken(any());
480485
}
481486

487+
@Test
488+
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
489+
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
490+
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
491+
492+
SecurityWebFilterChain securityFilterChain = this.http
493+
.oauth2Login()
494+
.clientRegistrationRepository(clientRegistrationRepository)
495+
.and()
496+
.authorizeExchange().anyExchange().authenticated()
497+
.and()
498+
.requestCache(c -> c.requestCache(requestCache))
499+
.build();
500+
501+
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
502+
client.get().uri("/test").exchange();
503+
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
504+
verify(requestCache).saveRequest(captor.capture());
505+
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
506+
507+
508+
OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
509+
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
510+
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
511+
assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
512+
}
513+
482514
@Test
483515
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
484516
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@@ -503,7 +535,7 @@ public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
503535

504536
private boolean isX509Filter(WebFilter filter) {
505537
try {
506-
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
538+
Object converter = getField(filter, "authenticationConverter");
507539
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
508540
} catch (IllegalArgumentException e) {
509541
// field doesn't exist

0 commit comments

Comments
 (0)