diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java index f603c5b177c..2c90c7891df 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletRequest; import org.opensaml.saml.saml2.core.LogoutRequest; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; @@ -34,6 +35,7 @@ * OpenSAML 4 * * @author Josh Cummings + * @author Gerhard Haege * @since 5.6 */ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver { @@ -83,6 +85,16 @@ public void setClock(Clock clock) { this.clock = clock; } + /** + * Use this {@link Converter} to compute the RelayState + * @param relayStateResolver the {@link Converter} to use + * @since 6.1 + */ + public void setRelayStateResolver(Converter relayStateResolver) { + Assert.notNull(relayStateResolver, "relayStateResolver cannot be null"); + this.logoutRequestResolver.setRelayStateResolver(relayStateResolver); + } + public static final class LogoutRequestParameters { private final HttpServletRequest request; diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java index dff48e4a0d3..6a5c774447a 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,7 @@ import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder; import org.w3c.dom.Element; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.Authentication; import org.springframework.security.saml2.Saml2Exception; import org.springframework.security.saml2.core.OpenSamlInitializationService; @@ -74,6 +75,8 @@ final class OpenSamlLogoutRequestResolver { private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; + private Converter relayStateResolver = (request) -> UUID.randomUUID().toString(); + /** * Construct a {@link OpenSamlLogoutRequestResolver} */ @@ -95,6 +98,10 @@ final class OpenSamlLogoutRequestResolver { Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML"); } + void setRelayStateResolver(Converter relayStateResolver) { + this.relayStateResolver = relayStateResolver; + } + /** * Prepare to create, sign, and serialize a SAML 2.0 Logout Request. * @@ -140,7 +147,7 @@ Saml2LogoutRequest resolve(HttpServletRequest request, Authentication authentica if (logoutRequest.getID() == null) { logoutRequest.setID("LR" + UUID.randomUUID()); } - String relayState = UUID.randomUUID().toString(); + String relayState = this.relayStateResolver.convert(request); Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration) .id(logoutRequest.getID()); if (registration.getAssertingPartyDetails().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolverTests.java index 6485068a699..833ae7aa910 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,10 @@ package org.springframework.security.saml2.provider.service.web.authentication.logout; import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.core.convert.converter.Converter; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -32,35 +34,61 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link OpenSaml4LogoutRequestResolver} */ public class OpenSaml4LogoutRequestResolverTests { - RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = mock(RelyingPartyRegistrationResolver.class); + RelyingPartyRegistration registration; + + RelyingPartyRegistrationResolver registrationResolver; + + OpenSaml4LogoutRequestResolver logoutRequestResolver; + + @BeforeEach + public void setup() { + this.registration = TestRelyingPartyRegistrations.full().build(); + this.registrationResolver = mock(RelyingPartyRegistrationResolver.class); + this.logoutRequestResolver = new OpenSaml4LogoutRequestResolver(this.registrationResolver); + } @Test public void resolveWhenCustomParametersConsumerThenUses() { - OpenSaml4LogoutRequestResolver logoutRequestResolver = new OpenSaml4LogoutRequestResolver( - this.relyingPartyRegistrationResolver); - logoutRequestResolver.setParametersConsumer((parameters) -> parameters.getLogoutRequest().setID("myid")); - HttpServletRequest request = new MockHttpServletRequest(); - RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration() - .assertingPartyDetails((party) -> party.singleLogoutServiceLocation("https://ap.example.com/logout")) - .build(); - Authentication authentication = new TestingAuthenticationToken("user", "password"); - given(this.relyingPartyRegistrationResolver.resolve(any(), any())).willReturn(registration); - Saml2LogoutRequest logoutRequest = logoutRequestResolver.resolve(request, authentication); + this.logoutRequestResolver.setParametersConsumer((parameters) -> parameters.getLogoutRequest().setID("myid")); + given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration); + + Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication()); + assertThat(logoutRequest.getId()).isEqualTo("myid"); } @Test public void setParametersConsumerWhenNullThenIllegalArgument() { - OpenSaml4LogoutRequestResolver logoutRequestResolver = new OpenSaml4LogoutRequestResolver( - this.relyingPartyRegistrationResolver); assertThatExceptionOfType(IllegalArgumentException.class) - .isThrownBy(() -> logoutRequestResolver.setParametersConsumer(null)); + .isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null)); + } + + @Test + public void resolveWhenCustomRelayStateThenUses() { + given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration); + Converter relayState = mock(Converter.class); + given(relayState.convert(any())).willReturn("any-state"); + this.logoutRequestResolver.setRelayStateResolver(relayState); + + Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication()); + + assertThat(logoutRequest.getRelayState()).isEqualTo("any-state"); + verify(relayState).convert(any()); + } + + private static Authentication givenAuthentication() { + return new TestingAuthenticationToken("user", "password"); + } + + private MockHttpServletRequest givenRequest() { + return new MockHttpServletRequest(); } }