diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java index 05bc7063850..81e423b2020 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java @@ -27,6 +27,8 @@ import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import javax.servlet.http.HttpServletRequest; @@ -54,6 +56,7 @@ */ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver { private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId"; + private static final char PATH_DELIMITER = '/'; private final ClientRegistrationRepository clientRegistrationRepository; private final AntPathRequestMatcher authorizationRequestMatcher; private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder()); @@ -127,7 +130,7 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re ") for Client Registration with Id: " + clientRegistration.getRegistrationId()); } - String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction); + String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction); OAuth2AuthorizationRequest authorizationRequest = builder .clientId(clientRegistration.getClientId()) @@ -149,20 +152,49 @@ private String resolveRegistrationId(HttpServletRequest request) { return null; } - private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) { - // Supported URI variables -> baseUrl, action, registrationId - // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}" + /** + * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:
+ * - baseUrl (e.g. https://localhost/app)
+ * - baseScheme (e.g. https)
+ * - baseHost (e.g. localhost)
+ * - basePort (e.g. :8080)
+ * - basePath (e.g. /app)
+ * - registrationId (e.g. google)
+ * - action (e.g. login)
+ *

+ * Null variables are provided as empty strings. + *

+ * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL + * + * @return expanded URI + */ + private static String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) - .replaceQuery(null) + + UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) .replacePath(request.getContextPath()) - .build() - .toUriString(); - uriVariables.put("baseUrl", baseUrl); - if (action != null) { - uriVariables.put("action", action); + .replaceQuery(null) + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path)) { + if (path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + + uriVariables.put("action", action == null ? "" : action); + return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate()) .buildAndExpand(uriVariables) .toUriString(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java index 74371198004..f0f4eb38c1b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java @@ -30,8 +30,10 @@ import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import org.springframework.web.server.ResponseStatusException; import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import reactor.core.publisher.Mono; @@ -63,8 +65,9 @@ public class DefaultServerOAuth2AuthorizationRequestResolver /** * The default pattern used to resolve the {@link ClientRegistration#getRegistrationId()} */ - public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME - + "}"; + public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}"; + + private static final char PATH_DELIMITER = '/'; private final ServerWebExchangeMatcher authorizationRequestMatcher; @@ -121,8 +124,7 @@ private Mono findByRegistrationId(ServerWebExchange exchange private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange, ClientRegistration clientRegistration) { - String redirectUriStr = this - .expandRedirectUri(exchange.getRequest(), clientRegistration); + String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration); Map attributes = new HashMap<>(); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); @@ -153,23 +155,52 @@ else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizat .build(); } - private String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { - // Supported URI variables -> baseUrl, action, registrationId - // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}" + /** + * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:
+ * - baseUrl (e.g. https://localhost/app)
+ * - baseScheme (e.g. https)
+ * - baseHost (e.g. localhost)
+ * - basePort (e.g. :8080)
+ * - basePath (e.g. /app)
+ * - registrationId (e.g. google)
+ * - action (e.g. login)
+ *

+ * Null variables are provided as empty strings. + *

+ * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL + * + * @return expanded URI + */ + private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { Map uriVariables = new HashMap<>(); uriVariables.put("registrationId", clientRegistration.getRegistrationId()); - String baseUrl = UriComponentsBuilder.fromUri(request.getURI()) + UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) .replacePath(request.getPath().contextPath().value()) .replaceQuery(null) - .build() - .toUriString(); - uriVariables.put("baseUrl", baseUrl); + .fragment(null) + .build(); + String scheme = uriComponents.getScheme(); + uriVariables.put("baseScheme", scheme == null ? "" : scheme); + String host = uriComponents.getHost(); + uriVariables.put("baseHost", host == null ? "" : host); + // following logic is based on HierarchicalUriComponents#toUriString() + int port = uriComponents.getPort(); + uriVariables.put("basePort", port == -1 ? "" : ":" + port); + String path = uriComponents.getPath(); + if (StringUtils.hasLength(path)) { + if (path.charAt(0) != PATH_DELIMITER) { + path = PATH_DELIMITER + path; + } + } + uriVariables.put("basePath", path == null ? "" : path); + uriVariables.put("baseUrl", uriComponents.toUriString()); + String action = ""; if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - String loginAction = "login"; - uriVariables.put("action", loginAction); + action = "login"; } + uriVariables.put("action", action); return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate()) .buildAndExpand(uriVariables) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java index 03aacdcf256..665755e58ad 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java @@ -41,15 +41,17 @@ public class DefaultOAuth2AuthorizationRequestResolverTests { private ClientRegistration registration1; private ClientRegistration registration2; + private ClientRegistration fineRedirectUriTemplateRegistration; private ClientRegistration pkceRegistration; private ClientRegistrationRepository clientRegistrationRepository; - private String authorizationRequestBaseUri = "/oauth2/authorization"; + private final String authorizationRequestBaseUri = "/oauth2/authorization"; private DefaultOAuth2AuthorizationRequestResolver resolver; @Before public void setUp() { this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build(); + this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build(); this.pkceRegistration = TestClientRegistrations.clientRegistration() .registrationId("pkce-client-registration-id") .clientId("pkce-client-id") @@ -58,7 +60,7 @@ public void setUp() { .build(); this.clientRegistrationRepository = new InMemoryClientRegistrationRepository( - this.registration1, this.registration2, this.pkceRegistration); + this.registration1, this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration); this.resolver = new DefaultOAuth2AuthorizationRequestResolver( this.clientRegistrationRepository, this.authorizationRequestBaseUri); } @@ -152,6 +154,80 @@ public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriEx "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); } + @Test + public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() { + ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServerPort(8080); + request.setServletPath(requestUri); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo( + "http://localhost:8080/login/oauth2/code/" + clientRegistration.getRegistrationId()); + } + + @Test + public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() { + ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("https"); + request.setServerPort(8081); + request.setServletPath(requestUri); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo( + "https://localhost:8081/login/oauth2/code/" + clientRegistration.getRegistrationId()); + } + + @Test + public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() { + ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("http"); + request.setServerPort(80); + request.setServletPath(requestUri); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo( + "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + } + + @Test + public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() { + ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("https"); + request.setServerPort(443); + request.setServletPath(requestUri); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo( + "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + } + + @Test + public void resolveWhenAuthorizationRequestHasNoPortThenExpandedRedirectUriWithExtraVarsExcludesPort() { + ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration; + String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setScheme("https"); + request.setServerPort(-1); + request.setServletPath(requestUri); + + OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request); + assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate()); + assertThat(authorizationRequest.getRedirectUri()).isEqualTo( + "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId()); + } + // gh-5520 @Test public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() { @@ -301,4 +377,19 @@ public void resolveWhenAuthorizationRequestWithValidPkceClientThenResolves() { "code_challenge_method=S256&" + "code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}"); } + + private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() { + return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration") + .redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .scope("read:user") + .authorizationUri("https://example.com/login/oauth/authorize") + .tokenUri("https://example.com/login/oauth/access_token") + .userInfoUri("https://api.example.com/user") + .userNameAttributeName("id") + .clientName("Fine Redirect Uri Template Client") + .clientId("fine-redirect-uri-template-client") + .clientSecret("client-secret"); + } }