Skip to content

Commit dcc88ec

Browse files
kwondh5217jzheaux
authored andcommitted
Remove Deprecated Usages of RemoteJWKSet
Closes spring-projectsgh-16251 Signed-off-by: Daeho Kwon <[email protected]>
1 parent 083b74c commit dcc88ec

File tree

3 files changed

+116
-69
lines changed

3 files changed

+116
-69
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 113 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,12 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19+
import com.nimbusds.jose.KeySourceException;
20+
import com.nimbusds.jose.jwk.JWK;
21+
import com.nimbusds.jose.jwk.JWKMatcher;
22+
import com.nimbusds.jose.jwk.JWKSelector;
23+
import com.nimbusds.jose.jwk.source.JWKSetParseException;
24+
import com.nimbusds.jose.jwk.source.JWKSetRetrievalException;
1925
import java.io.IOException;
2026
import java.net.MalformedURLException;
2127
import java.net.URL;
@@ -26,26 +32,23 @@
2632
import java.util.Collections;
2733
import java.util.HashSet;
2834
import java.util.LinkedHashMap;
35+
import java.util.List;
2936
import java.util.Map;
3037
import java.util.Set;
38+
import java.util.concurrent.locks.ReentrantLock;
3139
import java.util.function.Consumer;
3240
import java.util.function.Function;
3341

3442
import javax.crypto.SecretKey;
3543

3644
import com.nimbusds.jose.JOSEException;
3745
import com.nimbusds.jose.JWSAlgorithm;
38-
import com.nimbusds.jose.RemoteKeySourceException;
3946
import com.nimbusds.jose.jwk.JWKSet;
40-
import com.nimbusds.jose.jwk.source.JWKSetCache;
4147
import com.nimbusds.jose.jwk.source.JWKSource;
42-
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
4348
import com.nimbusds.jose.proc.JWSKeySelector;
4449
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
4550
import com.nimbusds.jose.proc.SecurityContext;
4651
import com.nimbusds.jose.proc.SingleKeyJWSKeySelector;
47-
import com.nimbusds.jose.util.Resource;
48-
import com.nimbusds.jose.util.ResourceRetriever;
4952
import com.nimbusds.jwt.JWT;
5053
import com.nimbusds.jwt.JWTClaimsSet;
5154
import com.nimbusds.jwt.JWTParser;
@@ -57,6 +60,7 @@
5760
import org.apache.commons.logging.LogFactory;
5861

5962
import org.springframework.cache.Cache;
63+
import org.springframework.cache.support.NoOpCache;
6064
import org.springframework.core.convert.converter.Converter;
6165
import org.springframework.http.HttpHeaders;
6266
import org.springframework.http.HttpMethod;
@@ -80,6 +84,7 @@
8084
* @author Josh Cummings
8185
* @author Joe Grandja
8286
* @author Mykyta Bezverkhyi
87+
* @author Daeho Kwon
8388
* @since 5.2
8489
*/
8590
public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165170
.build();
166171
// @formatter:on
167172
}
168-
catch (RemoteKeySourceException ex) {
173+
catch (KeySourceException ex) {
169174
this.logger.trace("Failed to retrieve JWK set", ex);
170175
if (ex.getCause() instanceof ParseException) {
171176
throw new JwtException(String.format(DECODING_ERROR_MESSAGE_TEMPLATE, "Malformed Jwk set"), ex);
@@ -273,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
273278

274279
private RestOperations restOperations = new RestTemplate();
275280

276-
private Cache cache;
281+
private Cache cache = new NoOpCache("default");
277282

278283
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
279284

@@ -376,18 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
376381
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
377382
}
378383

379-
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) {
380-
if (this.cache == null) {
381-
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever);
382-
}
383-
JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache);
384-
return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache);
384+
JWKSource<SecurityContext> jwkSource() {
385+
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
386+
return new SpringJWKSource<>(this.restOperations, this.cache, toURL(jwkSetUri), jwkSetUri);
385387
}
386388

387389
JWTProcessor<SecurityContext> processor() {
388-
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
389-
String jwkSetUri = this.jwkSetUri.apply(this.restOperations);
390-
JWKSource<SecurityContext> jwkSource = jwkSource(jwkSetRetriever, jwkSetUri);
390+
JWKSource<SecurityContext> jwkSource = jwkSource();
391391
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
392392
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
393393
// Spring Security validates the claim set independent from Nimbus
@@ -414,84 +414,130 @@ private static URL toURL(String url) {
414414
}
415415
}
416416

