Skip to content

Add DefaultRelyingPartyRegistrationResolver #8899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,9 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -317,15 +320,16 @@ private <C> void setSharedObject(B http, Class<C> clazz, C object) {

private final class AuthenticationRequestEndpointConfig {
private String filterProcessingUrl = "/saml2/authenticate/{registrationId}";

private AuthenticationRequestEndpointConfig() {
}

private Filter build(B http) {
Saml2AuthenticationRequestFactory authenticationRequestResolver = getResolver(http);
Saml2AuthenticationRequestContextResolver contextResolver = getContextResolver(http);

return postProcess(new Saml2WebSsoAuthenticationRequestFilter(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository,
authenticationRequestResolver));
contextResolver, authenticationRequestResolver));
}

private Saml2AuthenticationRequestFactory getResolver(B http) {
Expand All @@ -335,6 +339,16 @@ private Saml2AuthenticationRequestFactory getResolver(B http) {
}
return resolver;
}

private Saml2AuthenticationRequestContextResolver getContextResolver(B http) {
Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http, Saml2AuthenticationRequestContextResolver.class);
if (resolver == null) {
return new DefaultSaml2AuthenticationRequestContextResolver(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to have a single return statement.
new DefaultSaml2AuthenticationRequestContextResolver can be moved into getDefaultContextResolver method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm of a different opinion on this one. I think returning early is more readable.

However, since that's probably in the eye of the beholder, what I'd propose for now is aligning with how other similar methods in the class, like getResolver. We can consider changing both methods at another time.

new DefaultRelyingPartyRegistrationResolver(
Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
}
return resolver;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestContext;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationToken;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.context.HttpRequestResponseHolder;
Expand All @@ -87,6 +85,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.springframework.security.config.Customizer.withDefaults;
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
Expand Down Expand Up @@ -161,11 +160,11 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
Saml2AuthenticationRequestContext context = authenticationRequestContext().build();
Saml2AuthenticationRequestContextResolver resolver =
CustomAuthenticationRequestContextResolver.resolver;
when(resolver.resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class)))
when(resolver.resolve(any(HttpServletRequest.class)))
.thenReturn(context);
this.mvc.perform(get("/saml2/authenticate/registration-id"))
.andExpect(status().isFound());
verify(resolver).resolve(any(HttpServletRequest.class), any(RelyingPartyRegistration.class));
verify(resolver).resolve(any(HttpServletRequest.class));
}

@Test
Expand Down Expand Up @@ -276,22 +275,11 @@ static class CustomAuthenticationRequestContextResolver extends WebSecurityConfi

@Override
protected void configure(HttpSecurity http) throws Exception {
ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter> processor
= new ObjectPostProcessor<Saml2WebSsoAuthenticationRequestFilter>() {
@Override
public <O extends Saml2WebSsoAuthenticationRequestFilter> O postProcess(O filter) {
filter.setAuthenticationRequestContextResolver(resolver);
return filter;
}
};

http
.authorizeRequests(authz -> authz
.anyRequest().authenticated()
)
.saml2Login(saml2 -> saml2
.addObjectPostProcessor(processor)
);
.saml2Login(withDefaults());
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
Expand Down Expand Up @@ -69,9 +70,8 @@
*/
public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter {

private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver;
private Saml2AuthenticationRequestFactory authenticationRequestFactory;
private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver();

private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/authenticate/{registrationId}");

Expand All @@ -83,21 +83,24 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
*/
@Deprecated
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
this(relyingPartyRegistrationRepository,
this(new DefaultSaml2AuthenticationRequestContextResolver(
new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)),
new org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationRequestFactory());
}

/**
* Construct a {@link Saml2WebSsoAuthenticationRequestFilter} with the provided parameters
*
* @param relyingPartyRegistrationRepository a repository for relying party configurations
* @param authenticationRequestContextResolver a strategy for formulating a {@link Saml2AuthenticationRequestContext}
* @since 5.4
*/
public Saml2WebSsoAuthenticationRequestFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
public Saml2WebSsoAuthenticationRequestFilter(
Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver,
Saml2AuthenticationRequestFactory authenticationRequestFactory) {
Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");

Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
Assert.notNull(authenticationRequestFactory, "authenticationRequestFactory cannot be null");
this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
this.authenticationRequestFactory = authenticationRequestFactory;
}

Expand All @@ -123,17 +126,6 @@ public void setRedirectMatcher(RequestMatcher redirectMatcher) {
this.redirectMatcher = redirectMatcher;
}

/**
* Use the given {@link Saml2AuthenticationRequestContextResolver} that creates a {@link Saml2AuthenticationRequestContext}
*
* @param authenticationRequestContextResolver the {@link Saml2AuthenticationRequestContextResolver} to use
* @since 5.4
*/
public void setAuthenticationRequestContextResolver(Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver) {
Assert.notNull(authenticationRequestContextResolver, "authenticationRequestContextResolver cannot be null");
this.authenticationRequestContextResolver = authenticationRequestContextResolver;
}

/**
* {@inheritDoc}
*/
Expand All @@ -147,14 +139,12 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
return;
}

