diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java
index 05bc7063850..81e423b2020 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolver.java
@@ -27,6 +27,8 @@
import org.springframework.security.web.util.UrlUtils;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.servlet.http.HttpServletRequest;
@@ -54,6 +56,7 @@
*/
public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2AuthorizationRequestResolver {
private static final String REGISTRATION_ID_URI_VARIABLE_NAME = "registrationId";
+ private static final char PATH_DELIMITER = '/';
private final ClientRegistrationRepository clientRegistrationRepository;
private final AntPathRequestMatcher authorizationRequestMatcher;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
@@ -127,7 +130,7 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re
") for Client Registration with Id: " + clientRegistration.getRegistrationId());
}
- String redirectUriStr = this.expandRedirectUri(request, clientRegistration, redirectUriAction);
+ String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
OAuth2AuthorizationRequest authorizationRequest = builder
.clientId(clientRegistration.getClientId())
@@ -149,20 +152,49 @@ private String resolveRegistrationId(HttpServletRequest request) {
return null;
}
- private String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
- // Supported URI variables -> baseUrl, action, registrationId
- // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
+ /**
+ * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:
+ * - baseUrl (e.g. https://localhost/app)
+ * - baseScheme (e.g. https)
+ * - baseHost (e.g. localhost)
+ * - basePort (e.g. :8080)
+ * - basePath (e.g. /app)
+ * - registrationId (e.g. google)
+ * - action (e.g. login)
+ *
+ * Null variables are provided as empty strings.
+ *
+ * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL
+ *
+ * @return expanded URI
+ */
+ private static String expandRedirectUri(HttpServletRequest request, ClientRegistration clientRegistration, String action) {
Map uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
- String baseUrl = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
- .replaceQuery(null)
+
+ UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath())
- .build()
- .toUriString();
- uriVariables.put("baseUrl", baseUrl);
- if (action != null) {
- uriVariables.put("action", action);
+ .replaceQuery(null)
+ .fragment(null)
+ .build();
+ String scheme = uriComponents.getScheme();
+ uriVariables.put("baseScheme", scheme == null ? "" : scheme);
+ String host = uriComponents.getHost();
+ uriVariables.put("baseHost", host == null ? "" : host);
+ // following logic is based on HierarchicalUriComponents#toUriString()
+ int port = uriComponents.getPort();
+ uriVariables.put("basePort", port == -1 ? "" : ":" + port);
+ String path = uriComponents.getPath();
+ if (StringUtils.hasLength(path)) {
+ if (path.charAt(0) != PATH_DELIMITER) {
+ path = PATH_DELIMITER + path;
+ }
}
+ uriVariables.put("basePath", path == null ? "" : path);
+ uriVariables.put("baseUrl", uriComponents.toUriString());
+
+ uriVariables.put("action", action == null ? "" : action);
+
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
.buildAndExpand(uriVariables)
.toUriString();
diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
index 74371198004..f0f4eb38c1b 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizationRequestResolver.java
@@ -30,8 +30,10 @@
import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;
@@ -63,8 +65,9 @@ public class DefaultServerOAuth2AuthorizationRequestResolver
/**
* The default pattern used to resolve the {@link ClientRegistration#getRegistrationId()}
*/
- public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME
- + "}";
+ public static final String DEFAULT_AUTHORIZATION_REQUEST_PATTERN = "/oauth2/authorization/{" + DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME + "}";
+
+ private static final char PATH_DELIMITER = '/';
private final ServerWebExchangeMatcher authorizationRequestMatcher;
@@ -121,8 +124,7 @@ private Mono findByRegistrationId(ServerWebExchange exchange
private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange,
ClientRegistration clientRegistration) {
- String redirectUriStr = this
- .expandRedirectUri(exchange.getRequest(), clientRegistration);
+ String redirectUriStr = expandRedirectUri(exchange.getRequest(), clientRegistration);
Map attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
@@ -153,23 +155,52 @@ else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizat
.build();
}
- private String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) {
- // Supported URI variables -> baseUrl, action, registrationId
- // Used in -> CommonOAuth2Provider.DEFAULT_REDIRECT_URL = "{baseUrl}/{action}/oauth2/code/{registrationId}"
+ /**
+ * Expands the {@link ClientRegistration#getRedirectUriTemplate()} with following provided variables:
+ * - baseUrl (e.g. https://localhost/app)
+ * - baseScheme (e.g. https)
+ * - baseHost (e.g. localhost)
+ * - basePort (e.g. :8080)
+ * - basePath (e.g. /app)
+ * - registrationId (e.g. google)
+ * - action (e.g. login)
+ *
+ * Null variables are provided as empty strings.
+ *
+ * Default redirectUriTemplate is: {@link org.springframework.security.config.oauth2.client}.CommonOAuth2Provider#DEFAULT_REDIRECT_URL
+ *
+ * @return expanded URI
+ */
+ private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) {
Map uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId());
- String baseUrl = UriComponentsBuilder.fromUri(request.getURI())
+ UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI())
.replacePath(request.getPath().contextPath().value())
.replaceQuery(null)
- .build()
- .toUriString();
- uriVariables.put("baseUrl", baseUrl);
+ .fragment(null)
+ .build();
+ String scheme = uriComponents.getScheme();
+ uriVariables.put("baseScheme", scheme == null ? "" : scheme);
+ String host = uriComponents.getHost();
+ uriVariables.put("baseHost", host == null ? "" : host);
+ // following logic is based on HierarchicalUriComponents#toUriString()
+ int port = uriComponents.getPort();
+ uriVariables.put("basePort", port == -1 ? "" : ":" + port);
+ String path = uriComponents.getPath();
+ if (StringUtils.hasLength(path)) {
+ if (path.charAt(0) != PATH_DELIMITER) {
+ path = PATH_DELIMITER + path;
+ }
+ }
+ uriVariables.put("basePath", path == null ? "" : path);
+ uriVariables.put("baseUrl", uriComponents.toUriString());
+ String action = "";
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
- String loginAction = "login";
- uriVariables.put("action", loginAction);
+ action = "login";
}
+ uriVariables.put("action", action);
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUriTemplate())
.buildAndExpand(uriVariables)
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java
index 03aacdcf256..665755e58ad 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizationRequestResolverTests.java
@@ -41,15 +41,17 @@
public class DefaultOAuth2AuthorizationRequestResolverTests {
private ClientRegistration registration1;
private ClientRegistration registration2;
+ private ClientRegistration fineRedirectUriTemplateRegistration;
private ClientRegistration pkceRegistration;
private ClientRegistrationRepository clientRegistrationRepository;
- private String authorizationRequestBaseUri = "/oauth2/authorization";
+ private final String authorizationRequestBaseUri = "/oauth2/authorization";
private DefaultOAuth2AuthorizationRequestResolver resolver;
@Before
public void setUp() {
this.registration1 = TestClientRegistrations.clientRegistration().build();
this.registration2 = TestClientRegistrations.clientRegistration2().build();
+ this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build();
this.pkceRegistration = TestClientRegistrations.clientRegistration()
.registrationId("pkce-client-registration-id")
.clientId("pkce-client-id")
@@ -58,7 +60,7 @@ public void setUp() {
.build();
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(
- this.registration1, this.registration2, this.pkceRegistration);
+ this.registration1, this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration);
this.resolver = new DefaultOAuth2AuthorizationRequestResolver(
this.clientRegistrationRepository, this.authorizationRequestBaseUri);
}
@@ -152,6 +154,80 @@ public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriEx
"http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
}
+ @Test
+ public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpRedirectUriWithExtraVarsExpanded() {
+ ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+ String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setServerPort(8080);
+ request.setServletPath(requestUri);
+
+ OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+ assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+ "http://localhost:8080/login/oauth2/code/" + clientRegistration.getRegistrationId());
+ }
+
+ @Test
+ public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenHttpsRedirectUriWithExtraVarsExpanded() {
+ ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+ String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setScheme("https");
+ request.setServerPort(8081);
+ request.setServletPath(requestUri);
+
+ OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+ assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+ "https://localhost:8081/login/oauth2/code/" + clientRegistration.getRegistrationId());
+ }
+
+ @Test
+ public void resolveWhenAuthorizationRequestIncludesPort80ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+ ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+ String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setScheme("http");
+ request.setServerPort(80);
+ request.setServletPath(requestUri);
+
+ OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+ assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+ "http://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+ }
+
+ @Test
+ public void resolveWhenAuthorizationRequestIncludesPort443ThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+ ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+ String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setScheme("https");
+ request.setServerPort(443);
+ request.setServletPath(requestUri);
+
+ OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+ assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+ "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+ }
+
+ @Test
+ public void resolveWhenAuthorizationRequestHasNoPortThenExpandedRedirectUriWithExtraVarsExcludesPort() {
+ ClientRegistration clientRegistration = this.fineRedirectUriTemplateRegistration;
+ String requestUri = this.authorizationRequestBaseUri + "/" + clientRegistration.getRegistrationId();
+ MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+ request.setScheme("https");
+ request.setServerPort(-1);
+ request.setServletPath(requestUri);
+
+ OAuth2AuthorizationRequest authorizationRequest = this.resolver.resolve(request);
+ assertThat(authorizationRequest.getRedirectUri()).isNotEqualTo(clientRegistration.getRedirectUriTemplate());
+ assertThat(authorizationRequest.getRedirectUri()).isEqualTo(
+ "https://localhost/login/oauth2/code/" + clientRegistration.getRegistrationId());
+ }
+
// gh-5520
@Test
public void resolveWhenAuthorizationRequestRedirectUriTemplatedThenRedirectUriExpandedExcludesQueryString() {
@@ -301,4 +377,19 @@ public void resolveWhenAuthorizationRequestWithValidPkceClientThenResolves() {
"code_challenge_method=S256&" +
"code_challenge=([a-zA-Z0-9\\-\\.\\_\\~]){43}");
}
+
+ private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
+ return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
+ .redirectUriTemplate("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")
+ .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+ .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+ .scope("read:user")
+ .authorizationUri("https://example.com/login/oauth/authorize")
+ .tokenUri("https://example.com/login/oauth/access_token")
+ .userInfoUri("https://api.example.com/user")
+ .userNameAttributeName("id")
+ .clientName("Fine Redirect Uri Template Client")
+ .clientId("fine-redirect-uri-template-client")
+ .clientSecret("client-secret");
+ }
}