Skip to content

Commit 6e45e65

Browse files
YoungKi Hongjzheaux
authored andcommitted
Update to return List of StatusCodes and add Saml2Error to result object and other formatting
1 parent 76331a5 commit 6e45e65

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@
2020
import java.nio.charset.StandardCharsets;
2121
import java.time.Duration;
2222
import java.util.ArrayList;
23+
import java.util.Arrays;
2324
import java.util.Collection;
2425
import java.util.Collections;
2526
import java.util.HashMap;
27+
import java.util.HashSet;
2628
import java.util.LinkedHashMap;
2729
import java.util.List;
2830
import java.util.Map;
2931
import java.util.Set;
30-
import java.util.HashSet;
31-
import java.util.Arrays;
32-
import java.util.Optional;
3332
import java.util.function.Consumer;
3433

3534
import javax.annotation.Nonnull;
@@ -98,8 +97,6 @@
9897
import org.springframework.util.MultiValueMap;
9998
import org.springframework.util.StringUtils;
10099

101-
import static org.opensaml.saml.saml2.core.StatusCode.*;
102-
103100
/**
104101
* Implementation of {@link AuthenticationProvider} for SAML authentications when
105102
* receiving a {@code Response} object containing an {@code Assertion}. This
@@ -174,7 +171,8 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
174171

175172
private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createDefaultResponseAuthenticationConverter();
176173

177-
private static final Set<String> includeChildStatusCodes = new HashSet<>(Arrays.asList(REQUESTER, RESPONDER, VERSION_MISMATCH));
174+
private static final Set<String> includeChildStatusCodes = new HashSet<>(
175+
Arrays.asList(StatusCode.REQUESTER, StatusCode.RESPONDER, StatusCode.VERSION_MISMATCH));
178176

179177
/**
180178
* Creates an {@link OpenSaml4AuthenticationProvider}
@@ -379,11 +377,13 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
379377
Response response = responseToken.getResponse();
380378
Saml2AuthenticationToken token = responseToken.getToken();
381379
Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
382-
String statusCode = getStatusCode(response);
383-
if (!StatusCode.SUCCESS.equals(statusCode)) {
384-
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
385-
response.getID());
386-
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
380+
List<String> statusCodes = getStatusCodes(response);
381+
if (!isSuccess(statusCodes)) {
382+
for (String statusCode : statusCodes) {
383+
String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
384+
response.getID());
385+
result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
386+
}
387387
}
388388

389389
String inResponseTo = response.getInResponseTo();
@@ -412,24 +412,37 @@ public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefau
412412
};
413413
}
414414

415-
private static String getStatusCode(Response response) {
415+
private static List<String> getStatusCodes(Response response) {
416416
if (response.getStatus() == null) {
417-
return StatusCode.SUCCESS;
417+
return Arrays.asList(StatusCode.SUCCESS);
418418
}
419419
if (response.getStatus().getStatusCode() == null) {
420-
return StatusCode.SUCCESS;
420+
return Arrays.asList(StatusCode.SUCCESS);
421421
}
422422

423423
StatusCode parentStatusCode = response.getStatus().getStatusCode();
424424
String parentStatusCodeValue = parentStatusCode.getValue();
425425
if (includeChildStatusCodes.contains(parentStatusCodeValue)) {
426-
return Optional.ofNullable(parentStatusCode.getStatusCode())
427-
.map(StatusCode::getValue)
428-
.map(childStatusCodeValue -> parentStatusCodeValue + childStatusCodeValue)
429-
.orElse(parentStatusCodeValue);
426+
StatusCode statusCode = parentStatusCode.getStatusCode();
427+
if (statusCode != null) {
428+
String childStatusCodeValue = statusCode.getValue();
429+
if (childStatusCodeValue != null) {
430+
return Arrays.asList(parentStatusCodeValue, childStatusCodeValue);
431+
}
432+
}
433+
return Arrays.asList(parentStatusCodeValue);
434+
}
435+
436+
return Arrays.asList(parentStatusCodeValue);
437+
}
438+
439+
private static boolean isSuccess(List<String> statusCodes) {
440+
if (statusCodes.size() != 1) {
441+
return false;
430442
}
431443

432-
return parentStatusCodeValue;
444+
String statusCode = statusCodes.get(0);
445+
return StatusCode.SUCCESS.equals(statusCode);
433446
}
434447

435448
private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -86,8 +86,6 @@
8686
import static org.assertj.core.api.Assertions.assertThat;
8787
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
8888
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
89-
import static org.junit.Assert.assertFalse;
90-
import static org.junit.Assert.assertTrue;
9189
import static org.mockito.ArgumentMatchers.any;
9290
import static org.mockito.BDDMockito.given;
9391
import static org.mockito.Mockito.atLeastOnce;
@@ -736,15 +734,16 @@ public void authenticateWhenCustomResponseValidatorThenUses() {
736734
}
737735

738736
@Test
739-
public void setsOnlyParentStatusCodeOnResultDescription() {
737+
public void authenticateWhenResponseStatusIsNotSuccessThenOnlyReturnParentStatusCodes() {
740738
ResponseToken mockResponseToken = mock(ResponseToken.class);
741739
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
742740
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
743741

744742
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
745743
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
746744

747-
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
745+
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
746+
RelyingPartyRegistration.AssertingPartyDetails.class);
748747
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
749748

750749
Status parentStatus = new StatusBuilder().buildObject();
@@ -763,24 +762,30 @@ public void setsOnlyParentStatusCodeOnResultDescription() {
763762

764763
given(mockResponseToken.getResponse()).willReturn(mockResponse);
765764

766-
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
765+
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
766+
.createDefaultResponseValidator();
767767
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
768768

769-
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue());
770-
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
771-
assertFalse(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(childStatusCode.getValue())));
769+
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response",
770+
parentStatusCode.getValue());
771+
assertThat(
772+
result.getErrors().stream().anyMatch((error) -> error.getDescription().contains(expectedErrorMessage)));
773+
assertThat(result.getErrors()
774+
.stream()
775+
.noneMatch((error) -> error.getDescription().contains(childStatusCode.getValue())));
772776
}
773777

774778
@Test
775-
public void setsParentAndChildStatusCodeOnResultDescription() {
779+
public void authenticateWhenResponseStatusIsNotSuccessThenReturnParentAndChildStatusCode() {
776780
ResponseToken mockResponseToken = mock(ResponseToken.class);
777781
Saml2AuthenticationToken mockSamlToken = mock(Saml2AuthenticationToken.class);
778782
given(mockResponseToken.getToken()).willReturn(mockSamlToken);
779783

780784
RelyingPartyRegistration mockRelyingPartyRegistration = mock(RelyingPartyRegistration.class);
781785
given(mockSamlToken.getRelyingPartyRegistration()).willReturn(mockRelyingPartyRegistration);
782786

783-
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(RelyingPartyRegistration.AssertingPartyDetails.class);
787+
RelyingPartyRegistration.AssertingPartyDetails mockAssertingPartyDetails = mock(
788+
RelyingPartyRegistration.AssertingPartyDetails.class);
784789
given(mockRelyingPartyRegistration.getAssertingPartyDetails()).willReturn(mockAssertingPartyDetails);
785790

786791
Status parentStatus = new StatusBuilder().buildObject();
@@ -799,11 +804,20 @@ public void setsParentAndChildStatusCodeOnResultDescription() {
799804

800805
given(mockResponseToken.getResponse()).willReturn(mockResponse);
801806

802-
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider.createDefaultResponseValidator();
807+
Converter<ResponseToken, Saml2ResponseValidatorResult> validator = OpenSaml4AuthenticationProvider
808+
.createDefaultResponseValidator();
803809
Saml2ResponseValidatorResult result = validator.convert(mockResponseToken);
804810

805-
String expectedErrorMessage = String.format("Invalid status [%s] for SAML response", parentStatusCode.getValue() + childStatusCode.getValue());
806-
assertTrue(result.getErrors().stream().anyMatch(error -> error.getDescription().contains(expectedErrorMessage)));
811+
String expectedParentErrorMessage = String.format("Invalid status [%s] for SAML response",
812+
parentStatusCode.getValue());
813+
String expectedChildErrorMessage = String.format("Invalid status [%s] for SAML response",
814+
childStatusCode.getValue());
815+
assertThat(result.getErrors()
816+
.stream()
817+
.anyMatch((error) -> error.getDescription().contains(expectedParentErrorMessage)));
818+
assertThat(result.getErrors()
819+
.stream()
820+
.anyMatch((error) -> error.getDescription().contains(expectedChildErrorMessage)));
807821
}
808822

809823
@Test

0 commit comments

Comments
 (0)