Skip to content

Reactive Jwt Validation #5718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import com.nimbusds.jwt.proc.JWTProcessor;
import reactor.core.publisher.Mono;

import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -67,6 +69,8 @@ public final class NimbusReactiveJwtDecoder implements ReactiveJwtDecoder {

private final JWKSelectorFactory jwkSelectorFactory;

private OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefault();

public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
JWSAlgorithm algorithm = JWSAlgorithm.parse(JwsAlgorithms.RS256);

Expand All @@ -77,6 +81,7 @@ public NimbusReactiveJwtDecoder(RSAPublicKey publicKey) {
new JWSVerificationKeySelector<>(algorithm, jwkSource);
DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});

this.jwtProcessor = jwtProcessor;
this.reactiveJwkSource = new ReactiveJWKSourceAdapter(jwkSource);
Expand All @@ -98,6 +103,7 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {

DefaultJWTProcessor<JWKContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
this.jwtProcessor = jwtProcessor;

this.reactiveJwkSource = new ReactiveRemoteJWKSource(jwkSetUrl);
Expand All @@ -106,6 +112,16 @@ public NimbusReactiveJwtDecoder(String jwkSetUrl) {

}

/**
* Use the provided {@link OAuth2TokenValidator} to validate incoming {@link Jwt}s.
*
* @param jwtValidator the {@link OAuth2TokenValidator} to use
*/
public void setJwtValidator(OAuth2TokenValidator<Jwt> jwtValidator) {
Assert.notNull(jwtValidator, "jwtValidator cannot be null");
this.jwtValidator = jwtValidator;
}

@Override
public Mono<Jwt> decode(String token) throws JwtException {
JWT jwt = parse(token);
Expand All @@ -131,7 +147,8 @@ private Mono<Jwt> decode(SignedJWT parsedToken) {
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
.map(jwkList -> createClaimsSet(parsedToken, jwkList))
.map(set -> createJwt(parsedToken, set))
.onErrorMap(e -> !(e instanceof IllegalStateException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
.map(this::validateJwt)
.onErrorMap(e -> !(e instanceof IllegalStateException) && !(e instanceof JwtException), e -> new JwtException("An error occurred while attempting to decode the Jwt: ", e));
} catch (RuntimeException ex) {
throw new JwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex);
}
Expand Down Expand Up @@ -164,6 +181,17 @@ private Jwt createJwt(JWT parsedJwt, JWTClaimsSet jwtClaimsSet) {
return new Jwt(parsedJwt.getParsedString(), issuedAt, expiresAt, headers, jwtClaimsSet.getClaims());
}

private Jwt validateJwt(Jwt jwt) {
OAuth2TokenValidatorResult result = this.jwtValidator.validate(jwt);

if ( result.hasErrors() ) {
String message = result.getErrors().iterator().next().getDescription();
throw new JwtValidationException(message, result.getErrors());
}

return jwt;
}

private static RSAKey rsaKey(RSAPublicKey publicKey) {
return new RSAKey.Builder(publicKey)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,28 @@

package org.springframework.security.oauth2.jwt;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.net.UnknownHostException;
import java.security.KeyFactory;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.X509EncodedKeySpec;
import java.util.Base64;
import java.util.Date;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* @author Rob Winch
Expand Down Expand Up @@ -114,7 +121,7 @@ public void decodeWhenIssuedAtThenSuccess() {
@Test
public void decodeWhenExpiredThenFail() {
assertThatCode(() -> this.decoder.decode(this.expired).block())
.isInstanceOf(JwtException.class);
.isInstanceOf(JwtValidationException.class);
}

@Test
Expand Down Expand Up @@ -155,4 +162,24 @@ public void decodeWhenUnsignedTokenThenMessageDoesNotMentionClass() {
.isInstanceOf(JwtException.class)
.hasMessage("Unsupported algorithm of none");
}

@Test
public void decodeWhenUsingCustomValidatorThenValidatorIsInvoked() {
OAuth2TokenValidator jwtValidator = mock(OAuth2TokenValidator.class);
this.decoder.setJwtValidator(jwtValidator);

OAuth2Error error = new OAuth2Error("mock-error", "mock-description", "mock-uri");
OAuth2TokenValidatorResult result = OAuth2TokenValidatorResult.failure(error);
when(jwtValidator.validate(any(Jwt.class))).thenReturn(result);

assertThatCode(() -> this.decoder.decode(messageReadToken).block())
.isInstanceOf(JwtException.class)
.hasMessageContaining("mock-description");
}

@Test
public void setJwtValidatorWhenGivenNullThrowsIllegalArgumentException() {
assertThatCode(() -> this.decoder.setJwtValidator(null))
.isInstanceOf(IllegalArgumentException.class);
}
}