20
20
import static org .mockito .BDDMockito .given ;
21
21
import static org .mockito .ArgumentMatchers .any ;
22
22
import static org .mockito .Mockito .mock ;
23
+ import static org .mockito .Mockito .spy ;
23
24
import static org .mockito .Mockito .verify ;
24
25
import static org .mockito .Mockito .verifyZeroInteractions ;
25
26
import static org .mockito .Mockito .when ;
26
27
import static org .springframework .security .config .Customizer .withDefaults ;
28
+ import static org .springframework .test .util .ReflectionTestUtils .getField ;
27
29
28
30
import java .util .Arrays ;
29
31
import java .util .List ;
35
37
import org .junit .Before ;
36
38
import org .junit .Test ;
37
39
import org .junit .runner .RunWith ;
40
+ import org .mockito .ArgumentCaptor ;
38
41
import org .mockito .Mock ;
39
42
import org .mockito .junit .MockitoJUnitRunner ;
40
43
41
44
import org .springframework .security .core .Authentication ;
42
45
import org .springframework .security .oauth2 .client .registration .ReactiveClientRegistrationRepository ;
43
46
import org .springframework .security .oauth2 .client .web .server .ServerAuthorizationRequestRepository ;
47
+ import org .springframework .security .oauth2 .client .web .server .authentication .OAuth2LoginAuthenticationWebFilter ;
44
48
import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationRequest ;
45
49
import org .springframework .security .oauth2 .core .endpoint .TestOAuth2AuthorizationRequests ;
46
50
import org .springframework .security .web .authentication .preauth .x509 .X509PrincipalExtractor ;
47
51
import org .springframework .security .web .server .authentication .ServerX509AuthenticationConverter ;
52
+ import org .springframework .security .web .server .savedrequest .ServerRequestCache ;
53
+ import org .springframework .security .web .server .savedrequest .WebSessionServerRequestCache ;
48
54
import reactor .core .publisher .Mono ;
49
55
import reactor .test .publisher .TestPublisher ;
50
56
64
70
import org .springframework .security .web .server .csrf .CsrfServerLogoutHandler ;
65
71
import org .springframework .security .web .server .csrf .CsrfWebFilter ;
66
72
import org .springframework .security .web .server .csrf .ServerCsrfTokenRepository ;
67
- import org .springframework .test .util .ReflectionTestUtils ;
68
73
import org .springframework .test .web .reactive .server .EntityExchangeResult ;
69
74
import org .springframework .test .web .reactive .server .FluxExchangeResult ;
70
75
import org .springframework .test .web .reactive .server .WebTestClient ;
@@ -200,7 +205,7 @@ public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
200
205
.isNotPresent ();
201
206
202
207
Optional <ServerLogoutHandler > logoutHandler = getWebFilter (securityWebFilterChain , LogoutWebFilter .class )
203
- .map (logoutWebFilter -> (ServerLogoutHandler ) ReflectionTestUtils . getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
208
+ .map (logoutWebFilter -> (ServerLogoutHandler ) getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
204
209
205
210
assertThat (logoutHandler )
206
211
.get ()
@@ -213,17 +218,17 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
213
218
214
219
assertThat (getWebFilter (securityWebFilterChain , CsrfWebFilter .class ))
215
220
.get ()
216
- .extracting (csrfWebFilter -> ReflectionTestUtils . getField (csrfWebFilter , "csrfTokenRepository" ))
221
+ .extracting (csrfWebFilter -> getField (csrfWebFilter , "csrfTokenRepository" ))
217
222
.isEqualTo (this .csrfTokenRepository );
218
223
219
224
Optional <ServerLogoutHandler > logoutHandler = getWebFilter (securityWebFilterChain , LogoutWebFilter .class )
220
- .map (logoutWebFilter -> (ServerLogoutHandler ) ReflectionTestUtils . getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
225
+ .map (logoutWebFilter -> (ServerLogoutHandler ) getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
221
226
222
227
assertThat (logoutHandler )
223
228
.get ()
224
229
.isExactlyInstanceOf (DelegatingServerLogoutHandler .class )
225
230
.extracting (delegatingLogoutHandler ->
226
- ((List <ServerLogoutHandler >) ReflectionTestUtils . getField (delegatingLogoutHandler , DelegatingServerLogoutHandler .class , "delegates" )).stream ()
231
+ ((List <ServerLogoutHandler >) getField (delegatingLogoutHandler , DelegatingServerLogoutHandler .class , "delegates" )).stream ()
227
232
.map (ServerLogoutHandler ::getClass )
228
233
.collect (Collectors .toList ()))
229
234
.isEqualTo (Arrays .asList (SecurityContextServerLogoutHandler .class , CsrfServerLogoutHandler .class ));
@@ -479,6 +484,33 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
479
484
verify (customServerCsrfTokenRepository ).loadToken (any ());
480
485
}
481
486
487
+ @ Test
488
+ public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler () {
489
+ ServerRequestCache requestCache = spy (new WebSessionServerRequestCache ());
490
+ ReactiveClientRegistrationRepository clientRegistrationRepository = mock (ReactiveClientRegistrationRepository .class );
491
+
492
+ SecurityWebFilterChain securityFilterChain = this .http
493
+ .oauth2Login ()
494
+ .clientRegistrationRepository (clientRegistrationRepository )
495
+ .and ()
496
+ .authorizeExchange ().anyExchange ().authenticated ()
497
+ .and ()
498
+ .requestCache (c -> c .requestCache (requestCache ))
499
+ .build ();
500
+
501
+ WebTestClient client = WebTestClientBuilder .bindToWebFilters (securityFilterChain ).build ();
502
+ client .get ().uri ("/test" ).exchange ();
503
+ ArgumentCaptor <ServerWebExchange > captor = ArgumentCaptor .forClass (ServerWebExchange .class );
504
+ verify (requestCache ).saveRequest (captor .capture ());
505
+ assertThat (captor .getValue ().getRequest ().getURI ().toString ()).isEqualTo ("/test" );
506
+
507
+
508
+ OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
509
+ getWebFilter (securityFilterChain , OAuth2LoginAuthenticationWebFilter .class ).get ();
510
+ Object handler = getField (authenticationWebFilter , "authenticationSuccessHandler" );
511
+ assertThat (getField (handler , "requestCache" )).isSameAs (requestCache );
512
+ }
513
+
482
514
@ Test
483
515
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login () {
484
516
ServerAuthorizationRequestRepository <OAuth2AuthorizationRequest > authorizationRequestRepository = mock (ServerAuthorizationRequestRepository .class );
@@ -503,7 +535,7 @@ public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
503
535
504
536
private boolean isX509Filter (WebFilter filter ) {
505
537
try {
506
- Object converter = ReflectionTestUtils . getField (filter , "authenticationConverter" );
538
+ Object converter = getField (filter , "authenticationConverter" );
507
539
return converter .getClass ().isAssignableFrom (ServerX509AuthenticationConverter .class );
508
540
} catch (IllegalArgumentException e ) {
509
541
// field doesn't exist
0 commit comments