diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java
index d8852992f4d..a3fdf95f8ba 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidator.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,7 +27,10 @@
import java.net.URL;
import java.time.Instant;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
/**
* An {@link OAuth2TokenValidator} responsible for
@@ -41,7 +44,6 @@
* @see ID Token Validation
*/
public final class OidcIdTokenValidator implements OAuth2TokenValidator {
- private static final OAuth2Error INVALID_ID_TOKEN_ERROR = new OAuth2Error("invalid_id_token");
private final ClientRegistration clientRegistration;
public OidcIdTokenValidator(ClientRegistration clientRegistration) {
@@ -53,27 +55,10 @@ public OidcIdTokenValidator(ClientRegistration clientRegistration) {
public OAuth2TokenValidatorResult validate(Jwt idToken) {
// 3.1.3.7 ID Token Validation
// http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
+ Map invalidClaims = validateRequiredClaims(idToken);
- // Validate REQUIRED Claims
- URL issuer = idToken.getIssuer();
- if (issuer == null) {
- return invalidIdToken();
- }
- String subject = idToken.getSubject();
- if (subject == null) {
- return invalidIdToken();
- }
- List audience = idToken.getAudience();
- if (CollectionUtils.isEmpty(audience)) {
- return invalidIdToken();
- }
- Instant expiresAt = idToken.getExpiresAt();
- if (expiresAt == null) {
- return invalidIdToken();
- }
- Instant issuedAt = idToken.getIssuedAt();
- if (issuedAt == null) {
- return invalidIdToken();
+ if (!invalidClaims.isEmpty()){
+ return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims));
}
// 2. The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery)
@@ -85,21 +70,21 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) {
// The aud (audience) Claim MAY contain an array with more than one element.
// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience,
// or if it contains additional audiences not trusted by the Client.
- if (!audience.contains(this.clientRegistration.getClientId())) {
- return invalidIdToken();
+ if (!idToken.getAudience().contains(this.clientRegistration.getClientId())) {
+ invalidClaims.put(IdTokenClaimNames.AUD, idToken.getAudience());
}
// 4. If the ID Token contains multiple audiences,
// the Client SHOULD verify that an azp Claim is present.
String authorizedParty = idToken.getClaimAsString(IdTokenClaimNames.AZP);
- if (audience.size() > 1 && authorizedParty == null) {
- return invalidIdToken();
+ if (idToken.getAudience().size() > 1 && authorizedParty == null) {
+ invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty);
}
// 5. If an azp (authorized party) Claim is present,
// the Client SHOULD verify that its client_id is the Claim Value.
if (authorizedParty != null && !authorizedParty.equals(this.clientRegistration.getClientId())) {
- return invalidIdToken();
+ invalidClaims.put(IdTokenClaimNames.AZP, authorizedParty);
}
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
@@ -108,16 +93,16 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) {
// 9. The current time MUST be before the time represented by the exp Claim.
Instant now = Instant.now();
- if (!now.isBefore(expiresAt)) {
- return invalidIdToken();
+ if (!now.isBefore(idToken.getExpiresAt())) {
+ invalidClaims.put(IdTokenClaimNames.EXP, idToken.getExpiresAt());
}
// 10. The iat Claim can be used to reject tokens that were issued too far away from the current time,
// limiting the amount of time that nonces need to be stored to prevent attacks.
// The acceptable range is Client specific.
Instant maxIssuedAt = now.plusSeconds(30);
- if (issuedAt.isAfter(maxIssuedAt)) {
- return invalidIdToken();
+ if (idToken.getIssuedAt().isAfter(maxIssuedAt)) {
+ invalidClaims.put(IdTokenClaimNames.IAT, idToken.getIssuedAt());
}
// 11. If a nonce value was sent in the Authentication Request,
@@ -127,10 +112,45 @@ public OAuth2TokenValidatorResult validate(Jwt idToken) {
// The precise method for detecting replay attacks is Client specific.
// TODO Depends on gh-4442
+ if (!invalidClaims.isEmpty()) {
+ return OAuth2TokenValidatorResult.failure(invalidIdToken(invalidClaims));
+ }
+
return OAuth2TokenValidatorResult.success();
}
- private static OAuth2TokenValidatorResult invalidIdToken() {
- return OAuth2TokenValidatorResult.failure(INVALID_ID_TOKEN_ERROR);
+ private static OAuth2Error invalidIdToken(Map invalidClaims) {
+ String claimsDetail = invalidClaims.entrySet().stream()
+ .map(it -> it.getKey()+ "("+it.getValue()+")")
+ .collect(Collectors.joining(", "));
+
+ return new OAuth2Error("invalid_id_token", "The ID Token contains invalid claims: "+claimsDetail, null);
+ }
+
+ private static Map validateRequiredClaims(Jwt idToken){
+ Map requiredClaims = new HashMap<>();
+
+ URL issuer = idToken.getIssuer();
+ if (issuer == null) {
+ requiredClaims.put(IdTokenClaimNames.ISS, issuer);
+ }
+ String subject = idToken.getSubject();
+ if (subject == null) {
+ requiredClaims.put(IdTokenClaimNames.SUB, subject);
+ }
+ List audience = idToken.getAudience();
+ if (CollectionUtils.isEmpty(audience)) {
+ requiredClaims.put(IdTokenClaimNames.AUD, audience);
+ }
+ Instant expiresAt = idToken.getExpiresAt();
+ if (expiresAt == null) {
+ requiredClaims.put(IdTokenClaimNames.EXP, expiresAt);
+ }
+ Instant issuedAt = idToken.getIssuedAt();
+ if (issuedAt == null) {
+ requiredClaims.put(IdTokenClaimNames.IAT, issuedAt);
+ }
+
+ return requiredClaims;
}
}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java
index 088bfbe900f..4522db84efe 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenValidatorTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -64,8 +64,9 @@ public void validateIdTokenWhenIssuerNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.ISS);
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.ISS));
+
}
@Test
@@ -73,8 +74,8 @@ public void validateIdTokenWhenSubNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.SUB);
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB));
}
@Test
@@ -82,8 +83,8 @@ public void validateIdTokenWhenAudNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.AUD);
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD));
}
@Test
@@ -91,8 +92,8 @@ public void validateIdTokenWhenIssuedAtNullThenHasErrors() {
this.issuedAt = null;
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT));
}
@Test
@@ -100,8 +101,8 @@ public void validateIdTokenWhenExpiresAtNullThenHasErrors() {
this.expiresAt = null;
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
}
@Test
@@ -109,8 +110,8 @@ public void validateIdTokenWhenAudMultipleAndAzpNullThenHasErrors() {
this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
}
@Test
@@ -118,8 +119,8 @@ public void validateIdTokenWhenAzpNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AZP, "other");
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
}
@Test
@@ -135,8 +136,8 @@ public void validateIdTokenWhenMultipleAudAzpNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AZP, "other-client");
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AZP));
}
@Test
@@ -144,8 +145,8 @@ public void validateIdTokenWhenAudNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD));
}
@Test
@@ -154,8 +155,8 @@ public void validateIdTokenWhenExpiredThenHasErrors() {
this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
}
@Test
@@ -164,8 +165,8 @@ public void validateIdTokenWhenIssuedAtWayInFutureThenHasErrors() {
this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(1));
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT));
}
@Test
@@ -174,8 +175,34 @@ public void validateIdTokenWhenExpiresAtBeforeNowThenHasErrors() {
this.expiresAt = Instant.from(this.issuedAt).plusSeconds(5);
assertThat(this.validateIdToken())
.hasSize(1)
- .extracting(OAuth2Error::getErrorCode)
- .contains("invalid_id_token");
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
+ }
+
+ @Test
+ public void validateIdTokenWhenMissingClaimsThenHasErrors() {
+ this.claims.remove(IdTokenClaimNames.SUB);
+ this.claims.remove(IdTokenClaimNames.AUD);
+ this.issuedAt = null;
+ this.expiresAt = null;
+ assertThat(this.validateIdToken())
+ .hasSize(1)
+ .extracting(OAuth2Error::getDescription)
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.SUB))
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.AUD))
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.IAT))
+ .allMatch(msg -> msg.contains(IdTokenClaimNames.EXP));
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void validateIdTokenWhenNoClaimsThenHasErrors() {
+ this.claims.remove(IdTokenClaimNames.ISS);
+ this.claims.remove(IdTokenClaimNames.SUB);
+ this.claims.remove(IdTokenClaimNames.AUD);
+ this.issuedAt = null;
+ this.expiresAt = null;
+ assertThat(this.validateIdToken())
+ .hasSize(1);
}
private Collection validateIdToken() {