417-
private static final class SpringJWKSetCache implements JWKSetCache {
417+
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSource<C> {
418418

419-
private final String jwkSetUri;
419+
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
420+
421+
private final ReentrantLock reentrantLock = new ReentrantLock();
422+
423+
private final RestOperations restOperations;
420424

421425
private final Cache cache;
422426

423-
private JWKSet jwkSet;
427+
private final URL url;
424428

425-
SpringJWKSetCache(String jwkSetUri, Cache cache) {
426-
this.jwkSetUri = jwkSetUri;
429+
private final String jwkSetUri;
430+
431+
private SpringJWKSource(RestOperations restOperations, Cache cache, URL url, String jwkSetUri) {
432+
Assert.notNull(restOperations, "restOperations cannot be null");
433+
this.restOperations = restOperations;
427434
this.cache = cache;
428-
this.updateJwkSetFromCache();
435+
this.url = url;
436+
this.jwkSetUri = jwkSetUri;
429437
}
430438

431-
private void updateJwkSetFromCache() {
439+
440+
@Override
441+
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) throws KeySourceException {
432442
String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class);
443+
JWKSet jwkSet = null;
433444
if (cachedJwkSet != null) {
434-
try {
435-
this.jwkSet = JWKSet.parse(cachedJwkSet);
436-
}
437-
catch (ParseException ignored) {
438-
// Ignore invalid cache value
445+
jwkSet = parse(cachedJwkSet);
446+
}
447+
if (jwkSet == null) {
448+
if(reentrantLock.tryLock()) {
449+
try {
450+
String cachedJwkSetAfterLock = this.cache.get(this.jwkSetUri, String.class);
451+
if (cachedJwkSetAfterLock != null) {
452+
jwkSet = parse(cachedJwkSetAfterLock);
453+
}
454+
if(jwkSet == null) {
455+
try {
456+
jwkSet = fetchJWKSet();
457+
} catch (IOException e) {
458+
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
459+
}
460+
}
461+
} finally {
462+
reentrantLock.unlock();
463+
}
439464
}
440465
}
441-
}
442-
443-
// Note: Only called from inside a synchronized block in RemoteJWKSet.
444-
@Override
445-
public void put(JWKSet jwkSet) {
446-
this.jwkSet = jwkSet;
447-
this.cache.put(this.jwkSetUri, jwkSet.toString(false));
448-
}
449-
450-
@Override
451-
public JWKSet get() {
452-
return (!requiresRefresh()) ? this.jwkSet : null;
453-
454-
}
455-
456-
@Override
457-
public boolean requiresRefresh() {
458-
return this.cache.get(this.jwkSetUri) == null;
459-
}
460-
461-
}
462-
463-
private static class RestOperationsResourceRetriever implements ResourceRetriever {
464-
465-
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
466-
467-
private final RestOperations restOperations;
466+
List<JWK> matches = jwkSelector.select(jwkSet);
467+
if(!matches.isEmpty()) {
468+
return matches;
469+
}
470+
String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
471+
if (soughtKeyID == null) {
472+
return Collections.emptyList();
473+
}
474+
if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
475+
return Collections.emptyList();
476+
}
468477

469-
RestOperationsResourceRetriever(RestOperations restOperations) {
470-
Assert.notNull(restOperations, "restOperations cannot be null");
471-
this.restOperations = restOperations;
478+
if(reentrantLock.tryLock()) {
479+
try {
480+
String jwkSetUri = this.cache.get(this.jwkSetUri, String.class);
481+
JWKSet cacheJwkSet = parse(jwkSetUri);
482+
if(jwkSetUri != null && cacheJwkSet.toString().equals(jwkSet.toString())) {
483+
try {
484+
jwkSet = fetchJWKSet();
485+
} catch (IOException e) {
486+
throw new JWKSetRetrievalException("Couldn't retrieve JWK set from URL: " + e.getMessage(), e);
487+
}
488+
} else if (jwkSetUri != null) {
489+
jwkSet = parse(jwkSetUri);
490+
}
491+
} finally {
492+
reentrantLock.unlock();
493+
}
494+
}
495+
if(jwkSet == null) {
496+
return Collections.emptyList();
497+
}
498+
return jwkSelector.select(jwkSet);
472499
}
473500

474-
@Override
475-
public Resource retrieveResource(URL url) throws IOException {
501+
private JWKSet fetchJWKSet() throws IOException, KeySourceException {
476502
HttpHeaders headers = new HttpHeaders();
477503
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
478-
ResponseEntity<String> response = getResponse(url, headers);
504+
ResponseEntity<String> response = getResponse(headers);
479505
if (response.getStatusCode().value() != 200) {
480506
throw new IOException(response.toString());
481507
}
482-
return new Resource(response.getBody(), "UTF-8");
508+
try {
509+
String jwkSet = response.getBody();
510+
this.cache.put(this.jwkSetUri, jwkSet);
511+
return JWKSet.parse(jwkSet);
512+
} catch (ParseException e) {
513+
throw new JWKSetParseException("Unable to parse JWK set", e);
514+
}
483515
}
484516

485-
private ResponseEntity<String> getResponse(URL url, HttpHeaders headers) throws IOException {
517+
private ResponseEntity<String> getResponse(HttpHeaders headers) throws IOException {
486518
try {
487-
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI());
519+
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, this.url.toURI());
488520
return this.restOperations.exchange(request, String.class);
489-
}
490-
catch (Exception ex) {
521+
} catch (Exception ex) {
491522
throw new IOException(ex);
492523
}
493524
}
494525

526+
private JWKSet parse(String cachedJwkSet) {
527+
JWKSet jwkSet = null;
528+
try {
529+
jwkSet = JWKSet.parse(cachedJwkSet);
530+
} catch (ParseException ignored) {
531+
// Ignore invalid cache value
532+
}
533+
return jwkSet;
534+
}
535+
536+
private String getFirstSpecifiedKeyID(JWKMatcher jwkMatcher) {
537+
Set<String> keyIDs = jwkMatcher.getKeyIDs();
538+
return (keyIDs == null || keyIDs.isEmpty()) ?
539+
null : keyIDs.stream().filter(id -> id != null).findFirst().orElse(null);
540+
}
495541
}
496542

497543
}

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -308,6 +308,7 @@ private void prepareConfigurationResponse() {
308308
private void prepareConfigurationResponse(String body) {
309309
this.server.enqueue(response(body));
310310
this.server.enqueue(response(JWK_SET));
311+
this.server.enqueue(response(JWK_SET)); // default NoOpCache
311312
}
312313

313314
private void prepareConfigurationResponseOidc() {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)