40
40
import com .nimbusds .jwt .proc .JWTProcessor ;
41
41
import reactor .core .publisher .Mono ;
42
42
43
+ import org .springframework .security .oauth2 .core .OAuth2TokenValidator ;
44
+ import org .springframework .security .oauth2 .core .OAuth2TokenValidatorResult ;
43
45
import org .springframework .security .oauth2 .jose .jws .JwsAlgorithms ;
44
46
import org .springframework .util .Assert ;
45
47
@@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
67
69
68
70
private final JWKSelectorFactory jwkSelectorFactory ;
69
71
72
+ private OAuth2TokenValidator <Jwt > jwtValidator = JwtValidators .createDefault ();
73
+
70
74
public NimbusReactiveJwtDecoder (RSAPublicKey publicKey ) {
71
75
JWSAlgorithm algorithm = JWSAlgorithm .parse (JwsAlgorithms .RS256 );
72
76
@@ -77,6 +81,7 @@ public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
77
81
new JWSVerificationKeySelector <>(algorithm , jwkSource );
78
82
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor <>();
79
83
jwtProcessor .setJWSKeySelector (jwsKeySelector );
84
+ jwtProcessor .setJWTClaimsSetVerifier ((claims , context ) -> {});
80
85
81
86
this .jwtProcessor = jwtProcessor ;
82
87
this .reactiveJwkSource = new ReactiveJWKSourceAdapter (jwkSource );
@@ -98,6 +103,7 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {
98
103
99
104
DefaultJWTProcessor <JWKContext > jwtProcessor = new DefaultJWTProcessor <>();
100
105
jwtProcessor .setJWSKeySelector (jwsKeySelector );
106
+ jwtProcessor .setJWTClaimsSetVerifier ((claims , context ) -> {});
101
107
this .jwtProcessor = jwtProcessor ;
102
108
103
109
this .reactiveJwkSource = new ReactiveRemoteJWKSource (jwkSetUrl );
@@ -106,6 +112,16 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {
106
112
107
113
}
108
114
115
+ /**
116
+ * Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
117
+ *
118
+ * @param jwtValidator the {@link OAuth2TokenValidator} to use
119
+ */
120
+ public void setJwtValidator (OAuth2TokenValidator <Jwt > jwtValidator ) {
121
+ Assert .notNull (jwtValidator , "jwtValidator cannot be null" );
122
+ this .jwtValidator = jwtValidator ;
123
+ }
124
+
109
125
@ Override
110
126
public Mono <Jwt > decode (String token ) throws JwtException {
111
127
JWT jwt = parse (token );
@@ -131,7 +147,8 @@ private Mono<Jwt> decode(SignedJWT parsedToken) {
131
147
.onErrorMap (e -> new IllegalStateException ("Could not obtain the keys" , e ))
132
148
.map (jwkList -> createClaimsSet (parsedToken , jwkList ))
133
149
.map (set -> createJwt (parsedToken , set ))
134
- .onErrorMap (e -> !(e instanceof IllegalStateException ), e -> new JwtException ("An error occurred while attempting to decode the Jwt: " , e ));
150
+ .map (this ::validateJwt )
151
+ .onErrorMap (e -> !(e instanceof IllegalStateException ) && !(e instanceof JwtException ), e -> new JwtException ("An error occurred while attempting to decode the Jwt: " , e ));
135
152
} catch (RuntimeException ex ) {
136
153
throw new JwtException ("An error occurred while attempting to decode the Jwt: " + ex .getMessage (), ex );
137
154
}
@@ -164,6 +181,17 @@ private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
164
181
return new Jwt (parsedJwt .getParsedString (), issuedAt , expiresAt , headers , jwtClaimsSet .getClaims ());
165
182
}
166
183
184
+ private Jwt validateJwt (Jwt jwt ) {
185
+ OAuth2TokenValidatorResult result = this .jwtValidator .validate (jwt );
186
+
187
+ if ( result .hasErrors () ) {
188
+ String message = result .getErrors ().iterator ().next ().getDescription ();
189
+ throw new JwtValidationException (message , result .getErrors ());
190
+ }
191
+
192
+ return jwt ;
193
+ }
194
+
167
195
private static RSAKey rsaKey (RSAPublicKey publicKey ) {
168
196
return new RSAKey .Builder (publicKey )
169
197
.build ();
0 commit comments