Skip to content

Commit 01443e3

Browse files
jzheauxrwinch
authored andcommitted
Reactive Jwt Validation
This allows a user to customize the Jwt validation steps that NimbusReactiveJwtDecoder will take for each Jwt. Fixes: gh-5650
1 parent 5365258 commit 01443e3

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import com.nimbusds.jwt.proc.JWTProcessor;
4141
import reactor.core.publisher.Mono;
4242

43+
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
44+
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
4345
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
4446
import org.springframework.util.Assert;
4547

@@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {
6769

6870
private final JWKSelectorFactory jwkSelectorFactory;
6971

72+
private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();
73+
7074
public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
7175
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);
7276

@@ -77,6 +81,7 @@ public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
7781
new JWSVerificationKeySelector<>(algorithm, jwkSource);
7882
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
7983
jwtProcessor.setJWSKeySelector(jwsKeySelector);
84+
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
8085

8186
this.jwtProcessor = jwtProcessor;
8287
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
@@ -98,6 +103,7 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {
98103

99104
DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
100105
jwtProcessor.setJWSKeySelector(jwsKeySelector);
106+
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
101107
this.jwtProcessor = jwtProcessor;
102108

103109
this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
@@ -106,6 +112,16 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {
106112

107113
}
108114

115+
/**
116+
* Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
117+
*
118+
* @param jwtValidator the {@link OAuth2TokenValidator} to use
119+
*/
120+
public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
121+
Assert.notNull(jwtValidator, "jwtValidator cannot be null");
122+
this.jwtValidator = jwtValidator;
123+
}
124+
109125
@Override
110126
public Mono<Jwt> decode(String token) throws JwtException {
111127
JWT jwt = parse(token);
@@ -131,7 +147,8 @@ private Mono<Jwt> decode(SignedJWT parsedToken) {
131147
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
132148
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
133149
.map(set -> createJwt(parsedToken, set))
134-
.onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
150+
.map(this::validateJwt)
151+
.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
135152
} catch (RuntimeException ex) {
136153
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
137154
}
@@ -164,6 +181,17 @@ private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
164181
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
165182
}
166183

184+
private Jwt validateJwt(Jwt jwt) {
185+
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);
186+
187+
if ( result.hasErrors() ) {
188+
String message = result.getErrors().iterator().next().getDescription();
189+
throw new JwtValidationException(message, result.getErrors());
190+
}
191+
192+
return jwt;
193+
}
194+
167195
private static RSAKey rsaKey(RSAPublicKey publicKey) {
168196
return new RSAKey.Builder(publicKey)
169197
.build();

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,28 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19-
import okhttp3.mockwebserver.MockResponse;
20-
import okhttp3.mockwebserver.MockWebServer;
21-
import org.junit.After;
22-
import org.junit.Before;
23-
import org.junit.Test;
24-
2519
import java.net.UnknownHostException;
2620
import java.security.KeyFactory;
2721
import java.security.interfaces.RSAPublicKey;
2822
import java.security.spec.X509EncodedKeySpec;
2923
import java.util.Base64;
3024
import java.util.Date;
3125

26+
import okhttp3.mockwebserver.MockResponse;
27+
import okhttp3.mockwebserver.MockWebServer;
28+
import org.junit.After;
29+
import org.junit.Before;
30+
import org.junit.Test;
31+
32+
import org.springframework.security.oauth2.core.OAuth2Error;
33+
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
34+
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
35+
3236
import static org.assertj.core.api.Assertions.assertThat;
3337
import static org.assertj.core.api.Assertions.assertThatCode;
38+
import static org.mockito.ArgumentMatchers.any;
39+
import static org.mockito.Mockito.mock;
40+
import static org.mockito.Mockito.when;
3441

3542
/**
3643
* @author Rob Winch
@@ -114,7 +121,7 @@ public void decodeWhenIssuedAtThenSuccess() {
114121
@Test
115122
public void decodeWhenExpiredThenFail() {
116123
assertThatCode(() -> this.decoder.decode(this.expired).block())
117-
.isInstanceOf(JwtException.class);
124+
.isInstanceOf(JwtValidationException.class);
118125
}
119126

120127
@Test
@@ -155,4 +162,24 @@ public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() {
155162
.isInstanceOf(JwtException.class)
156163
.hasMessage("Unsupported algorithm of none");
157164
}
165+
166+
@Test
167+
public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() {
168+
OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class);
169+
this.decoder.setJwtValidator(jwtValidator);
170+
171+
OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri");
172+
OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error);
173+
when(jwtValidator.validate(any(Jwt.class))).thenReturn(result);
174+
175+
assertThatCode(() -> this.decoder.decode(messageReadToken).block())
176+
.isInstanceOf(JwtException.class)
177+
.hasMessageContaining("mock-description");
178+
}
179+
180+
@Test
181+
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
182+
assertThatCode(() -> this.decoder.setJwtValidator(null))
183+
.isInstanceOf(IllegalArgumentException.class);
184+
}
158185
}

0 commit comments

Comments
 (0)