diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java
index df26460b4cd..8a8d89c2286 100644
--- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java
+++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,13 +16,21 @@
package org.springframework.security.oauth2.client.web;
+import java.io.Serializable;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
+import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
@@ -30,9 +38,13 @@
/**
* An implementation of an {@link AuthorizationRequestRepository} that stores
* {@link OAuth2AuthorizationRequest} in the {@code HttpSession}.
+ *
+ * NOTE: {@link OAuth2AuthorizationRequest}s expire after two minutes, the default
+ * duration can be configured via {@link #setAuthorizationRequestTimeToLive(Duration)}.
*
* @author Joe Grandja
* @author Rob Winch
+ * @author Craig Andrews
* @since 5.0
* @see AuthorizationRequestRepository
* @see OAuth2AuthorizationRequest
@@ -45,6 +57,12 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
+ private Clock clock = Clock.systemUTC();
+
+ private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120);
+
+ private int maxActiveAuthorizationRequestsPerRegistrationIdPerSession = 3;
+
@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
@@ -52,8 +70,9 @@ public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest re
if (stateParameter == null) {
return null;
}
- Map authorizationRequests = this.getAuthorizationRequests(request);
- return authorizationRequests.get(stateParameter);
+ Map authorizationRequests = this.getAuthorizationRequests(request);
+ OAuth2AuthorizationRequestReference authorizationRequestReference = authorizationRequests.get(stateParameter);
+ return (authorizationRequestReference != null) ? authorizationRequestReference.authorizationRequest : null;
}
@Override
@@ -67,8 +86,18 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
}
String state = authorizationRequest.getState();
Assert.hasText(state, "authorizationRequest.state cannot be empty");
- Map authorizationRequests = this.getAuthorizationRequests(request);
- authorizationRequests.put(state, authorizationRequest);
+ Map authorizationRequests = this.getAuthorizationRequests(request);
+ authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest,
+ this.clock.instant().plus(this.authorizationRequestTimeToLive)));
+ for (String registrationId : authorizationRequests.values().stream().map((r) -> r.getRegistrationId())
+ .distinct().collect(Collectors.toList())) {
+ List references = authorizationRequests.values().stream()
+ .filter((r) -> Objects.equals(registrationId, r.getRegistrationId())).collect(Collectors.toList());
+ if (references.size() > this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
+ references.stream().sorted((a, b) -> a.expiresAt.compareTo(b.expiresAt)).findFirst()
+ .map((r) -> r.getState()).ifPresent(authorizationRequests::remove);
+ }
+ }
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
@@ -79,15 +108,16 @@ public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest
if (stateParameter == null) {
return null;
}
- Map authorizationRequests = this.getAuthorizationRequests(request);
- OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter);
+ Map authorizationRequests = this.getAuthorizationRequests(request);
+ OAuth2AuthorizationRequestReference authorizationRequestReference = authorizationRequests
+ .remove(stateParameter);
if (!authorizationRequests.isEmpty()) {
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
else {
request.getSession().removeAttribute(this.sessionAttributeName);
}
- return originalRequest;
+ return (authorizationRequestReference != null) ? authorizationRequestReference.authorizationRequest : null;
}
@Override
@@ -113,14 +143,81 @@ private String getStateParameter(HttpServletRequest request) {
* @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()}
* to an {@link OAuth2AuthorizationRequest}.
*/
- private Map getAuthorizationRequests(HttpServletRequest request) {
+ private Map getAuthorizationRequests(HttpServletRequest request) {
HttpSession session = request.getSession(false);
- Map authorizationRequests = (session != null)
- ? (Map) session.getAttribute(this.sessionAttributeName) : null;
+ Map authorizationRequests = (session != null)
+ ? (Map) session.getAttribute(this.sessionAttributeName)
+ : null;
if (authorizationRequests == null) {
return new HashMap<>();
}
+ // remove expired entries
+ authorizationRequests.entrySet().removeIf((entry) -> entry.getValue().expiresAt.isBefore(this.clock.instant()));
return authorizationRequests;
}
+ /**
+ * Sets the {@link Clock} used in {@link Instant#now(Clock)} when setting the instant
+ * created for {@link OAuth2AuthorizationRequest}.
+ * @param clock the clock
+ * @since 5.5
+ */
+ void setClock(Clock clock) {
+ Assert.notNull(clock, "clock cannot be null");
+ this.clock = clock;
+ }
+
+ /**
+ * Sets the {@link Duration} for which {@link OAuth2AuthorizationRequest} should
+ * expire.
+ * @param authorizationRequestTimeToLive the {@link Duration} a
+ * {@link OAuth2AuthorizationRequest} is considered not expired. Must not be negative.
+ * @since 5.5
+ */
+ void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeToLive) {
+ Assert.notNull(authorizationRequestTimeToLive, "oAuth2AuthorizationRequestExpiresIn cannot be null");
+ Assert.state(!authorizationRequestTimeToLive.isNegative(),
+ "oAuth2AuthorizationRequestExpiresIn cannot be negative");
+ this.authorizationRequestTimeToLive = authorizationRequestTimeToLive;
+ }
+
+ /**
+ * Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be
+ * stored/active per registration id for a session. If the maximum number are present
+ * in a session when an attempt is made to save another one, then the oldest will be
+ * removed.
+ * @param maxActiveAuthorizationRequestsPerSession must not be negative.
+ */
+ void setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(
+ int maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
+ Assert.state(maxActiveAuthorizationRequestsPerRegistrationIdPerSession > 0,
+ "maxActiveAuthorizationRequestsPerRegistrationIdPerSession must be greater than zero");
+ this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession = maxActiveAuthorizationRequestsPerRegistrationIdPerSession;
+ }
+
+ private static final class OAuth2AuthorizationRequestReference implements Serializable {
+
+ private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;
+
+ private final Instant expiresAt;
+
+ private final OAuth2AuthorizationRequest authorizationRequest;
+
+ private OAuth2AuthorizationRequestReference(OAuth2AuthorizationRequest authorizationRequest,
+ Instant expiresAt) {
+ Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
+ this.expiresAt = expiresAt;
+ this.authorizationRequest = authorizationRequest;
+ }
+
+ private String getRegistrationId() {
+ return this.authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
+ }
+
+ private String getState() {
+ return this.authorizationRequest.getState();
+ }
+
+ }
+
}
diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java
index 66d6e1d90dc..a5ad1155e8c 100644
--- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java
+++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,6 +16,11 @@
package org.springframework.security.oauth2.client.web;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.ZoneId;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -36,6 +41,7 @@
* Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}.
*
* @author Joe Grandja
+ * @author Craig Andrews
*/
@RunWith(MockitoJUnitRunner.class)
public class HttpSessionOAuth2AuthorizationRequestRepositoryTests {
@@ -237,6 +243,102 @@ public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() {
assertThat(removedAuthorizationRequest).isNull();
}
+ @Test
+ public void removeAuthorizationRequestWhenExpired() {
+ final Duration expiresIn = Duration.ofMinutes(2);
+ this.authorizationRequestRepository.setAuthorizationRequestTimeToLive(expiresIn);
+ this.authorizationRequestRepository.setClock(Clock.fixed(Instant.ofEpochMilli(0), ZoneId.systemDefault()));
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ String state1 = "state-1122";
+ OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+ String state2 = "state-3344";
+ this.authorizationRequestRepository
+ .setClock(Clock.fixed(Instant.ofEpochMilli(1).plus(expiresIn), ZoneId.systemDefault()));
+ OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+ request.addParameter(OAuth2ParameterNames.STATE, state1);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest1).isNull();
+ request.removeParameter(OAuth2ParameterNames.STATE);
+ request.addParameter(OAuth2ParameterNames.STATE, state2);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
+ }
+
+ @Test
+ public void removeOldestAuthorizationRequestWhenMoreThanMax() {
+ String registrationId = "registration-id-1";
+ this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ String state1 = "state-1122";
+ OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+ String state2 = "state-3344";
+ OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+ String state3 = "state-4455";
+ OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
+ request.addParameter(OAuth2ParameterNames.STATE, state1);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest1).isNull();
+ request.removeParameter(OAuth2ParameterNames.STATE);
+ request.addParameter(OAuth2ParameterNames.STATE, state2);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
+ request.removeParameter(OAuth2ParameterNames.STATE);
+ request.addParameter(OAuth2ParameterNames.STATE, state3);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
+ }
+
+ @Test
+ public void doNotremoveOldestAuthorizationRequestWhenLessThanMax() {
+ this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
+ MockHttpServletRequest request = new MockHttpServletRequest();
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ String state1 = "state-1122";
+ OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-1"))
+ .build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
+ String state2 = "state-3344";
+ OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-2"))
+ .build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
+ String state3 = "state-4455";
+ OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
+ .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-3"))
+ .build();
+ this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
+ request.addParameter(OAuth2ParameterNames.STATE, state1);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
+ request.removeParameter(OAuth2ParameterNames.STATE);
+ request.addParameter(OAuth2ParameterNames.STATE, state2);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
+ request.removeParameter(OAuth2ParameterNames.STATE);
+ request.addParameter(OAuth2ParameterNames.STATE, state3);
+ OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
+ .loadAuthorizationRequest(request);
+ assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
+ }
+
private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
.clientId("client-id-1234").state("state-1234");