From f4a5d6ec33709518801cbc87c828096ea641cd7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Veli=20D=C3=B6ngelci?= <4553816+dongelci@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:23:30 +0200 Subject: [PATCH] Fix caching error state in ReactiveRemoteJWKSource Closes gh-13757 --- .../oauth2/jwt/ReactiveRemoteJWKSource.java | 30 ++++++++++++------- .../jwt/ReactiveRemoteJWKSourceTests.java | 23 ++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java index 24bc7321137..848094274c1 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSource.java @@ -43,18 +43,23 @@ class ReactiveRemoteJWKSource implements ReactiveJWKSource { */ private final AtomicReference> cachedJWKSet = new AtomicReference<>(Mono.empty()); + /** + * cached url for jwk set. + */ + private final AtomicReference cachedJwkSetUrl = new AtomicReference<>(); + private WebClient webClient = WebClient.create(); - private final Mono jwkSetURL; + private Mono jwkSetURLProvider; ReactiveRemoteJWKSource(String jwkSetURL) { Assert.hasText(jwkSetURL, "jwkSetURL cannot be empty"); - this.jwkSetURL = Mono.just(jwkSetURL); + this.cachedJwkSetUrl.set(jwkSetURL); } - ReactiveRemoteJWKSource(Mono jwkSetURL) { - Assert.notNull(jwkSetURL, "jwkSetURL cannot be null"); - this.jwkSetURL = jwkSetURL.cache(); + ReactiveRemoteJWKSource(Mono jwkSetURLProvider) { + Assert.notNull(jwkSetURLProvider, "jwkSetURLProvider cannot be null"); + this.jwkSetURLProvider = jwkSetURLProvider; } @Override @@ -100,13 +105,18 @@ private Mono> get(JWKSelector jwkSelector, JWKSet jwkSet) { */ private Mono getJWKSet() { // @formatter:off - return this.jwkSetURL.flatMap((jwkSetURL) -> this.webClient.get() - .uri(jwkSetURL) - .retrieve() - .bodyToMono(String.class)) + return Mono.justOrEmpty(this.cachedJwkSetUrl.get()) + .switchIfEmpty(Mono.defer(() -> this.jwkSetURLProvider + .doOnNext(this.cachedJwkSetUrl::set)) + ) + .flatMap((jwkSetURL) -> this.webClient.get() + .uri(jwkSetURL) + .retrieve() + .bodyToMono(String.class) + ) .map(this::parse) .doOnNext((jwkSet) -> this.cachedJWKSet - .set(Mono.just(jwkSet)) + .set(Mono.just(jwkSet)) ) .cache(); // @formatter:on diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java index ddcc1c913f3..2f0881fc228 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/ReactiveRemoteJWKSourceTests.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.List; +import java.util.function.Supplier; import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWKMatcher; @@ -31,10 +32,15 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; /** * @author Rob Winch @@ -52,6 +58,9 @@ public class ReactiveRemoteJWKSourceTests { private MockWebServer server; + @Mock + private Supplier mockStringSupplier; + // @formatter:off private String keys = "{\n" + " \"keys\": [\n" @@ -156,4 +165,18 @@ public void getWhenNoMatchAndKeyIdMatchThenEmpty() { assertThat(this.source.get(this.selector).block()).isEmpty(); } + @Test + public void getShouldRecoverAndReturnKeysAfterErrorCase() { + given(this.matcher.matches(any())).willReturn(true); + this.source = new ReactiveRemoteJWKSource(Mono.fromSupplier(mockStringSupplier)); + doThrow(WebClientResponseException.ServiceUnavailable.class).when(this.mockStringSupplier).get(); + // first case: id provider has error state + assertThatThrownBy(() -> this.source.get(this.selector).block()) + .isExactlyInstanceOf(WebClientResponseException.ServiceUnavailable.class); + // second case: id provider is healthy again + doReturn(this.server.url("/").toString()).when(this.mockStringSupplier).get(); + var actual = this.source.get(this.selector).block(); + assertThat(actual).isNotEmpty(); + } + }