diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java index 73f33203c18..321a492f0aa 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,9 @@ import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; +import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -317,15 +320,16 @@ private void setSharedObject(B http, Class clazz, C object) { private final class AuthenticationRequestEndpointConfig { private String filterProcessingUrl = "/saml2/authenticate/{registrationId}"; + private AuthenticationRequestEndpointConfig() { } private Filter build(B http) { Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http); + Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http); return postProcess(new Saml2WebSsoAuthenticationRequestFilter( - Saml2LoginConfigurer.this.relyingPartyRegistrationRepository, - authenticationRequestResolver)); + contextResolver, authenticationRequestResolver)); } private Saml2AuthenticationRequestFactory getResolver(B http) { @@ -335,6 +339,16 @@ private Saml2AuthenticationRequestFactory getResolver(B http) { } return resolver; } + + private Saml2AuthenticationRequestContextResolver getContextResolver(B http) { + Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class); + if (resolver == null) { + return new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver( + Saml2LoginConfigurer.this.relyingPartyRegistrationRepository)); + } + return resolver; + } } } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java index 440a10f9331..21845bcec3d 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java @@ -65,10 +65,8 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter; -import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.context.HttpRequestResponseHolder; @@ -87,6 +85,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.springframework.security.config.Customizer.withDefaults; import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; @@ -161,11 +160,11 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t Saml2AuthenticationRequestContext context = authenticationRequestContext().build(); Saml2AuthenticationRequestContextResolver resolver = CustomAuthenticationRequestContextResolver.resolver; - when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class))) + when(resolver.resolve(any(HttpServletRequest.class))) .thenReturn(context); this.mvc.perform(get("/saml2/authenticate/registration-id")) .andExpect(status().isFound()); - verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)); + verify(resolver).resolve(any(HttpServletRequest.class)); } @Test @@ -276,22 +275,11 @@ static class CustomAuthenticationRequestContextResolver extends WebSecurityConfi @Override protected void configure(HttpSecurity http) throws Exception { - ObjectPostProcessor processor - = new ObjectPostProcessor() { - @Override - public O postProcess(O filter) { - filter.setAuthenticationRequestContextResolver(resolver); - return filter; - } - }; - http .authorizeRequests(authz -> authz .anyRequest().authenticated() ) - .saml2Login(saml2 -> saml2 - .addObjectPostProcessor(processor) - ); + .saml2Login(withDefaults()); } @Bean diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java index f0270de2139..3fa3e9522ca 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver; import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; @@ -69,9 +70,8 @@ */ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter { - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver; private Saml2AuthenticationRequestFactory authenticationRequestFactory; - private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(); private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}"); @@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter */ @Deprecated public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { - this(relyingPartyRegistrationRepository, + this(new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory()); } /** * Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters * - * @param relyingPartyRegistrationRepository a repository for relying party configurations + * @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext} * @since 5.4 */ - public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + public Saml2WebSsoAuthenticationRequestFilter( + Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver, Saml2AuthenticationRequestFactory authenticationRequestFactory) { - Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); + + Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null"); - this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + this.authenticationRequestContextResolver = authenticationRequestContextResolver; this.authenticationRequestFactory = authenticationRequestFactory; } @@ -123,17 +126,6 @@ public void setRedirectMatcher(RequestMatcher redirectMatcher) { this.redirectMatcher = redirectMatcher; } - /** - * Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext} - * - * @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use - * @since 5.4 - */ - public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) { - Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null"); - this.authenticationRequestContextResolver = authenticationRequestContextResolver; - } - /** * {@inheritDoc} */ @@ -147,14 +139,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse return; } - String registrationId = matcher.getVariables().get("registrationId"); - RelyingPartyRegistration relyingParty = - this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); - if (relyingParty == null) { + Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request); + if (context == null) { response.sendError(HttpServletResponse.SC_UNAUTHORIZED); return; } - Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty); + RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration(); if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) { sendRedirect(response, context); } else { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java new file mode 100644 index 00000000000..ebb27f1d268 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; +import javax.servlet.http.HttpServletRequest; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration; +import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; +import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; + +/** + * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the + * registration id from the request, querying a {@link RelyingPartyRegistrationRepository}, + * and resolving any template values. + * + * @since 5.4 + * @author Josh Cummings + */ +public final class DefaultRelyingPartyRegistrationResolver + implements Converter { + + private static final char PATH_DELIMITER = '/'; + + private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final Converter registrationIdResolver = new RegistrationIdResolver(); + + public DefaultRelyingPartyRegistrationResolver + (RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) { + + this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; + } + + @Override + public RelyingPartyRegistration convert(HttpServletRequest request) { + String registrationId = this.registrationIdResolver.convert(request); + if (registrationId == null) { + return null; + } + RelyingPartyRegistration relyingPartyRegistration = + this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId); + if (relyingPartyRegistration == null) { + return null; + } + + String applicationUri = getApplicationUri(request); + Function templateResolver = templateResolver(applicationUri, relyingPartyRegistration); + String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId()); + String assertionConsumerServiceLocation = templateResolver.apply( + relyingPartyRegistration.getAssertionConsumerServiceLocation()); + return withRelyingPartyRegistration(relyingPartyRegistration) + .entityId(relyingPartyEntityId) + .assertionConsumerServiceLocation(assertionConsumerServiceLocation) + .build(); + } + + private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { + return template -> resolveUrlTemplate(template, applicationUri, relyingParty); + } + + private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { + String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); + String registrationId = relyingParty.getRegistrationId(); + Map uriVariables = new HashMap<>(); + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) + .replaceQuery(null) + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); + uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); + + return UriComponentsBuilder.fromUriString(template) + .buildAndExpand(uriVariables) + .toUriString(); + } + + private static String getApplicationUri(HttpServletRequest request) { + UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) + .replacePath(request.getContextPath()) + .replaceQuery(null) + .fragment(null) + .build(); + return uriComponents.toUriString(); + } + + private static class RegistrationIdResolver implements Converter { + private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}"); + + @Override + public String convert(HttpServletRequest request) { + RequestMatcher.MatchResult result = this.requestMatcher.matcher(request); + return result.getVariables().get("registrationId"); + } + } +} diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java index 7910b74bb92..b9d15b7860e 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java @@ -16,45 +16,45 @@ package org.springframework.security.saml2.provider.service.web; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; import javax.servlet.http.HttpServletRequest; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; -import org.springframework.web.util.UriComponents; -import org.springframework.web.util.UriComponentsBuilder; - -import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl; -import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl; /** * The default implementation for {@link Saml2AuthenticationRequestContextResolver} * which uses the current request and given relying party to formulate a {@link Saml2AuthenticationRequestContext} * * @author Shazin Sadakath + * @author Josh Cummings * @since 5.4 */ public final class DefaultSaml2AuthenticationRequestContextResolver implements Saml2AuthenticationRequestContextResolver { private final Log logger = LogFactory.getLog(getClass()); - private static final char PATH_DELIMITER = '/'; + private final Converter relyingPartyRegistrationResolver; + + public DefaultSaml2AuthenticationRequestContextResolver + (Converter relyingPartyRegistrationResolver) { + this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver; + } /** * {@inheritDoc} */ @Override - public Saml2AuthenticationRequestContext resolve(HttpServletRequest request, - RelyingPartyRegistration relyingParty) { + public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); - Assert.notNull(relyingParty, "relyingParty cannot be null"); + RelyingPartyRegistration relyingParty = this.relyingPartyRegistrationResolver.convert(request); + if (relyingParty == null) { + return null; + } if (this.logger.isDebugEnabled()) { this.logger.debug("Creating SAML 2.0 Authentication Request for Asserting Party [" + relyingParty.getRegistrationId() + "]"); @@ -65,59 +65,11 @@ public Saml2AuthenticationRequestContext resolve(HttpServletRequest request, private Saml2AuthenticationRequestContext createRedirectAuthenticationRequestContext( HttpServletRequest request, RelyingPartyRegistration relyingParty) { - String applicationUri = getApplicationUri(request); - Function resolver = templateResolver(applicationUri, relyingParty); - String localSpEntityId = resolver.apply(relyingParty.getEntityId()); - String assertionConsumerServiceUrl = resolver.apply(relyingParty.getAssertionConsumerServiceLocation()); return Saml2AuthenticationRequestContext.builder() - .issuer(localSpEntityId) + .issuer(relyingParty.getEntityId()) .relyingPartyRegistration(relyingParty) - .assertionConsumerServiceUrl(assertionConsumerServiceUrl) + .assertionConsumerServiceUrl(relyingParty.getAssertionConsumerServiceLocation()) .relayState(request.getParameter("RelayState")) .build(); } - - private Function templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) { - return template -> resolveUrlTemplate(template, applicationUri, relyingParty); - } - - private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) { - String entityId = relyingParty.getAssertingPartyDetails().getEntityId(); - String registrationId = relyingParty.getRegistrationId(); - Map uriVariables = new HashMap<>(); - UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl) - .replaceQuery(null) - .fragment(null) - .build(); - String scheme = uriComponents.getScheme(); - uriVariables.put("baseScheme", scheme == null ? "" : scheme); - String host = uriComponents.getHost(); - uriVariables.put("baseHost", host == null ? "" : host); - // following logic is based on HierarchicalUriComponents#toUriString() - int port = uriComponents.getPort(); - uriVariables.put("basePort", port == -1 ? "" : ":" + port); - String path = uriComponents.getPath(); - if (StringUtils.hasLength(path)) { - if (path.charAt(0) != PATH_DELIMITER) { - path = PATH_DELIMITER + path; - } - } - uriVariables.put("basePath", path == null ? "" : path); - uriVariables.put("baseUrl", uriComponents.toUriString()); - uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : ""); - uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : ""); - - return UriComponentsBuilder.fromUriString(template) - .buildAndExpand(uriVariables) - .toUriString(); - } - - private static String getApplicationUri(HttpServletRequest request) { - UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request)) - .replacePath(request.getContextPath()) - .replaceQuery(null) - .fragment(null) - .build(); - return uriComponents.toUriString(); - } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java index 1c86ec239e4..db24c8ff903 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationRequestContextResolver.java @@ -16,16 +16,16 @@ package org.springframework.security.saml2.provider.service.web; -import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; - import javax.servlet.http.HttpServletRequest; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext; + /** * This {@code Saml2AuthenticationRequestContextResolver} formulates a * SAML 2.0 AuthnRequest (line 1968) * * @author Shazin Sadakath + * @author Josh Cummings * @since 5.4 */ public interface Saml2AuthenticationRequestContextResolver { @@ -35,9 +35,7 @@ public interface Saml2AuthenticationRequestContextResolver { * * * @param request the current request - * @param relyingParty the relying party responsible for saml2 sso authentication - * @return the created {@link Saml2AuthenticationRequestContext} for request/relying party combination + * @return the created {@link Saml2AuthenticationRequestContext} for the request */ - Saml2AuthenticationRequestContext resolve(HttpServletRequest request, - RelyingPartyRegistration relyingParty); + Saml2AuthenticationRequestContext resolve(HttpServletRequest request); } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index f9613e080d8..1ea4d636c9a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -30,6 +30,7 @@ import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; @@ -41,6 +42,7 @@ import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential; +import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; public class Saml2WebSsoAuthenticationRequestFilterTests { @@ -49,6 +51,8 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { private Saml2WebSsoAuthenticationRequestFilter filter; private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class); + private Saml2AuthenticationRequestContextResolver resolver = + mock(Saml2AuthenticationRequestContextResolver.class); private MockHttpServletRequest request; private MockHttpServletResponse response; private MockFilterChain filterChain; @@ -188,12 +192,14 @@ public void doFilterWhenCustomAuthenticationRequestFactoryThenUses() throws Exce when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); when(authenticationRequest.getRelayState()).thenReturn("relay"); when(authenticationRequest.getSamlRequest()).thenReturn("saml"); - when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); + when(this.resolver.resolve(this.request)).thenReturn(authenticationRequestContext() + .relyingPartyRegistration(relyingParty) + .build()); when(this.factory.createPostAuthenticationRequest(any())) .thenReturn(authenticationRequest); Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter - (this.repository, this.factory); + (this.resolver, this.factory); filter.doFilterInternal(this.request, this.response, this.filterChain); assertThat(this.response.getContentAsString()) .contains("
") diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java new file mode 100644 index 00000000000..6b282638d56 --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.web; + +import org.junit.Test; + +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration; + +/** + * Tests for {@link DefaultRelyingPartyRegistrationResolver} + */ +public class DefaultRelyingPartyRegistrationResolverTests { + private final RelyingPartyRegistration registration = relyingPartyRegistration().build(); + private final RelyingPartyRegistrationRepository repository = + new InMemoryRelyingPartyRegistrationRepository(this.registration); + private final DefaultRelyingPartyRegistrationResolver resolver = + new DefaultRelyingPartyRegistrationResolver(this.repository); + + @Test + public void resolveWhenRequestContainsRegistrationIdThenResolves() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/some/path/" + this.registration.getRegistrationId()); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNotNull(); + assertThat(registration.getRegistrationId()) + .isEqualTo(this.registration.getRegistrationId()); + assertThat(registration.getEntityId()) + .isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId()); + assertThat(registration.getAssertionConsumerServiceLocation()) + .isEqualTo("http://localhost/login/saml2/sso/" + this.registration.getRegistrationId()); + } + + @Test + public void resolveWhenRequestContainsInvalidRegistrationIdThenNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + request.setPathInfo("/some/path/not-" + this.registration.getRegistrationId()); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNull(); + } + + @Test + public void resolveWhenRequestIsMissingRegistrationIdThenNull() { + MockHttpServletRequest request = new MockHttpServletRequest(); + RelyingPartyRegistration registration = this.resolver.convert(request); + assertThat(registration).isNull(); + } +} diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java index 182b7009653..80f2cd6afc9 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java @@ -44,11 +44,13 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests { private MockHttpServletRequest request; private RelyingPartyRegistration.Builder relyingPartyBuilder; private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver - = new DefaultSaml2AuthenticationRequestContextResolver(); + = new DefaultSaml2AuthenticationRequestContextResolver( + new DefaultRelyingPartyRegistrationResolver(id -> relyingPartyBuilder.build())); @Before public void setup() { this.request = new MockHttpServletRequest(); + this.request.setPathInfo("/saml2/authenticate/registration-id"); this.relyingPartyBuilder = RelyingPartyRegistration .withRegistrationId(REGISTRATION_ID) .localEntityIdTemplate(RELYING_PARTY_ENTITY_ID) @@ -61,52 +63,43 @@ public void setup() { @Test public void resolveWhenRequestAndRelyingPartyNotNullThenCreateSaml2AuthenticationRequestContext() { this.request.addParameter("RelayState", "relay-state"); - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder.build(); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context).isNotNull(); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo(RELYING_PARTY_SSO_URL); assertThat(context.getRelayState()).isEqualTo("relay-state"); assertThat(context.getDestination()).isEqualTo(ASSERTING_PARTY_SSO_URL); assertThat(context.getIssuer()).isEqualTo(RELYING_PARTY_ENTITY_ID); - assertThat(context.getRelyingPartyRegistration()).isSameAs(relyingParty); + assertThat(context.getRelyingPartyRegistration().getRegistrationId()) + .isSameAs(this.relyingPartyBuilder.build().getRegistrationId()); } @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsRegistrationIdThenResolves() { - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder - .assertionConsumerServiceUrlTemplate("/saml2/authenticate/{registrationId}") - .build(); + this.relyingPartyBuilder + .assertionConsumerServiceLocation("/saml2/authenticate/{registrationId}"); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()).isEqualTo("/saml2/authenticate/registration-id"); } @Test public void resolveWhenAssertionConsumerServiceUrlTemplateContainsBaseUrlThenResolves() { - RelyingPartyRegistration relyingParty = this.relyingPartyBuilder - .assertionConsumerServiceUrlTemplate("{baseUrl}/saml2/authenticate/{registrationId}") - .build(); + this.relyingPartyBuilder + .assertionConsumerServiceLocation("{baseUrl}/saml2/authenticate/{registrationId}"); Saml2AuthenticationRequestContext context = - this.authenticationRequestContextResolver.resolve(this.request, relyingParty); + this.authenticationRequestContextResolver.resolve(this.request); assertThat(context.getAssertionConsumerServiceUrl()) .isEqualTo("http://localhost/saml2/authenticate/registration-id"); } - @Test - public void resolveWhenRequestNullThenException() { - assertThatCode(() -> - this.authenticationRequestContextResolver.resolve(this.request, null)) - .isInstanceOf(IllegalArgumentException.class); - } - @Test public void resolveWhenRelyingPartyNullThenException() { assertThatCode(() -> - this.authenticationRequestContextResolver.resolve(null, this.relyingPartyBuilder.build())) + this.authenticationRequestContextResolver.resolve(null)) .isInstanceOf(IllegalArgumentException.class); } }