1
1
/*
2
- * Copyright 2002-2022 the original author or authors.
2
+ * Copyright 2002-2023 the original author or authors.
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
16
16
17
17
package org .springframework .security .saml2 .provider .service .web ;
18
18
19
+ import java .io .ByteArrayInputStream ;
19
20
import java .io .ByteArrayOutputStream ;
20
21
import java .nio .charset .StandardCharsets ;
21
22
import java .util .Arrays ;
25
26
import java .util .zip .InflaterOutputStream ;
26
27
27
28
import jakarta .servlet .http .HttpServletRequest ;
29
+ import net .shibboleth .utilities .java .support .xml .ParserPool ;
30
+ import org .opensaml .core .config .ConfigurationService ;
31
+ import org .opensaml .core .xml .config .XMLObjectProviderRegistry ;
32
+ import org .opensaml .core .xml .config .XMLObjectProviderRegistrySupport ;
33
+ import org .opensaml .saml .saml2 .core .Response ;
34
+ import org .opensaml .saml .saml2 .core .impl .ResponseUnmarshaller ;
35
+ import org .w3c .dom .Document ;
36
+ import org .w3c .dom .Element ;
28
37
29
38
import org .springframework .http .HttpMethod ;
39
+ import org .springframework .security .saml2 .Saml2Exception ;
40
+ import org .springframework .security .saml2 .core .OpenSamlInitializationService ;
30
41
import org .springframework .security .saml2 .core .Saml2Error ;
31
42
import org .springframework .security .saml2 .core .Saml2ErrorCodes ;
32
43
import org .springframework .security .saml2 .core .Saml2ParameterNames ;
33
44
import org .springframework .security .saml2 .provider .service .authentication .AbstractSaml2AuthenticationRequest ;
34
45
import org .springframework .security .saml2 .provider .service .authentication .Saml2AuthenticationException ;
35
46
import org .springframework .security .saml2 .provider .service .authentication .Saml2AuthenticationToken ;
36
47
import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration ;
48
+ import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistrationRepository ;
49
+ import org .springframework .security .saml2 .provider .service .web .RelyingPartyRegistrationPlaceholderResolvers .UriResolver ;
37
50
import org .springframework .security .web .authentication .AuthenticationConverter ;
51
+ import org .springframework .security .web .util .matcher .AntPathRequestMatcher ;
52
+ import org .springframework .security .web .util .matcher .OrRequestMatcher ;
53
+ import org .springframework .security .web .util .matcher .RequestMatcher ;
38
54
import org .springframework .util .Assert ;
39
55
40
56
/**
43
59
* {@link org.springframework.security.authentication.AuthenticationManager}.
44
60
*
45
61
* @author Josh Cummings
46
- * @since 5.4
62
+ * @since 6.1
47
63
*/
48
- public final class Saml2AuthenticationTokenConverter implements AuthenticationConverter {
64
+ public final class OpenSamlAuthenticationTokenConverter implements AuthenticationConverter {
65
+
66
+ static {
67
+ OpenSamlInitializationService .initialize ();
68
+ }
49
69
50
70
// MimeDecoder allows extra line-breaks as well as other non-alphabet values.
51
71
// This matches the behaviour of the commons-codec decoder.
52
72
private static final Base64 .Decoder BASE64 = Base64 .getMimeDecoder ();
53
73
54
74
private static final Base64Checker BASE_64_CHECKER = new Base64Checker ();
55
75
56
- private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver ;
76
+ private final RelyingPartyRegistrationRepository registrations ;
77
+
78
+ private RequestMatcher requestMatcher = new OrRequestMatcher (
79
+ new AntPathRequestMatcher ("/login/saml2/sso/{registrationId}" ),
80
+ new AntPathRequestMatcher ("/login/saml2/sso" ));
81
+
82
+ private final ParserPool parserPool ;
83
+
84
+ private final ResponseUnmarshaller unmarshaller ;
57
85
58
86
private Function <HttpServletRequest , AbstractSaml2AuthenticationRequest > loader ;
59
87
60
88
/**
61
- * Constructs a {@link Saml2AuthenticationTokenConverter } given a strategy for
62
- * resolving {@link RelyingPartyRegistration}s
63
- * @param relyingPartyRegistrationResolver the strategy for resolving
89
+ * Constructs a {@link OpenSamlAuthenticationTokenConverter } given a repository for
90
+ * {@link RelyingPartyRegistration}s
91
+ * @param registrations the repository for {@link RelyingPartyRegistration}s
64
92
* {@link RelyingPartyRegistration}s
65
93
*/
66
- public Saml2AuthenticationTokenConverter (RelyingPartyRegistrationResolver relyingPartyRegistrationResolver ) {
67
- Assert .notNull (relyingPartyRegistrationResolver , "relyingPartyRegistrationResolver cannot be null" );
68
- this .relyingPartyRegistrationResolver = relyingPartyRegistrationResolver ;
94
+ public OpenSamlAuthenticationTokenConverter (RelyingPartyRegistrationRepository registrations ) {
95
+ Assert .notNull (registrations , "relyingPartyRegistrationRepository cannot be null" );
96
+ XMLObjectProviderRegistry registry = ConfigurationService .get (XMLObjectProviderRegistry .class );
97
+ this .parserPool = registry .getParserPool ();
98
+ this .unmarshaller = (ResponseUnmarshaller ) XMLObjectProviderRegistrySupport .getUnmarshallerFactory ()
99
+ .getUnmarshaller (Response .DEFAULT_ELEMENT_NAME );
100
+ this .registrations = registrations ;
69
101
this .loader = new HttpSessionSaml2AuthenticationRequestRepository ()::loadAuthenticationRequest ;
70
102
}
71
103
104
+ /**
105
+ * Resolve an authentication request from the given {@link HttpServletRequest}.
106
+ *
107
+ * <p>
108
+ * First uses the configured {@link RequestMatcher} to deduce whether an
109
+ * authentication request is being made and optionally for which
110
+ * {@code registrationId}.
111
+ *
112
+ * <p>
113
+ * If there is an associated {@code <saml2:AuthnRequest>}, then the
114
+ * {@code registrationId} is looked up and used.
115
+ *
116
+ * <p>
117
+ * If a {@code registrationId} is found in the request, then it is looked up and used.
118
+ * In that case, if none is found a {@link Saml2AuthenticationException} is thrown.
119
+ *
120
+ * <p>
121
+ * Finally, if no {@code registrationId} is found in the request, then the code
122
+ * attempts to resolve the {@link RelyingPartyRegistration} from the SAML Response's
123
+ * Issuer.
124
+ * @param request the HTTP request
125
+ * @return the {@link Saml2AuthenticationToken} authentication request
126
+ * @throws Saml2AuthenticationException if the {@link RequestMatcher} specifies a
127
+ * non-existent {@code registrationId}
128
+ */
72
129
@ Override
73
130
public Saml2AuthenticationToken convert (HttpServletRequest request ) {
131
+ String serialized = request .getParameter (Saml2ParameterNames .SAML_RESPONSE );
132
+ if (serialized == null ) {
133
+ return null ;
134
+ }
135
+ RequestMatcher .MatchResult result = this .requestMatcher .matcher (request );
136
+ if (!result .isMatch ()) {
137
+ return null ;
138
+ }
139
+ Saml2AuthenticationToken token = tokenByAuthenticationRequest (request );
140
+ if (token == null ) {
141
+ token = tokenByRegistrationId (request , result );
142
+ }
143
+ if (token == null ) {
144
+ token = tokenByEntityId (request );
145
+ }
146
+ return token ;
147
+ }
148
+
149
+ private Saml2AuthenticationToken tokenByAuthenticationRequest (HttpServletRequest request ) {
74
150
AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest (request );
75
- String relyingPartyRegistrationId = (authenticationRequest != null )
76
- ? authenticationRequest .getRelyingPartyRegistrationId () : null ;
77
- RelyingPartyRegistration relyingPartyRegistration = this .relyingPartyRegistrationResolver .resolve (request ,
78
- relyingPartyRegistrationId );
79
- if (relyingPartyRegistration == null ) {
151
+ if (authenticationRequest == null ) {
152
+ return null ;
153
+ }
154
+ String registrationId = authenticationRequest .getRelyingPartyRegistrationId ();
155
+ RelyingPartyRegistration registration = this .registrations .findByRegistrationId (registrationId );
156
+ return tokenByRegistration (request , registration , authenticationRequest );
157
+ }
158
+
159
+ private Saml2AuthenticationToken tokenByRegistrationId (HttpServletRequest request ,
160
+ RequestMatcher .MatchResult result ) {
161
+ String registrationId = result .getVariables ().get ("registrationId" );
162
+ if (registrationId == null ) {
80
163
return null ;
81
164
}
82
- String saml2Response = request .getParameter (Saml2ParameterNames .SAML_RESPONSE );
83
- if (saml2Response == null ) {
165
+ RelyingPartyRegistration registration = this .registrations .findByRegistrationId (registrationId );
166
+ return tokenByRegistration (request , registration , null );
167
+ }
168
+
169
+ private Saml2AuthenticationToken tokenByEntityId (HttpServletRequest request ) {
170
+ String serialized = request .getParameter (Saml2ParameterNames .SAML_RESPONSE );
171
+ String decoded = new String (samlDecode (serialized ), StandardCharsets .UTF_8 );
172
+ Response response = parse (decoded );
173
+ String issuer = response .getIssuer ().getValue ();
174
+ RelyingPartyRegistration registration = this .registrations .findUniqueByAssertingPartyEntityId (issuer );
175
+ return tokenByRegistration (request , registration , null );
176
+ }
177
+
178
+ private Saml2AuthenticationToken tokenByRegistration (HttpServletRequest request ,
179
+ RelyingPartyRegistration registration , AbstractSaml2AuthenticationRequest authenticationRequest ) {
180
+ if (registration == null ) {
84
181
return null ;
85
182
}
86
- byte [] b = samlDecode (saml2Response );
87
- saml2Response = inflateIfRequired (request , b );
88
- return new Saml2AuthenticationToken (relyingPartyRegistration , saml2Response , authenticationRequest );
183
+ String serialized = request .getParameter (Saml2ParameterNames .SAML_RESPONSE );
184
+ String decoded = inflateIfRequired (request , samlDecode (serialized ));
185
+ UriResolver resolver = RelyingPartyRegistrationPlaceholderResolvers .uriResolver (request , registration );
186
+ registration = registration .mutate ().entityId (resolver .resolve (registration .getEntityId ()))
187
+ .assertionConsumerServiceLocation (resolver .resolve (registration .getAssertionConsumerServiceLocation ()))
188
+ .build ();
189
+ return new Saml2AuthenticationToken (registration , decoded , authenticationRequest );
89
190
}
90
191
91
192
/**
@@ -100,6 +201,15 @@ public void setAuthenticationRequestRepository(
100
201
this .loader = authenticationRequestRepository ::loadAuthenticationRequest ;
101
202
}
102
203
204
+ /**
205
+ * Use the given {@link RequestMatcher} to match the request.
206
+ * @param requestMatcher the {@link RequestMatcher} to use
207
+ */
208
+ public void setRequestMatcher (RequestMatcher requestMatcher ) {
209
+ Assert .notNull (requestMatcher , "requestMatcher cannot be null" );
210
+ this .requestMatcher = requestMatcher ;
211
+ }
212
+
103
213
private AbstractSaml2AuthenticationRequest loadAuthenticationRequest (HttpServletRequest request ) {
104
214
return this .loader .apply (request );
105
215
}
@@ -136,6 +246,18 @@ private String samlInflate(byte[] b) {
136
246
}
137
247
}
138
248
249
+ private Response parse (String request ) throws Saml2Exception {
250
+ try {
251
+ Document document = this .parserPool
252
+ .parse (new ByteArrayInputStream (request .getBytes (StandardCharsets .UTF_8 )));
253
+ Element element = document .getDocumentElement ();
254
+ return (Response ) this .unmarshaller .unmarshall (element );
255
+ }
256
+ catch (Exception ex ) {
257
+ throw new Saml2Exception ("Failed to deserialize LogoutRequest" , ex );
258
+ }
259
+ }
260
+
139
261
static class Base64Checker {
140
262
141
263
private static final int [] values = genValueMapping ();
0 commit comments