diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java index 83ad3cace45..a1318a7eb9b 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,12 +20,15 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Consumer; import javax.annotation.Nonnull; @@ -168,6 +171,9 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv private Converter responseAuthenticationConverter = createDefaultResponseAuthenticationConverter(); + private static final Set includeChildStatusCodes = new HashSet<>( + Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH)); + /** * Creates an {@link OpenSaml4AuthenticationProvider} */ @@ -371,11 +377,13 @@ public static Converter createDefau Response response = responseToken.getResponse(); Saml2AuthenticationToken token = responseToken.getToken(); Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success(); - String statusCode = getStatusCode(response); - if (!StatusCode.SUCCESS.equals(statusCode)) { - String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, - response.getID()); - result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + List statusCodes = getStatusCodes(response); + if (!isSuccess(statusCodes)) { + for (String statusCode : statusCodes) { + String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode, + response.getID()); + result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message)); + } } String inResponseTo = response.getInResponseTo(); @@ -404,6 +412,39 @@ public static Converter createDefau }; } + private static List getStatusCodes(Response response) { + if (response.getStatus() == null) { + return Arrays.asList(StatusCode.SUCCESS); + } + if (response.getStatus().getStatusCode() == null) { + return Arrays.asList(StatusCode.SUCCESS); + } + + StatusCode parentStatusCode = response.getStatus().getStatusCode(); + String parentStatusCodeValue = parentStatusCode.getValue(); + if (includeChildStatusCodes.contains(parentStatusCodeValue)) { + StatusCode statusCode = parentStatusCode.getStatusCode(); + if (statusCode != null) { + String childStatusCodeValue = statusCode.getValue(); + if (childStatusCodeValue != null) { + return Arrays.asList(parentStatusCodeValue, childStatusCodeValue); + } + } + return Arrays.asList(parentStatusCodeValue); + } + + return Arrays.asList(parentStatusCodeValue); + } + + private static boolean isSuccess(List statusCodes) { + if (statusCodes.size() != 1) { + return false; + } + + String statusCode = statusCodes.get(0); + return StatusCode.SUCCESS.equals(statusCode); + } + private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest, String inResponseTo) { if (!StringUtils.hasText(inResponseTo)) { @@ -614,16 +655,6 @@ private Consumer createDefaultResponseElementsDecrypter() { }; } - private static String getStatusCode(Response response) { - if (response.getStatus() == null) { - return StatusCode.SUCCESS; - } - if (response.getStatus().getStatusCode() == null) { - return StatusCode.SUCCESS; - } - return response.getStatus().getStatusCode().getValue(); - } - private Converter createDefaultAssertionSignatureValidator() { return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> { RelyingPartyRegistration registration = assertionToken.getToken().getRelyingPartyRegistration(); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java index eccc42693f8..8432b5760a2 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,9 +51,11 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion; import org.opensaml.saml.saml2.core.EncryptedAttribute; import org.opensaml.saml.saml2.core.EncryptedID; +import org.opensaml.saml.saml2.core.Issuer; import org.opensaml.saml.saml2.core.NameID; import org.opensaml.saml.saml2.core.OneTimeUse; import org.opensaml.saml.saml2.core.Response; +import org.opensaml.saml.saml2.core.Status; import org.opensaml.saml.saml2.core.StatusCode; import org.opensaml.saml.saml2.core.SubjectConfirmation; import org.opensaml.saml.saml2.core.SubjectConfirmationData; @@ -61,6 +63,8 @@ import org.opensaml.saml.saml2.core.impl.EncryptedAssertionBuilder; import org.opensaml.saml.saml2.core.impl.EncryptedIDBuilder; import org.opensaml.saml.saml2.core.impl.NameIDBuilder; +import org.opensaml.saml.saml2.core.impl.StatusBuilder; +import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder; import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder; import org.opensaml.xmlsec.signature.support.SignatureConstants; import org.w3c.dom.Element; @@ -729,6 +733,93 @@ public void authenticateWhenCustomResponseValidatorThenUses() { verify(validator).convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class)); } + @Test + public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() { + ResponseToken mockResponseToken = mock(ResponseToken.class); + Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); + given(mockResponseToken.getToken()).willReturn(mockSamlToken); + + RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); + given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); + + RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( + RelyingPartyRegistration.AssertingPartyDetails.class); + given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); + + Status parentStatus = new StatusBuilder().buildObject(); + StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); + parentStatusCode.setValue(StatusCode.AUTHN_FAILED); + StatusCode childStatusCode = new StatusCodeBuilder().buildObject(); + childStatusCode.setValue(StatusCode.NO_PASSIVE); + parentStatusCode.setStatusCode(childStatusCode); + parentStatus.setStatusCode(parentStatusCode); + + Response mockResponse = mock(Response.class); + given(mockResponse.getStatus()).willReturn(parentStatus); + Issuer mockIssuer = mock(Issuer.class); + given(mockIssuer.getValue()).willReturn("mockedIssuer"); + given(mockResponse.getIssuer()).willReturn(mockIssuer); + + given(mockResponseToken.getResponse()).willReturn(mockResponse); + + Converter validator = OpenSaml4AuthenticationProvider + .createDefaultResponseValidator(); + Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); + + String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + assertThat( + result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage))); + assertThat(result.getErrors() + .stream() + .noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue()))); + } + + @Test + public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() { + ResponseToken mockResponseToken = mock(ResponseToken.class); + Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class); + given(mockResponseToken.getToken()).willReturn(mockSamlToken); + + RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class); + given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration); + + RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock( + RelyingPartyRegistration.AssertingPartyDetails.class); + given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails); + + Status parentStatus = new StatusBuilder().buildObject(); + StatusCode parentStatusCode = new StatusCodeBuilder().buildObject(); + parentStatusCode.setValue(StatusCode.REQUESTER); + StatusCode childStatusCode = new StatusCodeBuilder().buildObject(); + childStatusCode.setValue(StatusCode.NO_PASSIVE); + parentStatusCode.setStatusCode(childStatusCode); + parentStatus.setStatusCode(parentStatusCode); + + Response mockResponse = mock(Response.class); + given(mockResponse.getStatus()).willReturn(parentStatus); + Issuer mockIssuer = mock(Issuer.class); + given(mockIssuer.getValue()).willReturn("mockedIssuer"); + given(mockResponse.getIssuer()).willReturn(mockIssuer); + + given(mockResponseToken.getResponse()).willReturn(mockResponse); + + Converter validator = OpenSaml4AuthenticationProvider + .createDefaultResponseValidator(); + Saml2ResponseValidatorResult result = validator.convert(mockResponseToken); + + String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response", + parentStatusCode.getValue()); + String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response", + childStatusCode.getValue()); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage))); + assertThat(result.getErrors() + .stream() + .anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage))); + } + @Test public void authenticateWhenAssertionIssuerNotValidThenFailsWithInvalidIssuer() { OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();