diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java index d8852992f4d..a3fdf95f8ba 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,10 @@ import java.net.URL; import java.time.Instant; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; /** * An {@link OAuth2TokenValidator} responsible for @@ -41,7 +44,6 @@ * @see ID Token Validation */ public final class OidcIdTokenValidator implements OAuth2TokenValidator { - private static final OAuth2Error INVALID_ID_TOKEN_ERROR = new OAuth2Error("invalid_id_token"); private final ClientRegistration clientRegistration; public OidcIdTokenValidator(ClientRegistration clientRegistration) { @@ -53,27 +55,10 @@ public OidcIdTokenValidator(ClientRegistration clientRegistration) { public OAuth2TokenValidatorResult validate(Jwt idToken) { // 3.1.3.7 ID Token Validation // http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation + Map invalidClaims = validateRequiredClaims(idToken); - // Validate REQUIRED Claims - URL issuer = idToken.getIssuer(); - if (issuer == null) { - return invalidIdToken(); - } - String subject = idToken.getSubject(); - if (subject == null) { - return invalidIdToken(); - } - List audience = idToken.getAudience(); - if (CollectionUtils.isEmpty(audience)) { - return invalidIdToken(); - } - Instant expiresAt = idToken.getExpiresAt(); - if (expiresAt == null) { - return invalidIdToken(); - } - Instant issuedAt = idToken.getIssuedAt(); - if (issuedAt == null) { - return invalidIdToken(); + if (!invalidClaims.isEmpty()){ + return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); } // 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery) @@ -85,21 +70,21 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) { // The aud (audience) Claim MAY contain an array with more than one element. // The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience, // or if it contains additional audiences not trusted by the Client. - if (!audience.contains(this.clientRegistration.getClientId())) { - return invalidIdToken(); + if (!idToken.getAudience().contains(this.clientRegistration.getClientId())) { + invalidClaims.put(IdTokenClaimNames.AUD, idToken.getAudience()); } // 4. If the ID Token contains multiple audiences, // the Client SHOULD verify that an azp Claim is present. String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP); - if (audience.size() > 1 && authorizedParty == null) { - return invalidIdToken(); + if (idToken.getAudience().size() > 1 && authorizedParty == null) { + invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } // 5. If an azp (authorized party) Claim is present, // the Client SHOULD verify that its client_id is the Claim Value. if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) { - return invalidIdToken(); + invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty); } // 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client @@ -108,16 +93,16 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) { // 9. The current time MUST be before the time represented by the exp Claim. Instant now = Instant.now(); - if (!now.isBefore(expiresAt)) { - return invalidIdToken(); + if (!now.isBefore(idToken.getExpiresAt())) { + invalidClaims.put(IdTokenClaimNames.EXP, idToken.getExpiresAt()); } // 10. The iat Claim can be used to reject tokens that were issued too far away from the current time, // limiting the amount of time that nonces need to be stored to prevent attacks. // The acceptable range is Client specific. Instant maxIssuedAt = now.plusSeconds(30); - if (issuedAt.isAfter(maxIssuedAt)) { - return invalidIdToken(); + if (idToken.getIssuedAt().isAfter(maxIssuedAt)) { + invalidClaims.put(IdTokenClaimNames.IAT, idToken.getIssuedAt()); } // 11. If a nonce value was sent in the Authentication Request, @@ -127,10 +112,45 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) { // The precise method for detecting replay attacks is Client specific. // TODO Depends on gh-4442 + if (!invalidClaims.isEmpty()) { + return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims)); + } + return OAuth2TokenValidatorResult.success(); } - private static OAuth2TokenValidatorResult invalidIdToken() { - return OAuth2TokenValidatorResult.failure(INVALID_ID_TOKEN_ERROR); + private static OAuth2Error invalidIdToken(Map invalidClaims) { + String claimsDetail = invalidClaims.entrySet().stream() + .map(it -> it.getKey()+ "("+it.getValue()+")") + .collect(Collectors.joining(", ")); + + return new OAuth2Error("invalid_id_token", "The ID Token contains invalid claims: "+claimsDetail, null); + } + + private static Map validateRequiredClaims(Jwt idToken){ + Map requiredClaims = new HashMap<>(); + + URL issuer = idToken.getIssuer(); + if (issuer == null) { + requiredClaims.put(IdTokenClaimNames.ISS, issuer); + } + String subject = idToken.getSubject(); + if (subject == null) { + requiredClaims.put(IdTokenClaimNames.SUB, subject); + } + List audience = idToken.getAudience(); + if (CollectionUtils.isEmpty(audience)) { + requiredClaims.put(IdTokenClaimNames.AUD, audience); + } + Instant expiresAt = idToken.getExpiresAt(); + if (expiresAt == null) { + requiredClaims.put(IdTokenClaimNames.EXP, expiresAt); + } + Instant issuedAt = idToken.getIssuedAt(); + if (issuedAt == null) { + requiredClaims.put(IdTokenClaimNames.IAT, issuedAt); + } + + return requiredClaims; } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java index 088bfbe900f..4522db84efe 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,8 +64,9 @@ public void validateIdTokenWhenIssuerNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.ISS); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.ISS)); + } @Test @@ -73,8 +74,8 @@ public void validateIdTokenWhenSubNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.SUB); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB)); } @Test @@ -82,8 +83,8 @@ public void validateIdTokenWhenAudNullThenHasErrors() { this.claims.remove(IdTokenClaimNames.AUD); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)); } @Test @@ -91,8 +92,8 @@ public void validateIdTokenWhenIssuedAtNullThenHasErrors() { this.issuedAt = null; assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)); } @Test @@ -100,8 +101,8 @@ public void validateIdTokenWhenExpiresAtNullThenHasErrors() { this.expiresAt = null; assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); } @Test @@ -109,8 +110,8 @@ public void validateIdTokenWhenAudMultipleAndAzpNullThenHasErrors() { this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); } @Test @@ -118,8 +119,8 @@ public void validateIdTokenWhenAzpNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AZP, "other"); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); } @Test @@ -135,8 +136,8 @@ public void validateIdTokenWhenMultipleAudAzpNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AZP, "other-client"); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP)); } @Test @@ -144,8 +145,8 @@ public void validateIdTokenWhenAudNotClientIdThenHasErrors() { this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client")); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)); } @Test @@ -154,8 +155,8 @@ public void validateIdTokenWhenExpiredThenHasErrors() { this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); } @Test @@ -164,8 +165,8 @@ public void validateIdTokenWhenIssuedAtWayInFutureThenHasErrors() { this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1)); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)); } @Test @@ -174,8 +175,34 @@ public void validateIdTokenWhenExpiresAtBeforeNowThenHasErrors() { this.expiresAt = Instant.from(this.issuedAt).plusSeconds(5); assertThat(this.validateIdToken()) .hasSize(1) - .extracting(OAuth2Error::getErrorCode) - .contains("invalid_id_token"); + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + } + + @Test + public void validateIdTokenWhenMissingClaimsThenHasErrors() { + this.claims.remove(IdTokenClaimNames.SUB); + this.claims.remove(IdTokenClaimNames.AUD); + this.issuedAt = null; + this.expiresAt = null; + assertThat(this.validateIdToken()) + .hasSize(1) + .extracting(OAuth2Error::getDescription) + .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB)) + .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD)) + .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT)) + .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP)); + } + + @Test(expected = IllegalArgumentException.class) + public void validateIdTokenWhenNoClaimsThenHasErrors() { + this.claims.remove(IdTokenClaimNames.ISS); + this.claims.remove(IdTokenClaimNames.SUB); + this.claims.remove(IdTokenClaimNames.AUD); + this.issuedAt = null; + this.expiresAt = null; + assertThat(this.validateIdToken()) + .hasSize(1); } private Collection validateIdToken() {