Skip to content

Commit af5c55c

Browse files
committed
Polish AuthnRequest Customization Support
Having the application generate the AuthnRequest fresh allows Spring Security to back away more gracefully. Using a Consumer implies that the application will need to undo any values that Spring Security set that the application doesn't want. Also, if this does become a configuration burden, it can be simplified in a separate ticket by exposing the default Converter. Issue gh-8776
1 parent 3694485 commit af5c55c

File tree

4 files changed

+50
-36
lines changed

4 files changed

+50
-36
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.junit.Rule;
3636
import org.junit.Test;
3737
import org.opensaml.saml.saml2.core.Assertion;
38+
import org.opensaml.saml.saml2.core.AuthnRequest;
3839

3940
import org.springframework.beans.factory.annotation.Autowired;
4041
import org.springframework.context.ConfigurableApplicationContext;
@@ -89,6 +90,7 @@
8990
import static org.mockito.Mockito.when;
9091
import static org.springframework.security.config.Customizer.withDefaults;
9192
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
93+
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
9294
import static org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationRequestContexts.authenticationRequestContext;
9395
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
9496
import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
@@ -176,8 +178,8 @@ public void saml2LoginWhenCustomAuthenticationRequestContextResolverThenUses() t
176178
}
177179

178180
@Test
179-
public void authenticationRequestWhenAuthnRequestConsumerResolverThenUses() throws Exception {
180-
this.spring.register(CustomAuthnRequestConsumerResolver.class).autowire();
181+
public void authenticationRequestWhenAuthnRequestContextConverterThenUses() throws Exception {
182+
this.spring.register(CustomAuthenticationRequestContextConverterResolver.class).autowire();
181183

182184
MvcResult result = this.mvc.perform(get("/saml2/authenticate/registration-id"))
183185
.andReturn();
@@ -315,7 +317,7 @@ Saml2AuthenticationRequestContextResolver resolver() {
315317

316318
@EnableWebSecurity
317319
@Import(Saml2LoginConfigBeans.class)
318-
static class CustomAuthnRequestConsumerResolver extends WebSecurityConfigurerAdapter {
320+
static class CustomAuthenticationRequestContextConverterResolver extends WebSecurityConfigurerAdapter {
319321

320322
@Override
321323
protected void configure(HttpSecurity http) throws Exception {
@@ -330,8 +332,12 @@ protected void configure(HttpSecurity http) throws Exception {
330332
Saml2AuthenticationRequestFactory authenticationRequestFactory() {
331333
OpenSamlAuthenticationRequestFactory authenticationRequestFactory =
332334
new OpenSamlAuthenticationRequestFactory();
333-
authenticationRequestFactory.setAuthnRequestConsumerResolver(
334-
context -> authnRequest -> authnRequest.setForceAuthn(true));
335+
authenticationRequestFactory.setAuthenticationRequestContextConverter(
336+
context -> {
337+
AuthnRequest authnRequest = authnRequest();
338+
authnRequest.setForceAuthn(true);
339+
return authnRequest;
340+
});
335341
return authenticationRequestFactory;
336342
}
337343
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
import java.util.LinkedHashMap;
2626
import java.util.Map;
2727
import java.util.UUID;
28-
import java.util.function.Consumer;
29-
import java.util.function.Function;
3028

3129
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
3230
import org.joda.time.DateTime;
@@ -88,8 +86,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
8886
return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
8987
};
9088

91-
private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
92-
= context -> authnRequest -> {};
89+
private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter
90+
= this::createAuthnRequest;
9391

9492
/**
9593
* Creates an {@link OpenSamlAuthenticationRequestFactory}
@@ -124,7 +122,7 @@ public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
124122
*/
125123
@Override
126124
public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
127-
AuthnRequest authnRequest = createAuthnRequest(context);
125+
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
128126
String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
129127
serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
130128
serialize(authnRequest);
@@ -139,7 +137,7 @@ public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2Authe
139137
*/
140138
@Override
141139
public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
142-
AuthnRequest authnRequest = createAuthnRequest(context);
140+
AuthnRequest authnRequest = this.authenticationRequestContextConverter.convert(context);
143141
String xml = serialize(authnRequest);
144142
Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
145143
String deflatedAndEncoded = samlEncode(samlDeflate(xml));
@@ -168,11 +166,9 @@ public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Sa
168166
}
169167

170168
private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
171-
AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
169+
return createAuthnRequest(context.getIssuer(),
172170
context.getDestination(), context.getAssertionConsumerServiceUrl(),
173171
this.protocolBindingResolver.convert(context));
174-
this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
175-
return authnRequest;
176172
}
177173

178174
private AuthnRequest createAuthnRequest
@@ -194,13 +190,13 @@ private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext contex
194190
/**
195191
* Set the {@link AuthnRequest} post-processor resolver
196192
*
197-
* @param authnRequestConsumerResolver
193+
* @param authenticationRequestContextConverter
198194
* @since 5.4
199195
*/
200-
public void setAuthnRequestConsumerResolver(
201-
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver) {
202-
Assert.notNull(authnRequestConsumerResolver, "authnRequestConsumerResolver cannot be null");
203-
this.authnRequestConsumerResolver = authnRequestConsumerResolver;
196+
public void setAuthenticationRequestContextConverter(
197+
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter) {
198+
Assert.notNull(authenticationRequestContextConverter, "authenticationRequestContextConverter cannot be null");
199+
this.authenticationRequestContextConverter = authenticationRequestContextConverter;
204200
}
205201

206202
/**

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
package org.springframework.security.saml2.provider.service.authentication;
1818

1919
import java.io.ByteArrayInputStream;
20-
import java.util.function.Consumer;
21-
import java.util.function.Function;
2220

2321
import org.junit.Assert;
2422
import org.junit.Before;
@@ -31,6 +29,7 @@
3129
import org.w3c.dom.Document;
3230
import org.w3c.dom.Element;
3331

32+
import org.springframework.core.convert.converter.Converter;
3433
import org.springframework.security.saml2.Saml2Exception;
3534
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3635
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
@@ -47,6 +46,7 @@
4746
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
4847
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
4948
import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
49+
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.authnRequest;
5050
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
5151
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
5252
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -63,8 +63,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
6363
private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
6464
private RelyingPartyRegistration relyingPartyRegistration;
6565

66-
private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
67-
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
66+
private AuthnRequestUnmarshaller unmarshaller;
6867

6968
@Rule
7069
public ExpectedException exception = ExpectedException.none();
@@ -84,6 +83,8 @@ public void setUp() {
8483
.assertionConsumerServiceUrl("https://issuer/sso");
8584
context = contextBuilder.build();
8685
factory = new OpenSamlAuthenticationRequestFactory();
86+
this.unmarshaller =(AuthnRequestUnmarshaller) getUnmarshallerFactory()
87+
.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
8788
}
8889

8990
@Test
@@ -182,29 +183,29 @@ public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalAr
182183

183184
@Test
184185
public void createPostAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
185-
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
186-
mock(Function.class);
187-
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
188-
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
186+
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
187+
mock(Converter.class);
188+
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
189+
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
189190

190191
this.factory.createPostAuthenticationRequest(this.context);
191-
verify(authnRequestConsumerResolver).apply(this.context);
192+
verify(authenticationRequestContextConverter).convert(this.context);
192193
}
193194

194195
@Test
195196
public void createRedirectAuthenticationRequestWhenAuthnRequestConsumerThenUses() {
196-
Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver =
197-
mock(Function.class);
198-
when(authnRequestConsumerResolver.apply(this.context)).thenReturn(authnRequest -> {});
199-
this.factory.setAuthnRequestConsumerResolver(authnRequestConsumerResolver);
197+
Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter =
198+
mock(Converter.class);
199+
when(authenticationRequestContextConverter.convert(this.context)).thenReturn(authnRequest());
200+
this.factory.setAuthenticationRequestContextConverter(authenticationRequestContextConverter);
200201

201202
this.factory.createRedirectAuthenticationRequest(this.context);
202-
verify(authnRequestConsumerResolver).apply(this.context);
203+
verify(authenticationRequestContextConverter).convert(this.context);
203204
}
204205

205206
@Test
206-
public void setAuthnRequestConsumerResolverWhenNullThenException() {
207-
assertThatCode(() -> this.factory.setAuthnRequestConsumerResolver(null))
207+
public void setAuthenticationRequestContextConverterWhenNullThenException() {
208+
assertThatCode(() -> this.factory.setAuthenticationRequestContextConverter(null))
208209
.isInstanceOf(IllegalArgumentException.class);
209210
}
210211

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.opensaml.saml.saml2.core.Attribute;
5454
import org.opensaml.saml.saml2.core.AttributeStatement;
5555
import org.opensaml.saml.saml2.core.AttributeValue;
56+
import org.opensaml.saml.saml2.core.AuthnRequest;
5657
import org.opensaml.saml.saml2.core.Conditions;
5758
import org.opensaml.saml.saml2.core.EncryptedAssertion;
5859
import org.opensaml.saml.saml2.core.EncryptedID;
@@ -86,7 +87,7 @@
8687
import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
8788
import static org.springframework.security.saml2.core.TestSaml2X509Credentials.assertingPartySigningCredential;
8889

89-
final class TestOpenSamlObjects {
90+
public final class TestOpenSamlObjects {
9091
static {
9192
OpenSamlInitializationService.initialize();
9293
}
@@ -188,6 +189,16 @@ static Conditions conditions() {
188189
return conditions;
189190
}
190191

192+
public static AuthnRequest authnRequest() {
193+
Issuer issuer = build(Issuer.DEFAULT_ELEMENT_NAME);
194+
issuer.setValue(ASSERTING_PARTY_ENTITY_ID);
195+
AuthnRequest authnRequest = build(AuthnRequest.DEFAULT_ELEMENT_NAME);
196+
authnRequest.setIssuer(issuer);
197+
authnRequest.setDestination(ASSERTING_PARTY_ENTITY_ID + "/SSO.saml2");
198+
authnRequest.setAssertionConsumerServiceURL(DESTINATION);
199+
return authnRequest;
200+
}
201+
191202
static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
192203
BasicCredential cred = getBasicCredential(credential);
193204
cred.setEntityId(entityId);

0 commit comments

Comments
 (0)