1
1
/*
2
- * Copyright 2002-2023 the original author or authors.
2
+ * Copyright 2002-2025 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
16
16
17
17
package org .springframework .security .oauth2 .jwt ;
18
18
19
+ import com .nimbusds .jose .KeySourceException ;
20
+ import com .nimbusds .jose .jwk .JWK ;
21
+ import com .nimbusds .jose .jwk .JWKMatcher ;
22
+ import com .nimbusds .jose .jwk .JWKSelector ;
23
+ import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24
+ import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
19
25
import java .io .IOException ;
20
26
import java .net .MalformedURLException ;
21
27
import java .net .URL ;
26
32
import java .util .Collections ;
27
33
import java .util .HashSet ;
28
34
import java .util .LinkedHashMap ;
35
+ import java .util .List ;
29
36
import java .util .Map ;
30
37
import java .util .Set ;
38
+ import java .util .concurrent .locks .ReentrantLock ;
31
39
import java .util .function .Consumer ;
32
40
import java .util .function .Function ;
33
41
34
42
import javax .crypto .SecretKey ;
35
43
36
44
import com .nimbusds .jose .JOSEException ;
37
45
import com .nimbusds .jose .JWSAlgorithm ;
38
- import com .nimbusds .jose .RemoteKeySourceException ;
39
46
import com .nimbusds .jose .jwk .JWKSet ;
40
- import com .nimbusds .jose .jwk .source .JWKSetCache ;
41
47
import com .nimbusds .jose .jwk .source .JWKSource ;
42
- import com .nimbusds .jose .jwk .source .RemoteJWKSet ;
43
48
import com .nimbusds .jose .proc .JWSKeySelector ;
44
49
import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
45
50
import com .nimbusds .jose .proc .SecurityContext ;
46
51
import com .nimbusds .jose .proc .SingleKeyJWSKeySelector ;
47
- import com .nimbusds .jose .util .Resource ;
48
- import com .nimbusds .jose .util .ResourceRetriever ;
49
52
import com .nimbusds .jwt .JWT ;
50
53
import com .nimbusds .jwt .JWTClaimsSet ;
51
54
import com .nimbusds .jwt .JWTParser ;
57
60
import org .apache .commons .logging .LogFactory ;
58
61
59
62
import org .springframework .cache .Cache ;
63
+ import org .springframework .cache .support .NoOpCache ;
60
64
import org .springframework .core .convert .converter .Converter ;
61
65
import org .springframework .http .HttpHeaders ;
62
66
import org .springframework .http .HttpMethod ;
80
84
* @author Josh Cummings
81
85
* @author Joe Grandja
82
86
* @author Mykyta Bezverkhyi
87
+ * @author Daeho Kwon
83
88
* @since 5.2
84
89
*/
85
90
public final class NimbusJwtDecoder implements JwtDecoder {
@@ -165,7 +170,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
165
170
.build ();
166
171
// @formatter:on
167
172
}
168
- catch (RemoteKeySourceException ex ) {
173
+ catch (KeySourceException ex ) {
169
174
this .logger .trace ("Failed to retrieve JWK set" , ex );
170
175
if (ex .getCause () instanceof ParseException ) {
171
176
throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -273,7 +278,7 @@ public static final class JwkSetUriJwtDecoderBuilder {
273
278
274
279
private RestOperations restOperations = new RestTemplate ();
275
280
276
- private Cache cache ;
281
+ private Cache cache = new NoOpCache ( "default" ) ;
277
282
278
283
private Consumer <ConfigurableJWTProcessor <SecurityContext >> jwtProcessorCustomizer ;
279
284
@@ -376,18 +381,13 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
376
381
return new JWSVerificationKeySelector <>(jwsAlgorithms , jwkSource );
377
382
}
378
383
379
- JWKSource <SecurityContext > jwkSource (ResourceRetriever jwkSetRetriever , String jwkSetUri ) {
380
- if (this .cache == null ) {
381
- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever );
382
- }
383
- JWKSetCache jwkSetCache = new SpringJWKSetCache (jwkSetUri , this .cache );
384
- return new RemoteJWKSet <>(toURL (jwkSetUri ), jwkSetRetriever , jwkSetCache );
384
+ JWKSource <SecurityContext > jwkSource () {
385
+ String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386
+ return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
385
387
}
386
388
387
389
JWTProcessor <SecurityContext > processor () {
388
- ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever (this .restOperations );
389
- String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
390
- JWKSource <SecurityContext > jwkSource = jwkSource (jwkSetRetriever , jwkSetUri );
390
+ JWKSource <SecurityContext > jwkSource = jwkSource ();
391
391
ConfigurableJWTProcessor <SecurityContext > jwtProcessor = new DefaultJWTProcessor <>();
392
392
jwtProcessor .setJWSKeySelector (jwsKeySelector (jwkSource ));
393
393
// Spring Security validates the claim set independent from Nimbus
@@ -414,84 +414,130 @@ private static URL toURL(String url) {
414
414
}
415
415
}
416
416
417
- private static final class SpringJWKSetCache implements JWKSetCache {
417
+ private static final class SpringJWKSource < C extends SecurityContext > implements JWKSource < C > {
418
418
419
- private final String jwkSetUri ;
419
+ private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
420
+
421
+ private final ReentrantLock reentrantLock = new ReentrantLock ();
422
+
423
+ private final RestOperations restOperations ;
420
424
421
425
private final Cache cache ;
422
426
423
- private JWKSet jwkSet ;
427
+ private final URL url ;
424
428
425
- SpringJWKSetCache (String jwkSetUri , Cache cache ) {
426
- this .jwkSetUri = jwkSetUri ;
429
+ private final String jwkSetUri ;
430
+
431
+ private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
432
+ Assert .notNull (restOperations , "restOperations cannot be null" );
433
+ this .restOperations = restOperations ;
427
434
this .cache = cache ;
428
- this .updateJwkSetFromCache ();
435
+ this .url = url ;
436
+ this .jwkSetUri = jwkSetUri ;
429
437
}
430
438
431
- private void updateJwkSetFromCache () {
439
+
440
+ @ Override
441
+ public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
432
442
String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443
+ JWKSet jwkSet = null ;
433
444
if (cachedJwkSet != null ) {
434
- try {
435
- this .jwkSet = JWKSet .parse (cachedJwkSet );
436
- }
437
- catch (ParseException ignored ) {
438
- // Ignore invalid cache value
445
+ jwkSet = parse (cachedJwkSet );
446
+ }
447
+ if (jwkSet == null ) {
448
+ if (reentrantLock .tryLock ()) {
449
+ try {
450
+ String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451
+ if (cachedJwkSetAfterLock != null ) {
452
+ jwkSet = parse (cachedJwkSetAfterLock );
453
+ }
454
+ if (jwkSet == null ) {
455
+ try {
456
+ jwkSet = fetchJWKSet ();
457
+ } catch (IOException e ) {
458
+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459
+ }
460
+ }
461
+ } finally {
462
+ reentrantLock .unlock ();
463
+ }
439
464
}
440
465
}
441
- }
442
-
443
- // Note: Only called from inside a synchronized block in RemoteJWKSet.
444
- @ Override
445
- public void put (JWKSet jwkSet ) {
446
- this .jwkSet = jwkSet ;
447
- this .cache .put (this .jwkSetUri , jwkSet .toString (false ));
448
- }
449
-
450
- @ Override
451
- public JWKSet get () {
452
- return (!requiresRefresh ()) ? this .jwkSet : null ;
453
-
454
- }
455
-
456
- @ Override
457
- public boolean requiresRefresh () {
458
- return this .cache .get (this .jwkSetUri ) == null ;
459
- }
460
-
461
- }
462
-
463
- private static class RestOperationsResourceRetriever implements ResourceRetriever {
464
-
465
- private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
466
-
467
- private final RestOperations restOperations ;
466
+ List <JWK > matches = jwkSelector .select (jwkSet );
467
+ if (!matches .isEmpty ()) {
468
+ return matches ;
469
+ }
470
+ String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471
+ if (soughtKeyID == null ) {
472
+ return Collections .emptyList ();
473
+ }
474
+ if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475
+ return Collections .emptyList ();
476
+ }
468
477
469
- RestOperationsResourceRetriever (RestOperations restOperations ) {
470
- Assert .notNull (restOperations , "restOperations cannot be null" );
471
- this .restOperations = restOperations ;
478
+ if (reentrantLock .tryLock ()) {
479
+ try {
480
+ String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481
+ JWKSet cacheJwkSet = parse (jwkSetUri );
482
+ if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483
+ try {
484
+ jwkSet = fetchJWKSet ();
485
+ } catch (IOException e ) {
486
+ throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487
+ }
488
+ } else if (jwkSetUri != null ) {
489
+ jwkSet = parse (jwkSetUri );
490
+ }
491
+ } finally {
492
+ reentrantLock .unlock ();
493
+ }
494
+ }
495
+ if (jwkSet == null ) {
496
+ return Collections .emptyList ();
497
+ }
498
+ return jwkSelector .select (jwkSet );
472
499
}
473
500
474
- @ Override
475
- public Resource retrieveResource (URL url ) throws IOException {
501
+ private JWKSet fetchJWKSet () throws IOException , KeySourceException {
476
502
HttpHeaders headers = new HttpHeaders ();
477
503
headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
478
- ResponseEntity <String > response = getResponse (url , headers );
504
+ ResponseEntity <String > response = getResponse (headers );
479
505
if (response .getStatusCode ().value () != 200 ) {
480
506
throw new IOException (response .toString ());
481
507
}
482
- return new Resource (response .getBody (), "UTF-8" );
508
+ try {
509
+ String jwkSet = response .getBody ();
510
+ this .cache .put (this .jwkSetUri , jwkSet );
511
+ return JWKSet .parse (jwkSet );
512
+ } catch (ParseException e ) {
513
+ throw new JWKSetParseException ("Unable to parse JWK set" , e );
514
+ }
483
515
}
484
516
485
- private ResponseEntity <String > getResponse (URL url , HttpHeaders headers ) throws IOException {
517
+ private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
486
518
try {
487
- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , url .toURI ());
519
+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this . url .toURI ());
488
520
return this .restOperations .exchange (request , String .class );
489
- }
490
- catch (Exception ex ) {
521
+ } catch (Exception ex ) {
491
522
throw new IOException (ex );
492
523
}
493
524
}
494
525
526
+ private JWKSet parse (String cachedJwkSet ) {
527
+ JWKSet jwkSet = null ;
528
+ try {
529
+ jwkSet = JWKSet .parse (cachedJwkSet );
530
+ } catch (ParseException ignored ) {
531
+ // Ignore invalid cache value
532
+ }
533
+ return jwkSet ;
534
+ }
535
+
536
+ private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537
+ Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538
+ return (keyIDs == null || keyIDs .isEmpty ()) ?
539
+ null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
540
+ }
495
541
}
496
542
497
543
}
0 commit comments