String registrationId = matcher.getVariables().get("registrationId");
RelyingPartyRegistration relyingParty =
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (relyingParty == null) {
Saml2AuthenticationRequestContext context = this.authenticationRequestContextResolver.resolve(request);
if (context == null) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
Saml2AuthenticationRequestContext context = authenticationRequestContextResolver.resolve(request, relyingParty);
RelyingPartyRegistration relyingParty = context.getRelyingPartyRegistration();
if (relyingParty.getAssertingPartyDetails().getSingleSignOnServiceBinding() == Saml2MessageBinding.REDIRECT) {
sendRedirect(response, context);
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.web;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import javax.servlet.http.HttpServletRequest;

import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;

/**
* A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
* registration id from the request, querying a {@link RelyingPartyRegistrationRepository},
* and resolving any template values.
*
* @since 5.4
* @author Josh Cummings
*/
public final class DefaultRelyingPartyRegistrationResolver
implements Converter<HttpServletRequest, RelyingPartyRegistration> {

private static final char PATH_DELIMITER = '/';

private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();

public DefaultRelyingPartyRegistrationResolver
(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {

this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
}

@Override
public RelyingPartyRegistration convert(HttpServletRequest request) {
String registrationId = this.registrationIdResolver.convert(request);
if (registrationId == null) {
return null;
}
RelyingPartyRegistration relyingPartyRegistration =
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
if (relyingPartyRegistration == null) {
return null;
}

String applicationUri = getApplicationUri(request);
Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
String assertionConsumerServiceLocation = templateResolver.apply(
relyingPartyRegistration.getAssertionConsumerServiceLocation());
return withRelyingPartyRegistration(relyingPartyRegistration)
.entityId(relyingPartyEntityId)
.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
.build();
}

private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
}

private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
String registrationId = relyingParty.getRegistrationId();
Map<String, String> uriVariables = new HashMap<>();
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
.replaceQuery(null)
.fragment(null)
.build();
String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", scheme == null ? "" : scheme);
String host = uriComponents.getHost();
uriVariables.put("baseHost", host == null ? "" : host);
// following logic is based on HierarchicalUriComponents#toUriString()
int port = uriComponents.getPort();
uriVariables.put("basePort", port == -1 ? "" : ":" + port);
String path = uriComponents.getPath();
if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
path = PATH_DELIMITER + path;
}
uriVariables.put("basePath", path == null ? "" : path);
uriVariables.put("baseUrl", uriComponents.toUriString());
uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");

return UriComponentsBuilder.fromUriString(template)
.buildAndExpand(uriVariables)
.toUriString();
}

private static String getApplicationUri(HttpServletRequest request) {
UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.replaceQuery(null)
.fragment(null)
.build();
return uriComponents.toUriString();
}

private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");

@Override
public String convert(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
return result.getVariables().get("registrationId");
}
}
}
Loading