diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java index df26460b4cd..8a8d89c2286 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,21 @@ package org.springframework.security.oauth2.client.web; +import java.io.Serializable; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; +import org.springframework.security.core.SpringSecurityCoreVersion; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; @@ -30,9 +38,13 @@ /** * An implementation of an {@link AuthorizationRequestRepository} that stores * {@link OAuth2AuthorizationRequest} in the {@code HttpSession}. + *

+ * NOTE: {@link OAuth2AuthorizationRequest}s expire after two minutes, the default + * duration can be configured via {@link #setAuthorizationRequestTimeToLive(Duration)}. * * @author Joe Grandja * @author Rob Winch + * @author Craig Andrews * @since 5.0 * @see AuthorizationRequestRepository * @see OAuth2AuthorizationRequest @@ -45,6 +57,12 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME; + private Clock clock = Clock.systemUTC(); + + private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120); + + private int maxActiveAuthorizationRequestsPerRegistrationIdPerSession = 3; + @Override public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) { Assert.notNull(request, "request cannot be null"); @@ -52,8 +70,9 @@ public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest re if (stateParameter == null) { return null; } - Map authorizationRequests = this.getAuthorizationRequests(request); - return authorizationRequests.get(stateParameter); + Map authorizationRequests = this.getAuthorizationRequests(request); + OAuth2AuthorizationRequestReference authorizationRequestReference = authorizationRequests.get(stateParameter); + return (authorizationRequestReference != null) ? authorizationRequestReference.authorizationRequest : null; } @Override @@ -67,8 +86,18 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq } String state = authorizationRequest.getState(); Assert.hasText(state, "authorizationRequest.state cannot be empty"); - Map authorizationRequests = this.getAuthorizationRequests(request); - authorizationRequests.put(state, authorizationRequest); + Map authorizationRequests = this.getAuthorizationRequests(request); + authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest, + this.clock.instant().plus(this.authorizationRequestTimeToLive))); + for (String registrationId : authorizationRequests.values().stream().map((r) -> r.getRegistrationId()) + .distinct().collect(Collectors.toList())) { + List references = authorizationRequests.values().stream() + .filter((r) -> Objects.equals(registrationId, r.getRegistrationId())).collect(Collectors.toList()); + if (references.size() > this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession) { + references.stream().sorted((a, b) -> a.expiresAt.compareTo(b.expiresAt)).findFirst() + .map((r) -> r.getState()).ifPresent(authorizationRequests::remove); + } + } request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); } @@ -79,15 +108,16 @@ public OAuth2AuthorizationRequest removeAuthorizationRequest(HttpServletRequest if (stateParameter == null) { return null; } - Map authorizationRequests = this.getAuthorizationRequests(request); - OAuth2AuthorizationRequest originalRequest = authorizationRequests.remove(stateParameter); + Map authorizationRequests = this.getAuthorizationRequests(request); + OAuth2AuthorizationRequestReference authorizationRequestReference = authorizationRequests + .remove(stateParameter); if (!authorizationRequests.isEmpty()) { request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests); } else { request.getSession().removeAttribute(this.sessionAttributeName); } - return originalRequest; + return (authorizationRequestReference != null) ? authorizationRequestReference.authorizationRequest : null; } @Override @@ -113,14 +143,81 @@ private String getStateParameter(HttpServletRequest request) { * @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} * to an {@link OAuth2AuthorizationRequest}. */ - private Map getAuthorizationRequests(HttpServletRequest request) { + private Map getAuthorizationRequests(HttpServletRequest request) { HttpSession session = request.getSession(false); - Map authorizationRequests = (session != null) - ? (Map) session.getAttribute(this.sessionAttributeName) : null; + Map authorizationRequests = (session != null) + ? (Map) session.getAttribute(this.sessionAttributeName) + : null; if (authorizationRequests == null) { return new HashMap<>(); } + // remove expired entries + authorizationRequests.entrySet().removeIf((entry) -> entry.getValue().expiresAt.isBefore(this.clock.instant())); return authorizationRequests; } + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when setting the instant + * created for {@link OAuth2AuthorizationRequest}. + * @param clock the clock + * @since 5.5 + */ + void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + + /** + * Sets the {@link Duration} for which {@link OAuth2AuthorizationRequest} should + * expire. + * @param authorizationRequestTimeToLive the {@link Duration} a + * {@link OAuth2AuthorizationRequest} is considered not expired. Must not be negative. + * @since 5.5 + */ + void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeToLive) { + Assert.notNull(authorizationRequestTimeToLive, "oAuth2AuthorizationRequestExpiresIn cannot be null"); + Assert.state(!authorizationRequestTimeToLive.isNegative(), + "oAuth2AuthorizationRequestExpiresIn cannot be negative"); + this.authorizationRequestTimeToLive = authorizationRequestTimeToLive; + } + + /** + * Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be + * stored/active per registration id for a session. If the maximum number are present + * in a session when an attempt is made to save another one, then the oldest will be + * removed. + * @param maxActiveAuthorizationRequestsPerSession must not be negative. + */ + void setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession( + int maxActiveAuthorizationRequestsPerRegistrationIdPerSession) { + Assert.state(maxActiveAuthorizationRequestsPerRegistrationIdPerSession > 0, + "maxActiveAuthorizationRequestsPerRegistrationIdPerSession must be greater than zero"); + this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession = maxActiveAuthorizationRequestsPerRegistrationIdPerSession; + } + + private static final class OAuth2AuthorizationRequestReference implements Serializable { + + private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID; + + private final Instant expiresAt; + + private final OAuth2AuthorizationRequest authorizationRequest; + + private OAuth2AuthorizationRequestReference(OAuth2AuthorizationRequest authorizationRequest, + Instant expiresAt) { + Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); + this.expiresAt = expiresAt; + this.authorizationRequest = authorizationRequest; + } + + private String getRegistrationId() { + return this.authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID); + } + + private String getState() { + return this.authorizationRequest.getState(); + } + + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java index 66d6e1d90dc..a5ad1155e8c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,11 @@ package org.springframework.security.oauth2.client.web; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -36,6 +41,7 @@ * Tests for {@link HttpSessionOAuth2AuthorizationRequestRepository}. * * @author Joe Grandja + * @author Craig Andrews */ @RunWith(MockitoJUnitRunner.class) public class HttpSessionOAuth2AuthorizationRequestRepositoryTests { @@ -237,6 +243,102 @@ public void removeAuthorizationRequestWhenNotSavedThenNotRemoved() { assertThat(removedAuthorizationRequest).isNull(); } + @Test + public void removeAuthorizationRequestWhenExpired() { + final Duration expiresIn = Duration.ofMinutes(2); + this.authorizationRequestRepository.setAuthorizationRequestTimeToLive(expiresIn); + this.authorizationRequestRepository.setClock(Clock.fixed(Instant.ofEpochMilli(0), ZoneId.systemDefault())); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + this.authorizationRequestRepository + .setClock(Clock.fixed(Instant.ofEpochMilli(1).plus(expiresIn), ZoneId.systemDefault())); + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isNull(); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); + } + + @Test + public void removeOldestAuthorizationRequestWhenMoreThanMax() { + String registrationId = "registration-id-1"; + this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + String state3 = "state-4455"; + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isNull(); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state3); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); + } + + @Test + public void doNotremoveOldestAuthorizationRequestWhenLessThanMax() { + this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + String state1 = "state-1122"; + OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-1")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response); + String state2 = "state-3344"; + OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-2")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response); + String state3 = "state-4455"; + OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3) + .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-3")) + .build(); + this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response); + request.addParameter(OAuth2ParameterNames.STATE, state1); + OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state2); + OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2); + request.removeParameter(OAuth2ParameterNames.STATE); + request.addParameter(OAuth2ParameterNames.STATE, state3); + OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository + .loadAuthorizationRequest(request); + assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3); + } + private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() { return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize") .clientId("client-id-1234").state("state-1234");