|
39 | 39 | import org.mockito.junit.MockitoJUnitRunner;
|
40 | 40 |
|
41 | 41 | import org.springframework.security.core.Authentication;
|
| 42 | +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; |
| 43 | +import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; |
| 44 | +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
| 45 | +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; |
42 | 46 | import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
|
43 | 47 | import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
|
44 | 48 | import reactor.core.publisher.Mono;
|
@@ -475,6 +479,28 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
|
475 | 479 | verify(customServerCsrfTokenRepository).loadToken(any());
|
476 | 480 | }
|
477 | 481 |
|
| 482 | + @Test |
| 483 | + public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { |
| 484 | + ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); |
| 485 | + ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); |
| 486 | + |
| 487 | + OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().build(); |
| 488 | + |
| 489 | + when(authorizationRequestRepository.removeAuthorizationRequest(any())).thenReturn(Mono.just(authorizationRequest)); |
| 490 | + |
| 491 | + SecurityWebFilterChain securityFilterChain = this.http |
| 492 | + .oauth2Login() |
| 493 | + .clientRegistrationRepository(clientRegistrationRepository) |
| 494 | + .authorizationRequestRepository(authorizationRequestRepository) |
| 495 | + .and() |
| 496 | + .build(); |
| 497 | + |
| 498 | + WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); |
| 499 | + client.get().uri("/login/oauth2/code/registration-id").exchange(); |
| 500 | + |
| 501 | + verify(authorizationRequestRepository).removeAuthorizationRequest(any()); |
| 502 | + } |
| 503 | + |
478 | 504 | private boolean isX509Filter(WebFilter filter) {
|
479 | 505 | try {
|
480 | 506 | Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
|
|
0 commit comments