Skip to content

Commit 88bfd88

Browse files
committed
Introduce authenticationFailureHandler for OAuth2AuthorizationRequestRedirectFilter | gh-13793
1 parent 34cb9ab commit 88bfd88

File tree

2 files changed

+72
-22
lines changed

2 files changed

+72
-22
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,13 +25,15 @@
2525

2626
import org.springframework.core.log.LogMessage;
2727
import org.springframework.http.HttpStatus;
28+
import org.springframework.security.core.AuthenticationException;
2829
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2930
import org.springframework.security.oauth2.client.registration.ClientRegistration;
3031
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
3132
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3233
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3334
import org.springframework.security.web.DefaultRedirectStrategy;
3435
import org.springframework.security.web.RedirectStrategy;
36+
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
3537
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
3638
import org.springframework.security.web.savedrequest.RequestCache;
3739
import org.springframework.security.web.util.ThrowableAnalyzer;
@@ -97,19 +99,7 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
9799

98100
private RequestCache requestCache = new HttpSessionRequestCache();
99101

100-
private AuthorizationFailureHandler failureHandler = (request, response, ex) -> {
101-
LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
102-
if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
103-
// Log an invalid registrationId at WARN level to allow these errors to be
104-
// tuned separately from other errors
105-
this.logger.warn(message, ex);
106-
}
107-
else {
108-
this.logger.error(message, ex);
109-
}
110-
response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
111-
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
112-
};
102+
private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;
113103

114104
/**
115105
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
@@ -177,8 +167,16 @@ public final void setRequestCache(RequestCache requestCache) {
177167
this.requestCache = requestCache;
178168
}
179169

180-
public final void setFailureHandler(AuthorizationFailureHandler failureHandler) {
181-
this.failureHandler = failureHandler;
170+
/**
171+
* Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
172+
* the Authorization Server's Authorization Endpoint.
173+
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
174+
* to handle errors redirecting to the Authorization Server's Authorization Endpoint
175+
* @since 6.3
176+
*/
177+
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
178+
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
179+
this.authenticationFailureHandler = authenticationFailureHandler;
182180
}
183181

184182
@Override
@@ -192,7 +190,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
192190
}
193191
}
194192
catch (Exception ex) {
195-
this.failureHandler.onAuthorizationFailure(request, response, ex);
193+
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
194+
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
196195
return;
197196
}
198197
try {
@@ -217,7 +216,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
217216
this.sendRedirectForAuthorization(request, response, authorizationRequest);
218217
}
219218
catch (Exception failed) {
220-
this.failureHandler.onAuthorizationFailure(request, response, failed);
219+
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
220+
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
221221
}
222222
return;
223223
}
@@ -240,6 +240,22 @@ private void sendRedirectForAuthorization(HttpServletRequest request, HttpServle
240240
authorizationRequest.getAuthorizationRequestUri());
241241
}
242242

243+
private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
244+
AuthenticationException ex) throws IOException {
245+
Throwable cause = ex.getCause();
246+
LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
247+
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
248+
// Log an invalid registrationId at WARN level to allow these errors to be
249+
// tuned separately from other errors
250+
this.logger.warn(message, ex);
251+
}
252+
else {
253+
this.logger.error(message, ex);
254+
}
255+
response.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
256+
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
257+
}
258+
243259
private static final class DefaultThrowableAnalyzer extends ThrowableAnalyzer {
244260

245261
@Override
@@ -253,8 +269,12 @@ protected void initExtractorMap() {
253269

254270
}
255271

256-
public interface AuthorizationFailureHandler {
257-
void onAuthorizationFailure(HttpServletRequest request, HttpServletResponse response,
258-
Exception ex) throws IOException;
272+
private static final class OAuth2AuthorizationRequestException extends AuthenticationException {
273+
274+
public OAuth2AuthorizationRequestException(Throwable cause) {
275+
super(cause.getMessage(), cause);
276+
}
277+
259278
}
279+
260280
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -119,6 +119,11 @@ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentExcepti
119119
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
120120
}
121121

122+
@Test
123+
public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
124+
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
125+
}
126+
122127
@Test
123128
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
124129
String requestUri = "/path";
@@ -144,6 +149,31 @@ public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalS
144149
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
145150
}
146151

152+
@Test
153+
public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
154+
throws Exception {
155+
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
156+
+ this.registration1.getRegistrationId() + "-invalid";
157+
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
158+
request.setServletPath(requestUri);
159+
MockHttpServletResponse response = new MockHttpServletResponse();
160+
FilterChain filterChain = mock(FilterChain.class);
161+
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
162+
Throwable cause = ex.getCause();
163+
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
164+
response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
165+
}
166+
else {
167+
response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
168+
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
169+
}
170+
});
171+
this.filter.doFilter(request, response, filterChain);
172+
verifyNoMoreInteractions(filterChain);
173+
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
174+
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
175+
}
176+
147177
@Test
148178
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
149179
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"

0 commit comments

Comments
 (0)