diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java index 16400dc922c..fdbe3f828ab 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,10 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.util.Collections; @@ -26,10 +29,33 @@ /** * An implementation of an {@link ReactiveOAuth2AuthorizedClientManager} - * that is capable of operating outside of a {@code ServerHttpRequest} context, + * that is capable of operating outside of the context of a {@link ServerWebExchange}, * e.g. in a scheduled/background thread and/or in the service-tier. * - *

This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}

+ *

(When operating within the context of a {@link ServerWebExchange}, + * use {@link DefaultReactiveOAuth2AuthorizedClientManager} instead.)

+ * + *

This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}.

+ * + *

Authorized Client Persistence

+ * + *

This client manager utilizes a {@link ReactiveOAuth2AuthorizedClientService} + * to persist {@link OAuth2AuthorizedClient}s.

+ * + *

By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the authorized client service. + * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} + * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

+ * + *

By default, when an authorization attempt fails due to an + * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error, + * the previously saved {@link OAuth2AuthorizedClient} + * will be removed from the authorized client service. + * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} + * error generally occurs when a refresh token that is no longer valid + * is used to retrieve a new access token.) + * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

* * @author Ankur Pathak * @author Phil Clay @@ -45,6 +71,8 @@ public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager private final ReactiveOAuth2AuthorizedClientService authorizedClientService; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; /** * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. @@ -59,6 +87,8 @@ public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientService = authorizedClientService; + this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientService); + this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientService); } @Override @@ -66,7 +96,7 @@ public Mono authorize(OAuth2AuthorizeRequest authorizeRe Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); return createAuthorizationContext(authorizeRequest) - .flatMap(this::authorizeAndSave); + .flatMap(authorizationContext -> authorize(authorizationContext, authorizeRequest.getPrincipal())); } private Mono createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) { @@ -90,12 +120,33 @@ private Mono createAuthorizationContext(OAuth2Author })); } - private Mono authorizeAndSave(OAuth2AuthorizationContext authorizationContext) { + /** + * Performs authorization, and notifies either the {@link #authorizationSuccessHandler} + * or {@link #authorizationFailureHandler}, depending on the authorization result. + * + * @param authorizationContext the context to authorize + * @param principal the principle to authorize + * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds + * and the {@link #authorizationSuccessHandler} has completed, + * or completes with an exception after the authorization attempt fails + * and the {@link #authorizationFailureHandler} has completed + */ + private Mono authorize( + OAuth2AuthorizationContext authorizationContext, + Authentication principal) { return this.authorizedClientProvider.authorize(authorizationContext) - .flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient( + // Notify the authorizationSuccessHandler of the successful authorization + .flatMap(authorizedClient -> authorizationSuccessHandler.onAuthorizationSuccess( authorizedClient, - authorizationContext.getPrincipal()) + principal, + Collections.emptyMap()) .thenReturn(authorizedClient)) + // Notify the authorizationFailureHandler of the failed authorization + .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> authorizationFailureHandler.onAuthorizationFailure( + authorizationException, + principal, + Collections.emptyMap()) + .then(Mono.error(authorizationException))) .switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient()))); } @@ -121,6 +172,36 @@ public void setContextAttributesMapper(FunctionA {@link SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler} + * is used by default.

+ * + * @param authorizationSuccessHandler the handler that handles successful authorizations. + * @see SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler + * @since 5.3 + */ + public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null"); + this.authorizationSuccessHandler = authorizationSuccessHandler; + } + + /** + * Sets the handler that handles authorization failures. + * + *

A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * is used by default.

+ * + * @param authorizationFailureHandler the handler that handles authorization failures. + * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler + * @since 5.3 + */ + public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + } + /** * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. */ @@ -134,4 +215,5 @@ public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) return Mono.fromCallable(() -> mapper.apply(authorizeRequest)); } } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java new file mode 100644 index 00000000000..0cbd6ee2c8b --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationException.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.util.Assert; + +/** + * This exception is thrown on the client side when an attempt to authenticate + * or authorize an OAuth 2.0 client fails. + * + * @author Phil Clay + * @since 5.3 + * @see OAuth2AuthorizedClient + */ +public class ClientAuthorizationException extends OAuth2AuthorizationException { + + private final String clientRegistrationId; + + /** + * Constructs a {@code ClientAuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param clientRegistrationId the identifier for the client's registration + */ + public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId) { + this(error, clientRegistrationId, error.toString()); + } + /** + * Constructs a {@code ClientAuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param clientRegistrationId the identifier for the client's registration + * @param message the exception message + */ + public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message) { + super(error, message); + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + this.clientRegistrationId = clientRegistrationId; + } + + /** + * Constructs a {@code ClientAuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param clientRegistrationId the identifier for the client's registration + * @param cause the root cause + */ + public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, Throwable cause) { + this(error, clientRegistrationId, error.toString(), cause); + } + + /** + * Constructs a {@code ClientAuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param clientRegistrationId the identifier for the client's registration + * @param message the exception message + * @param cause the root cause + */ + public ClientAuthorizationException(OAuth2Error error, String clientRegistrationId, String message, Throwable cause) { + super(error, message, cause); + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + this.clientRegistrationId = clientRegistrationId; + } + + /** + * Returns the identifier for the client's registration. + * + * @return the identifier for the client's registration + */ + public String getClientRegistrationId() { + return this.clientRegistrationId; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java index 0a0c81ea98b..d9b9e7a6a73 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientAuthorizationRequiredException.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,7 @@ */ package org.springframework.security.oauth2.client; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.util.Assert; /** * This exception is thrown when an OAuth 2.0 Client is required @@ -27,9 +25,8 @@ * @since 5.1 * @see OAuth2AuthorizedClient */ -public class ClientAuthorizationRequiredException extends OAuth2AuthorizationException { +public class ClientAuthorizationRequiredException extends ClientAuthorizationException { private static final String CLIENT_AUTHORIZATION_REQUIRED_ERROR_CODE = "client_authorization_required"; - private final String clientRegistrationId; /** * Constructs a {@code ClientAuthorizationRequiredException} using the provided parameters. @@ -38,17 +35,7 @@ public class ClientAuthorizationRequiredException extends OAuth2AuthorizationExc */ public ClientAuthorizationRequiredException(String clientRegistrationId) { super(new OAuth2Error(CLIENT_AUTHORIZATION_REQUIRED_ERROR_CODE, - "Authorization required for Client Registration Id: " + clientRegistrationId, null)); - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - this.clientRegistrationId = clientRegistrationId; - } - - /** - * Returns the identifier for the client's registration. - * - * @return the identifier for the client's registration - */ - public String getClientRegistrationId() { - return this.clientRegistrationId; + "Authorization required for Client Registration Id: " + clientRegistrationId, null), + clientRegistrationId); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java new file mode 100644 index 00000000000..8daf37ecbeb --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationFailureHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * Handles when an OAuth 2.0 Authorized Client + * fails to authorize (or re-authorize) + * via the authorization server or resource server. + * + * @author Phil Clay + * @since 5.3 + */ +@FunctionalInterface +public interface ReactiveOAuth2AuthorizationFailureHandler { + + /** + * Called when an OAuth 2.0 Authorized Client + * fails to authorize (or re-authorize) + * via the authorization server or resource server. + * + * @param authorizationException the exception that contains details about what failed + * @param principal the {@code Principal} that was attempted to be authorized + * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions. + * For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} + * if the authorization was performed within the context of a {@code ServerWebExchange}. + * @return an empty {@link Mono} that completes after this handler has finished handling the event. + */ + Mono onAuthorizationFailure( + OAuth2AuthorizationException authorizationException, + Authentication principal, + Map attributes); +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java new file mode 100644 index 00000000000..5376c80103b --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ReactiveOAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,51 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * Handles when an OAuth 2.0 Authorized Client + * has been successfully authorized (or re-authorized) + * via the authorization server. + * + * @author Phil Clay + * @since 5.3 + */ +@FunctionalInterface +public interface ReactiveOAuth2AuthorizationSuccessHandler { + + /** + * Called when an OAuth 2.0 Authorized Client + * has been successfully authorized (or re-authorized) + * via the authorization server. + * + * @param authorizedClient the client that was successfully authorized + * @param principal the {@code Principal} that was authorized + * @param attributes an immutable {@code Map} of extra optional attributes present under certain conditions. + * For example, this might contain a {@link org.springframework.web.server.ServerWebExchange ServerWebExchange} + * if the authorization was performed within the context of a {@code ServerWebExchange}. + * @return an empty {@link Mono} that completes after this handler has finished handling the event. + */ + Mono onAuthorizationSuccess( + OAuth2AuthorizedClient authorizedClient, + Authentication principal, + Map attributes); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java new file mode 100644 index 00000000000..cbaf8c1d36a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * An authorization failure handler that removes authorized clients from a + * {@link ServerOAuth2AuthorizedClientRepository} + * or a {@link ReactiveOAuth2AuthorizedClientService}. + * for specific OAuth 2.0 error codes. + * + * @author Phil Clay + * @since 5.3 + */ +public class RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler implements ReactiveOAuth2AuthorizationFailureHandler { + + /** + * The default OAuth2 error codes that will trigger removal of the authorized client. + * @see OAuth2ErrorCodes + */ + public static final Set DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + /* + * Returned from resource servers when an access token provided is expired, revoked, + * malformed, or invalid for other reasons. + * + * Note that this is needed because the ServerOAuth2AuthorizedClientExchangeFilterFunction + * delegates this type of failure received from a resource server + * to this failure handler. + */ + OAuth2ErrorCodes.INVALID_TOKEN, + /* + * Returned from authorization servers when a refresh token is invalid, expired, revoked, + * does not match the redirection URI used in the authorization request, or was issued to another client. + */ + OAuth2ErrorCodes.INVALID_GRANT))); + + /** + * A delegate that removes clients from either a + * {@link ServerOAuth2AuthorizedClientRepository} + * or a + * {@link ReactiveOAuth2AuthorizedClientService} + * if the error code is one of the {@link #removeAuthorizedClientErrorCodes}. + */ + private final OAuth2AuthorizedClientRemover delegate; + + /** + * The OAuth2 Error Codes which will trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + private final Set removeAuthorizedClientErrorCodes; + + @FunctionalInterface + private interface OAuth2AuthorizedClientRemover { + Mono removeAuthorizedClient( + String clientRegistrationId, + Authentication principal, + Map attributes); + } + + /** + * @param authorizedClientRepository The repository from which authorized clients will be removed + * if the error code is one of the {@link #DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES}. + */ + public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + this(authorizedClientRepository, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); + } + + /** + * @param authorizedClientRepository The repository from which authorized clients will be removed + * if the error code is one of the {@code removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth2 Error Codes which will trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( + ServerOAuth2AuthorizedClientRepository authorizedClientRepository, + Set removeAuthorizedClientErrorCodes) { + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); + this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.delegate = (clientRegistrationId, principal, attributes) -> + authorizedClientRepository.removeAuthorizedClient( + clientRegistrationId, + principal, + (ServerWebExchange) attributes.get(ServerWebExchange.class.getName())); + } + + /** + * @param authorizedClientService the service from which authorized clients will be removed + * if the error code is one of the {@code removeAuthorizedClientErrorCodes}. + */ + public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(ReactiveOAuth2AuthorizedClientService authorizedClientService) { + this(authorizedClientService, DEFAULT_REMOVE_AUTHORIZED_CLIENT_ERROR_CODES); + } + + /** + * @param authorizedClientService the service from which authorized clients will be removed + * if the error code is one of the {@code removeAuthorizedClientErrorCodes}. + * @param removeAuthorizedClientErrorCodes the OAuth2 Error Codes which will trigger removal of an authorized client. + * @see OAuth2ErrorCodes + */ + public RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler( + ReactiveOAuth2AuthorizedClientService authorizedClientService, + Set removeAuthorizedClientErrorCodes) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + Assert.notNull(removeAuthorizedClientErrorCodes, "removeAuthorizedClientErrorCodes cannot be null"); + this.removeAuthorizedClientErrorCodes = Collections.unmodifiableSet(new HashSet<>(removeAuthorizedClientErrorCodes)); + this.delegate = (clientRegistrationId, principal, attributes) -> + authorizedClientService.removeAuthorizedClient( + clientRegistrationId, + principal.getName()); + } + + @Override + public Mono onAuthorizationFailure( + OAuth2AuthorizationException authorizationException, + Authentication principal, + Map attributes) { + + if (authorizationException instanceof ClientAuthorizationException + && hasRemovalErrorCode(authorizationException)) { + + ClientAuthorizationException clientAuthorizationException = (ClientAuthorizationException) authorizationException; + return this.delegate.removeAuthorizedClient( + clientAuthorizationException.getClientRegistrationId(), + principal, + attributes); + } else { + return Mono.empty(); + } + } + + /** + * Returns true if the given exception has an error code that + * indicates that the authorized client should be removed. + * + * @param authorizationException the exception that caused the authorization failure + * @return true if the given exception has an error code that + * indicates that the authorized client should be removed. + */ + private boolean hasRemovalErrorCode(OAuth2AuthorizationException authorizationException) { + return this.removeAuthorizedClientErrorCodes.contains(authorizationException.getError().getErrorCode()); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler.java new file mode 100644 index 00000000000..e60d79c2a9c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * An authorization success handler that saves authorized clients in a + * {@link ServerOAuth2AuthorizedClientRepository} + * or a {@link ReactiveOAuth2AuthorizedClientService}. + * + * @author Phil Clay + * @since 5.3 + */ +public class SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler implements ReactiveOAuth2AuthorizationSuccessHandler { + + /** + * A delegate that saves clients in either a + * {@link ServerOAuth2AuthorizedClientRepository} + * or a + * {@link ReactiveOAuth2AuthorizedClientService}. + */ + private final ReactiveOAuth2AuthorizationSuccessHandler delegate; + + /** + * @param authorizedClientRepository The repository in which authorized clients will be saved. + */ + public SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(final ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.delegate = (authorizedClient, principal, attributes) -> + authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + principal, + (ServerWebExchange) attributes.get(ServerWebExchange.class.getName())); + } + + /** + * @param authorizedClientService The service in which authorized clients will be saved. + */ + public SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(final ReactiveOAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.delegate = (authorizedClient, principal, attributes) -> + authorizedClientService.saveAuthorizedClient( + authorizedClient, + principal); + } + + @Override + public Mono onAuthorizationSuccess( + OAuth2AuthorizedClient authorizedClient, + Authentication principal, + Map attributes) { + return this.delegate.onAuthorizationSuccess( + authorizedClient, + principal, + attributes); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java new file mode 100644 index 00000000000..befe0e403e8 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java @@ -0,0 +1,229 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.Set; + +import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; + +/** + * Abstract base class for all of the {@code WebClientReactive*TokenResponseClient}s + * that communicate to the Authorization Server's Token Endpoint. + * + *

Submits a form request body specific to the type of grant request.

+ * + *

Accepts a JSON response body containing an OAuth 2.0 Access token or error.

+ * + * @author Phil Clay + * @since 5.3 + * @param type of grant request + * @see RFC-6749 Token Endpoint + * @see WebClientReactiveAuthorizationCodeTokenResponseClient + * @see WebClientReactiveClientCredentialsTokenResponseClient + * @see WebClientReactivePasswordTokenResponseClient + * @see WebClientReactiveRefreshTokenTokenResponseClient + */ +abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient + implements ReactiveOAuth2AccessTokenResponseClient { + + private WebClient webClient = WebClient.builder().build(); + + @Override + public Mono getTokenResponse(T grantRequest) { + Assert.notNull(grantRequest, "grantRequest cannot be null"); + return Mono.defer(() -> this.webClient.post() + .uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri()) + .headers(headers -> populateTokenRequestHeaders(grantRequest, headers)) + .body(createTokenRequestBody(grantRequest)) + .exchange() + .flatMap(response -> readTokenResponse(grantRequest, response))); + } + + /** + * Returns the {@link ClientRegistration} for the given {@code grantRequest}. + * + * @param grantRequest the grant request + * @return the {@link ClientRegistration} for the given {@code grantRequest}. + */ + abstract ClientRegistration clientRegistration(T grantRequest); + + /** + * Populates the headers for the token request. + * + * @param grantRequest the grant request + * @param headers the headers to populate + */ + private void populateTokenRequestHeaders(T grantRequest, HttpHeaders headers) { + ClientRegistration clientRegistration = clientRegistration(grantRequest); + headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); + headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); + if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); + } + } + + /** + * Creates and returns the body for the token request. + * + *

This method pre-populates the body with some standard properties, + * and then delegates to {@link #populateTokenRequestBody(AbstractOAuth2AuthorizationGrantRequest, BodyInserters.FormInserter)} + * for subclasses to further populate the body before returning.

+ * + * @param grantRequest the grant request + * @return the body for the token request. + */ + private BodyInserters.FormInserter createTokenRequestBody(T grantRequest) { + BodyInserters.FormInserter body = BodyInserters + .fromFormData(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + return populateTokenRequestBody(grantRequest, body); + } + + /** + * Populates the body of the token request. + * + *

By default, populates properties that are common to all grant types. + * Subclasses can extend this method to populate grant type specific properties.

+ * + * @param grantRequest the grant request + * @param body the body to populate + * @return the populated body + */ + BodyInserters.FormInserter populateTokenRequestBody(T grantRequest, BodyInserters.FormInserter body) { + ClientRegistration clientRegistration = clientRegistration(grantRequest); + if (!ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { + body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + Set scopesToRequest = scopes(grantRequest); + if (!CollectionUtils.isEmpty(scopesToRequest)) { + body.with(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(scopesToRequest, " ")); + } + return body; + } + + /** + * Returns the scopes to include as a property in the token request. + * + * @param grantRequest the grant request + * @return the scopes to include as a property in the token request. + */ + abstract Set scopes(T grantRequest); + + /** + * Returns the scopes to include in the response if the authorization + * server returned no scopes in the response. + * + *

As per RFC-6749 Section 5.1 Successful Access Token Response, + * if AccessTokenResponse.scope is empty, then default to the scope + * originally requested by the client in the Token Request.

+ * + * @param grantRequest the grant request + * @return the scopes to include in the response if the authorization + * server returned no scopes. + */ + Set defaultScopes(T grantRequest) { + return scopes(grantRequest); + } + + /** + * Reads the token response from the response body. + * + * @param grantRequest the request for which the response was received. + * @param response the client response from which to read + * @return the token response from the response body. + */ + private Mono readTokenResponse(T grantRequest, ClientResponse response) { + return response.body(oauth2AccessTokenResponse()) + .onErrorMap(OAuth2AuthorizationException.class, e -> createClientAuthorizationException( + response, + clientRegistration(grantRequest).getRegistrationId(), + e)) + .map(tokenResponse -> populateTokenResponse(grantRequest, tokenResponse)); + } + + /** + * Wraps the given {@link OAuth2AuthorizationException} in a {@link ClientAuthorizationException} + * that provides response details, and a more descriptive exception message. + * + * @param response the token response + * @param clientRegistrationId the id of the {@link ClientRegistration} for which a token is being requested + * @param authorizationException the {@link OAuth2AuthorizationException} to wrap + * @return the {@link ClientAuthorizationException} that wraps the given {@link OAuth2AuthorizationException} + */ + private OAuth2AuthorizationException createClientAuthorizationException( + ClientResponse response, + String clientRegistrationId, + OAuth2AuthorizationException authorizationException) { + + String message = String.format("Error retrieving OAuth 2.0 Access Token (HTTP Status Code: %s) %s", + response.rawStatusCode(), + authorizationException.getError()); + + return new ClientAuthorizationException( + authorizationException.getError(), + clientRegistrationId, + message, + authorizationException); + } + + /** + * Populates the given {@link OAuth2AccessTokenResponse} with additional details + * from the grant request. + * + * @param grantRequest the request for which the response was received. + * @param tokenResponse the original token response + * @return a token response optionally populated with additional details from the request. + */ + OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) { + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { + Set defaultScopes = defaultScopes(grantRequest); + tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) + .scopes(defaultScopes) + .build(); + } + return tokenResponse; + } + + /** + * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response. + * + * @param webClient the {@link WebClient} used when requesting the Access Token Response + */ + public void setWebClient(WebClient webClient) { + Assert.notNull(webClient, "webClient cannot be null"); + this.webClient = webClient; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java index 76d49cc367b..1ef506c1ff6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,21 +15,16 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.web.reactive.function.BodyInserters; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.util.Assert; -import reactor.core.publisher.Mono; -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import java.util.Collections; +import java.util.Set; /** * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" @@ -49,64 +44,37 @@ * @see Section 4.1.4 Access Token Response (Authorization Code Grant) * @see Section 4.2 Client Creates the Code Challenge */ -public class WebClientReactiveAuthorizationCodeTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { - private WebClient webClient = WebClient.builder() - .build(); +public class WebClientReactiveAuthorizationCodeTokenResponseClient extends + AbstractWebClientReactiveOAuth2AccessTokenResponseClient { - /** - * @param webClient the webClient to set - */ - public void setWebClient(WebClient webClient) { - Assert.notNull(webClient, "webClient cannot be null"); - this.webClient = webClient; + @Override + ClientRegistration clientRegistration(OAuth2AuthorizationCodeGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); } @Override - public Mono getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) { - return Mono.defer(() -> { - ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange(); - String tokenUri = clientRegistration.getProviderDetails().getTokenUri(); - BodyInserters.FormInserter body = body(authorizationExchange, clientRegistration); + Set scopes(OAuth2AuthorizationCodeGrantRequest grantRequest) { + return Collections.emptySet(); + } - return this.webClient.post() - .uri(tokenUri) - .accept(MediaType.APPLICATION_JSON) - .headers(headers -> { - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - }) - .body(body) - .exchange() - .flatMap(response -> response.body(oauth2AccessTokenResponse())) - .map(response -> { - if (response.getAccessToken().getScopes().isEmpty()) { - response = OAuth2AccessTokenResponse.withResponse(response) - .scopes(authorizationExchange.getAuthorizationRequest().getScopes()) - .build(); - } - return response; - }); - }); + @Override + Set defaultScopes(OAuth2AuthorizationCodeGrantRequest grantRequest) { + return grantRequest.getAuthorizationExchange().getAuthorizationRequest().getScopes(); } - private static BodyInserters.FormInserter body(OAuth2AuthorizationExchange authorizationExchange, ClientRegistration clientRegistration) { + @Override + BodyInserters.FormInserter populateTokenRequestBody( + OAuth2AuthorizationCodeGrantRequest grantRequest, + BodyInserters.FormInserter body) { + super.populateTokenRequestBody(grantRequest, body); + OAuth2AuthorizationExchange authorizationExchange = grantRequest.getAuthorizationExchange(); OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse(); - BodyInserters.FormInserter body = BodyInserters - .fromFormData(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .with(OAuth2ParameterNames.CODE, authorizationResponse.getCode()); + body.with(OAuth2ParameterNames.CODE, authorizationResponse.getCode()); String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri(); - String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER); if (redirectUri != null) { body.with(OAuth2ParameterNames.REDIRECT_URI, redirectUri); } - if (!ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - } - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } + String codeVerifier = authorizationExchange.getAuthorizationRequest().getAttribute(PkceParameterNames.CODE_VERIFIER); if (codeVerifier != null) { body.with(PkceParameterNames.CODE_VERIFIER, codeVerifier); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java index 6acfd38547e..0d39b00c09b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,31 +15,14 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; -import org.springframework.web.reactive.function.BodyInserters; -import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.client.WebClientResponseException; -import reactor.core.publisher.Mono; import java.util.Set; -import java.util.function.Consumer; - -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; /** * An implementation of an {@link ReactiveOAuth2AccessTokenResponseClient} that "exchanges" - * an authorization code credential for an access token credential + * an client credential for an access token credential * at the Authorization Server's Token Endpoint. * * @author Rob Winch @@ -51,76 +34,17 @@ * @see Section 4.1.3 Access Token Request (Authorization Code Grant) * @see Section 4.1.4 Access Token Response (Authorization Code Grant) */ -public class WebClientReactiveClientCredentialsTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { - private WebClient webClient = WebClient.builder() - .build(); +public class WebClientReactiveClientCredentialsTokenResponseClient extends + AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override - public Mono getTokenResponse(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) { - return Mono.defer(() -> { - ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - - String tokenUri = clientRegistration.getProviderDetails().getTokenUri(); - BodyInserters.FormInserter body = body(authorizationGrantRequest); - - return this.webClient.post() - .uri(tokenUri) - .accept(MediaType.APPLICATION_JSON) - .headers(headers(clientRegistration)) - .body(body) - .exchange() - .flatMap(response -> { - HttpStatus status = HttpStatus.resolve(response.rawStatusCode()); - if (status == null || !status.is2xxSuccessful()) { - // extract the contents of this into a method named oauth2AccessTokenResponse but has an argument for the response - return response.bodyToFlux(DataBuffer.class) - .map(DataBufferUtils::release) - .then(Mono.error(WebClientResponseException.create(response.rawStatusCode(), - "Cannot get token, expected 2xx HTTP Status code", - null, - null, - null - ))); - } - return response.body(oauth2AccessTokenResponse()); }) - .map(response -> { - if (response.getAccessToken().getScopes().isEmpty()) { - response = OAuth2AccessTokenResponse.withResponse(response) - .scopes(authorizationGrantRequest.getClientRegistration().getScopes()) - .build(); - } - return response; - }); - }); + ClientRegistration clientRegistration(OAuth2ClientCredentialsGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); } - private Consumer headers(ClientRegistration clientRegistration) { - return headers -> { - headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - }; - } - - private static BodyInserters.FormInserter body(OAuth2ClientCredentialsGrantRequest authorizationGrantRequest) { - ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration(); - BodyInserters.FormInserter body = BodyInserters - .fromFormData(OAuth2ParameterNames.GRANT_TYPE, authorizationGrantRequest.getGrantType().getValue()); - Set scopes = clientRegistration.getScopes(); - if (!CollectionUtils.isEmpty(scopes)) { - String scope = StringUtils.collectionToDelimitedString(scopes, " "); - body.with(OAuth2ParameterNames.SCOPE, scope); - } - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - return body; + @Override + Set scopes(OAuth2ClientCredentialsGrantRequest grantRequest) { + return grantRequest.getClientRegistration().getScopes(); } - public void setWebClient(WebClient webClient) { - Assert.notNull(webClient, "webClient cannot be null"); - this.webClient = webClient; - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java index 41fe1216943..442e2543fce 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,29 +15,14 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import java.util.Set; /** * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} @@ -53,82 +38,26 @@ * @see Section 4.3.2 Access Token Request (Resource Owner Password Credentials Grant) * @see Section 4.3.3 Access Token Response (Resource Owner Password Credentials Grant) */ -public final class WebClientReactivePasswordTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { - private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private WebClient webClient = WebClient.builder().build(); +public final class WebClientReactivePasswordTokenResponseClient extends + AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override - public Mono getTokenResponse(OAuth2PasswordGrantRequest passwordGrantRequest) { - Assert.notNull(passwordGrantRequest, "passwordGrantRequest cannot be null"); - return Mono.defer(() -> { - ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - return this.webClient.post() - .uri(clientRegistration.getProviderDetails().getTokenUri()) - .headers(tokenRequestHeaders(clientRegistration)) - .body(tokenRequestBody(passwordGrantRequest)) - .exchange() - .flatMap(response -> { - HttpStatus status = HttpStatus.resolve(response.rawStatusCode()); - if (status == null || !status.is2xxSuccessful()) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + - "HTTP Status Code " + response.rawStatusCode(), null); - return response - .bodyToMono(DataBuffer.class) - .map(DataBufferUtils::release) - .then(Mono.error(new OAuth2AuthorizationException(oauth2Error))); - } - return response.body(oauth2AccessTokenResponse()); - }) - .map(tokenResponse -> { - if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { - // As per spec, in Section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Token Request - tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) - .scopes(passwordGrantRequest.getClientRegistration().getScopes()) - .build(); - } - return tokenResponse; - }); - }); + ClientRegistration clientRegistration(OAuth2PasswordGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); } - private static Consumer tokenRequestHeaders(ClientRegistration clientRegistration) { - return headers -> { - headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - }; + @Override + Set scopes(OAuth2PasswordGrantRequest grantRequest) { + return grantRequest.getClientRegistration().getScopes(); } - private static BodyInserters.FormInserter tokenRequestBody(OAuth2PasswordGrantRequest passwordGrantRequest) { - ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration(); - BodyInserters.FormInserter body = BodyInserters.fromFormData( - OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue()); - body.with(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername()); - body.with(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword()); - if (!CollectionUtils.isEmpty(passwordGrantRequest.getClientRegistration().getScopes())) { - body.with(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(passwordGrantRequest.getClientRegistration().getScopes(), " ")); - } - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - return body; + @Override + BodyInserters.FormInserter populateTokenRequestBody( + OAuth2PasswordGrantRequest grantRequest, + BodyInserters.FormInserter body) { + return super.populateTokenRequestBody(grantRequest, body) + .with(OAuth2ParameterNames.USERNAME, grantRequest.getUsername()) + .with(OAuth2ParameterNames.PASSWORD, grantRequest.getPassword()); } - /** - * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response. - * - * @param webClient the {@link WebClient} used when requesting the Access Token Response - */ - public void setWebClient(WebClient webClient) { - Assert.notNull(webClient, "webClient cannot be null"); - this.webClient = webClient; - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java index 6d6daa83d58..9ad787af7fe 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,29 +15,15 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.core.io.buffer.DataBuffer; -import org.springframework.core.io.buffer.DataBufferUtils; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; -import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; +import java.util.Set; /** * An implementation of a {@link ReactiveOAuth2AccessTokenResponseClient} @@ -52,66 +38,37 @@ * @see OAuth2AccessTokenResponse * @see Section 6 Refreshing an Access Token */ -public final class WebClientReactiveRefreshTokenTokenResponseClient implements ReactiveOAuth2AccessTokenResponseClient { - private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; - private WebClient webClient = WebClient.builder().build(); +public final class WebClientReactiveRefreshTokenTokenResponseClient extends + AbstractWebClientReactiveOAuth2AccessTokenResponseClient { @Override - public Mono getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { - Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); - return Mono.defer(() -> { - ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - return this.webClient.post() - .uri(clientRegistration.getProviderDetails().getTokenUri()) - .headers(tokenRequestHeaders(clientRegistration)) - .body(tokenRequestBody(refreshTokenGrantRequest)) - .exchange() - .flatMap(response -> { - HttpStatus status = HttpStatus.resolve(response.rawStatusCode()); - if (status == null || !status.is2xxSuccessful()) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + - "HTTP Status Code " + response.rawStatusCode(), null); - return response - .bodyToMono(DataBuffer.class) - .map(DataBufferUtils::release) - .then(Mono.error(new OAuth2AuthorizationException(oauth2Error))); - } - return response.body(oauth2AccessTokenResponse()); - }) - .map(tokenResponse -> tokenResponse(refreshTokenGrantRequest, tokenResponse)); - }); + ClientRegistration clientRegistration(OAuth2RefreshTokenGrantRequest grantRequest) { + return grantRequest.getClientRegistration(); } - private static Consumer tokenRequestHeaders(ClientRegistration clientRegistration) { - return headers -> { - headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED); - headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON)); - if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) { - headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()); - } - }; + @Override + Set scopes(OAuth2RefreshTokenGrantRequest grantRequest) { + return grantRequest.getScopes(); } - private static BodyInserters.FormInserter tokenRequestBody(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { - ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); - BodyInserters.FormInserter body = BodyInserters.fromFormData( - OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); - body.with(OAuth2ParameterNames.REFRESH_TOKEN, - refreshTokenGrantRequest.getRefreshToken().getTokenValue()); - if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { - body.with(OAuth2ParameterNames.SCOPE, - StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); - } - if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { - body.with(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); - body.with(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); - } - return body; + @Override + Set defaultScopes(OAuth2RefreshTokenGrantRequest grantRequest) { + return grantRequest.getAccessToken().getScopes(); } - private static OAuth2AccessTokenResponse tokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest, - OAuth2AccessTokenResponse accessTokenResponse) { + @Override + BodyInserters.FormInserter populateTokenRequestBody( + OAuth2RefreshTokenGrantRequest grantRequest, + BodyInserters.FormInserter body) { + return super.populateTokenRequestBody(grantRequest, body) + .with(OAuth2ParameterNames.REFRESH_TOKEN, grantRequest.getRefreshToken().getTokenValue()); + } + + @Override + OAuth2AccessTokenResponse populateTokenResponse( + OAuth2RefreshTokenGrantRequest grantRequest, + OAuth2AccessTokenResponse accessTokenResponse) { + if (!CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes()) && accessTokenResponse.getRefreshToken() != null) { return accessTokenResponse; @@ -119,26 +76,13 @@ private static OAuth2AccessTokenResponse tokenResponse(OAuth2RefreshTokenGrantRe OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(accessTokenResponse); if (CollectionUtils.isEmpty(accessTokenResponse.getAccessToken().getScopes())) { - // As per spec, in Section 5.1 Successful Access Token Response - // https://tools.ietf.org/html/rfc6749#section-5.1 - // If AccessTokenResponse.scope is empty, then default to the scope - // originally requested by the client in the Token Request - tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); + tokenResponseBuilder.scopes(defaultScopes(grantRequest)); } if (accessTokenResponse.getRefreshToken() == null) { // Reuse existing refresh token - tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); + tokenResponseBuilder.refreshToken(grantRequest.getRefreshToken().getTokenValue()); } return tokenResponseBuilder.build(); } - /** - * Sets the {@link WebClient} used when requesting the OAuth 2.0 Access Token Response. - * - * @param webClient the {@link WebClient} used when requesting the Access Token Response - */ - public void setWebClient(WebClient webClient) { - Assert.notNull(webClient, "webClient cannot be null"); - this.webClient = webClient; - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java index 852e910e6f6..0917bdbfc92 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,16 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -37,18 +42,52 @@ import java.util.function.Function; /** - * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager}. + * The default implementation of a {@link ReactiveOAuth2AuthorizedClientManager} + * for use within the context of a {@link ServerWebExchange}. + * + *

(When operating outside of the context of a {@link ServerWebExchange}, + * use {@link org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} instead.)

+ * + *

This is a reactive equivalent of {@link DefaultOAuth2AuthorizedClientManager}.

+ * + *

Authorized Client Persistence

+ * + *

This client manager utilizes a {@link ServerOAuth2AuthorizedClientRepository} + * to persist {@link OAuth2AuthorizedClient}s.

+ * + *

By default, when an authorization attempt succeeds, the {@link OAuth2AuthorizedClient} + * will be saved in the authorized client repository. + * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationSuccessHandler} + * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

+ * + *

By default, when an authorization attempt fails due to an + * {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} error, + * the previously saved {@link OAuth2AuthorizedClient} + * will be removed from the authorized client repository. + * (The {@value org.springframework.security.oauth2.core.OAuth2ErrorCodes#INVALID_GRANT} + * error generally occurs when a refresh token that is no longer valid + * is used to retrieve a new access token.) + * This functionality can be changed by configuring a custom {@link ReactiveOAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

* * @author Joe Grandja + * @author Phil Clay * @since 5.2 * @see ReactiveOAuth2AuthorizedClientManager * @see ReactiveOAuth2AuthorizedClientProvider */ public final class DefaultReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { + + private static final Mono currentServerWebExchangeMono = Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + private ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; /** * Constructs a {@code DefaultReactiveOAuth2AuthorizedClientManager} using the provided parameters. @@ -62,6 +101,8 @@ public DefaultReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRe Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientRepository); + this.authorizationFailureHandler = new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientRepository); } @Override @@ -70,57 +111,76 @@ public Mono authorize(OAuth2AuthorizeRequest authorizeRe String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); - - return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) - .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) - .flatMap(authorizedClient -> { - // Re-authorize - return authorizationContext(authorizeRequest, authorizedClient) - .flatMap(this.authorizedClientProvider::authorize) - .flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange)) - // Default to the existing authorizedClient if the client was not re-authorized - .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? - authorizeRequest.getAuthorizedClient() : authorizedClient); - }) - .switchIfEmpty(Mono.deferWithContext(context -> - // Authorize - this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( - "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) - .flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration)) - .flatMap(this.authorizedClientProvider::authorize) - .flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange)) - .subscriberContext(context) - ) - ); + + return Mono.justOrEmpty(authorizeRequest.getAttribute(ServerWebExchange.class.getName())) + .switchIfEmpty(currentServerWebExchangeMono) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) + .flatMap(serverWebExchange -> Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) + .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) + .flatMap(authorizedClient -> { + // Re-authorize + return authorizationContext(authorizeRequest, authorizedClient) + .flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange)) + // Default to the existing authorizedClient if the client was not re-authorized + .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? + authorizeRequest.getAuthorizedClient() : authorizedClient); + }) + .switchIfEmpty(Mono.deferWithContext(context -> + // Authorize + this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( + "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) + .flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration)) + .flatMap(authorizationContext -> authorize(authorizationContext, principal, serverWebExchange)) + .subscriberContext(context) + ) + )); } private Mono loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) { - return Mono.justOrEmpty(serverWebExchange) - .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) - .flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)); + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange); } - private Mono saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) { - return Mono.justOrEmpty(serverWebExchange) - .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) - .flatMap(exchange -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange) - .thenReturn(authorizedClient)); + /** + * Performs authorization, and notifies either the {@link #authorizationSuccessHandler} + * or {@link #authorizationFailureHandler}, depending on the authorization result. + * + * @param authorizationContext the context to authorize + * @param principal the principle to authorize + * @param serverWebExchange the currently active exchange + * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds + * and the {@link #authorizationSuccessHandler} has completed, + * or completes with an exception after the authorization attempt fails + * and the {@link #authorizationFailureHandler} has completed + */ + private Mono authorize( + OAuth2AuthorizationContext authorizationContext, + Authentication principal, + ServerWebExchange serverWebExchange) { + + return this.authorizedClientProvider.authorize(authorizationContext) + // Notify the authorizationSuccessHandler of the successful authorization + .flatMap(authorizedClient -> authorizationSuccessHandler.onAuthorizationSuccess( + authorizedClient, + principal, + createAttributes(serverWebExchange)) + .thenReturn(authorizedClient)) + // Notify the authorizationFailureHandler of the failed authorization + .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> authorizationFailureHandler.onAuthorizationFailure( + authorizationException, + principal, + createAttributes(serverWebExchange)) + .then(Mono.error(authorizationException))); } - private static Mono currentServerWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + private Map createAttributes(ServerWebExchange serverWebExchange) { + return Collections.singletonMap(ServerWebExchange.class.getName(), serverWebExchange); } private Mono authorizationContext(OAuth2AuthorizeRequest authorizeRequest, OAuth2AuthorizedClient authorizedClient) { return Mono.just(authorizeRequest) - .flatMap(this.contextAttributesMapper::apply) + .flatMap(this.contextAttributesMapper) .map(attrs -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(authorizeRequest.getPrincipal()) .attributes(attributes -> { @@ -134,7 +194,7 @@ private Mono authorizationContext(OAuth2AuthorizeReq private Mono authorizationContext(OAuth2AuthorizeRequest authorizeRequest, ClientRegistration clientRegistration) { return Mono.just(authorizeRequest) - .flatMap(this.contextAttributesMapper::apply) + .flatMap(this.contextAttributesMapper) .map(attrs -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(authorizeRequest.getPrincipal()) .attributes(attributes -> { @@ -167,6 +227,36 @@ public void setContextAttributesMapper(FunctionA {@link SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler} + * is used by default.

+ * + * @param authorizationSuccessHandler the handler that handles successful authorizations. + * @see SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler + * @since 5.3 + */ + public void setAuthorizationSuccessHandler(ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler) { + Assert.notNull(authorizationSuccessHandler, "authorizationSuccessHandler cannot be null"); + this.authorizationSuccessHandler = authorizationSuccessHandler; + } + + /** + * Sets the handler that handles authorization failures. + * + *

A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * is used by default.

+ * + * @param authorizationFailureHandler the handler that handles authorization failures. + * @see RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler + * @since 5.3 + */ + public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + } + /** * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. */ @@ -176,7 +266,7 @@ public static class DefaultContextAttributesMapper implements Function> apply(OAuth2AuthorizeRequest authorizeRequest) { ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); return Mono.justOrEmpty(serverWebExchange) - .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) + .switchIfEmpty(currentServerWebExchangeMono) .flatMap(exchange -> { Map contextAttributes = Collections.emptyMap(); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); @@ -190,4 +280,5 @@ public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) .defaultIfEmpty(Collections.emptyMap()); } } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 512083f9f27..29e0ae3aa9d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,26 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; +import org.springframework.http.HttpStatus; +import org.springframework.lang.Nullable; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler; +import org.springframework.security.oauth2.client.SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; @@ -37,15 +44,21 @@ import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.util.Assert; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.function.Consumer; @@ -54,8 +67,27 @@ * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * token as a Bearer Token. * + *

Authentication and Authorization Failures

+ * + *

Since 5.3, this filter function has the ability to forward authentication (HTTP 401 Unauthorized) + * and authorization (HTTP 403 Forbidden) failures from an OAuth 2.0 Resource Server to a + * {@link ReactiveOAuth2AuthorizationFailureHandler}. + * A {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} can be used + * to remove the cached {@link OAuth2AuthorizedClient}, so that future requests will result + * in a new token being retrieved from an Authorization Server, and sent to the Resource Server.

+ * + *

If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository, ServerOAuth2AuthorizedClientRepository)} + * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * will be configured automatically.

+ * + *

If the {@link #ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager)} + * constructor is used, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * will be NOT be configured automatically. + * It is recommended that you configure one via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)}.

+ * * @author Rob Winch * @author Joe Grandja + * @author Phil Clay * @since 5.1 */ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction { @@ -77,7 +109,20 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); - private ReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private final Mono currentAuthenticationMono = ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + + private final Mono clientRegistrationIdMono = currentAuthenticationMono + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + + private final Mono currentServerWebExchangeMono = Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + + private final ReactiveOAuth2AuthorizedClientManager authorizedClientManager; private boolean defaultAuthorizedClientManager; @@ -91,33 +136,71 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Deprecated private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; + private ClientResponseHandler clientResponseHandler; + + @FunctionalInterface + private interface ClientResponseHandler { + Mono handleResponse(ClientRequest request, Mono response); + } /** * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. * + *

When this constructor is used, authentication (HTTP 401) and authorization (HTTP 403) + * failures returned from a OAuth 2.0 Resource Server will NOT be forwarded to a + * {@link ReactiveOAuth2AuthorizationFailureHandler}. + * Therefore, future requests to the Resource Server will most likely use the same (most likely invalid) token, + * resulting in the same errors returned from the Resource Server. + * It is recommended to configure a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * via {@link #setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler)} + * so that authentication and authorization failures returned from a Resource Server + * will result in removing the authorized client, so that a new token is retrieved for future requests.

+ * * @since 5.2 * @param authorizedClientManager the {@link ReactiveOAuth2AuthorizedClientManager} which manages the authorized client(s) */ public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientManager authorizedClientManager) { Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); this.authorizedClientManager = authorizedClientManager; + this.clientResponseHandler = (request, responseMono) -> responseMono; } /** * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. * + *

Since 5.3, when this constructor is used, authentication (HTTP 401) + * and authorization (HTTP 403) failures returned from a OAuth 2.0 Resource Server + * will be forwarded to a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler}, + * which will potentially remove the {@link OAuth2AuthorizedClient} from the given + * {@link ServerOAuth2AuthorizedClientRepository}, depending on the OAuth error code returned. + * Authentication failures returned from an OAuth 2.0 Resource Server typically indicate + * that the token is invalid, and should not be used in future requests. + * Removing the authorized client from the repository will ensure that the existing + * token will not be sent for future requests to the Resource Server, + * and a new token is retrieved from Authorization Server and used for + * future requests to the Resource Server.

+ * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + + ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler = + new RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler(authorizedClientRepository); + + this.authorizedClientManager = createDefaultAuthorizedClientManager( + clientRegistrationRepository, + authorizedClientRepository, + authorizationFailureHandler); + this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); this.defaultAuthorizedClientManager = true; } private static ReactiveOAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( ReactiveClientRegistrationRepository clientRegistrationRepository, - ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + ServerOAuth2AuthorizedClientRepository authorizedClientRepository, + ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder.builder() @@ -132,7 +215,8 @@ private static ReactiveOAuth2AuthorizedClientManager createDefaultAuthorizedClie UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager = new UnAuthenticatedReactiveOAuth2AuthorizedClientManager( clientRegistrationRepository, - (UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository); + (UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository, + authorizationFailureHandler); unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); return unauthenticatedAuthorizedClientManager; } @@ -140,6 +224,7 @@ private static ReactiveOAuth2AuthorizedClientManager createDefaultAuthorizedClie DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + authorizedClientManager.setAuthorizationFailureHandler(authorizationFailureHandler); return authorizedClientManager; } @@ -316,8 +401,13 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { public Mono filter(ClientRequest request, ExchangeFunction next) { return authorizedClient(request) .map(authorizedClient -> bearer(request, authorizedClient)) - .flatMap(next::exchange) - .switchIfEmpty(Mono.defer(() -> next.exchange(request))); + .flatMap(requestWithBearer -> exchangeAndHandleResponse(requestWithBearer, next)) + .switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next))); + } + + private Mono exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { + return next.exchange(request) + .transform(responseMono -> clientResponseHandler.handleResponse(request, responseMono)); } private Mono authorizedClient(ClientRequest request) { @@ -330,80 +420,102 @@ private Mono authorizedClient(ClientRequest request) { } private Mono authorizeRequest(ClientRequest request) { - Mono authentication = currentAuthentication(); - - Mono clientRegistrationId = Mono.justOrEmpty(clientRegistrationId(request)) - .switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId)) - .switchIfEmpty(clientRegistrationId(authentication)); + Mono clientRegistrationId = effectiveClientRegistrationId(request); - Mono> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request)) - .switchIfEmpty(currentServerWebExchange()) - .map(Optional::of) - .defaultIfEmpty(Optional.empty()); + Mono> serverWebExchange = effectiveServerWebExchange(request); - return Mono.zip(clientRegistrationId, authentication, serverWebExchange) + return Mono.zip(clientRegistrationId, currentAuthenticationMono, serverWebExchange) .map(t3 -> { OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2()); - if (t3.getT3().isPresent()) { - builder.attribute(ServerWebExchange.class.getName(), t3.getT3().get()); - } + t3.getT3().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange)); return builder.build(); }); } - private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { - Mono authentication = currentAuthentication(); + /** + * Returns a {@link Mono} the emits the {@code clientRegistrationId} + * that is active for the given request. + * + * @param request the request for which to retrieve the {@code clientRegistrationId} + * @return a mono the emits the {@code clientRegistrationId} + * that is active for the given request. + */ + private Mono effectiveClientRegistrationId(ClientRequest request) { + return Mono.justOrEmpty(clientRegistrationId(request)) + .switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId)) + .switchIfEmpty(clientRegistrationIdMono); + } - Mono> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request)) - .switchIfEmpty(currentServerWebExchange()) + /** + * Returns a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange} + * that is active for the given request. + * + *

The returned {@link Mono} will never complete empty. + * Instead, it will emit an empty {@link Optional} if no exchange is active.

+ * + * @param request the request for which to retrieve the exchange + * @return a {@link Mono} that emits an {@link Optional} for the {@link ServerWebExchange} + * that is active for the given request. + */ + private Mono> effectiveServerWebExchange(ClientRequest request) { + return Mono.justOrEmpty(serverWebExchange(request)) + .switchIfEmpty(currentServerWebExchangeMono) .map(Optional::of) .defaultIfEmpty(Optional.empty()); + } - return Mono.zip(authentication, serverWebExchange) + private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + Mono> serverWebExchange = effectiveServerWebExchange(request); + + return Mono.zip(currentAuthenticationMono, serverWebExchange) .map(t2 -> { OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(t2.getT1()); - if (t2.getT2().isPresent()) { - builder.attribute(ServerWebExchange.class.getName(), t2.getT2().get()); - } + t2.getT2().ifPresent(exchange -> builder.attribute(ServerWebExchange.class.getName(), exchange)); return builder.build(); }); } - private Mono currentAuthentication() { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); - } - - private Mono clientRegistrationId(Mono authentication) { - return authentication - .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) - .cast(OAuth2AuthenticationToken.class) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); - } - - private Mono currentServerWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); - } - private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { return ClientRequest.from(request) .headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) .build(); } + /** + * Sets the handler that handles authentication and authorization failures when communicating + * to the OAuth2 Resource Server. + * + *

For example, a {@link RemoveAuthorizedClientReactiveOAuth2AuthorizationFailureHandler} + * is typically used to remove the cached {@link OAuth2AuthorizedClient}, + * so that the same token is no longer used in future requests to the Resource Server.

+ * + *

The failure handler used by default depends on which constructor was used + * to construct this {@link ServerOAuth2AuthorizedClientExchangeFilterFunction}. + * See the constructors for more details.

+ * + * @param authorizationFailureHandler the handler that handles authentication and authorization failures. + * @since 5.3 + */ + public void setAuthorizationFailureHandler(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); + } + private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private final ReactiveOAuth2AuthorizationSuccessHandler authorizationSuccessHandler; + private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; private UnAuthenticatedReactiveOAuth2AuthorizedClientManager( ReactiveClientRegistrationRepository clientRegistrationRepository, - UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository, + ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizationSuccessHandler = new SaveAuthorizedClientReactiveOAuth2AuthorizationSuccessHandler(authorizedClientRepository); + this.authorizationFailureHandler = authorizationFailureHandler; } @Override @@ -418,8 +530,7 @@ public Mono authorize(OAuth2AuthorizeRequest authorizeRe .flatMap(authorizedClient -> { // Re-authorize return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build()) - .flatMap(this.authorizedClientProvider::authorize) - .flatMap(reauthorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, null).thenReturn(reauthorizedClient)) + .flatMap(authorizationContext -> authorize(authorizationContext, principal)) // Default to the existing authorizedClient if the client was not re-authorized .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? authorizeRequest.getAuthorizedClient() : authorizedClient); @@ -430,15 +541,184 @@ public Mono authorize(OAuth2AuthorizeRequest authorizeRe .switchIfEmpty(Mono.error(() -> new IllegalArgumentException( "Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) .flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build())) - .flatMap(this.authorizedClientProvider::authorize) - .flatMap(authorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null).thenReturn(authorizedClient)) + .flatMap(authorizationContext -> authorize(authorizationContext, principal)) .subscriberContext(context) )); } + /** + * Performs authorization, and notifies either the {@link #authorizationSuccessHandler} + * or {@link #authorizationFailureHandler}, depending on the authorization result. + * + * @param authorizationContext the context to authorize + * @param principal the principle to authorize + * @return a {@link Mono} that emits the authorized client after the authorization attempt succeeds + * and the {@link #authorizationSuccessHandler} has completed, + * or completes with an exception after the authorization attempt fails + * and the {@link #authorizationFailureHandler} has completed + */ + private Mono authorize( + OAuth2AuthorizationContext authorizationContext, + Authentication principal) { + + return this.authorizedClientProvider.authorize(authorizationContext) + // Notify the authorizationSuccessHandler of the successful authorization + .flatMap(authorizedClient -> authorizationSuccessHandler.onAuthorizationSuccess( + authorizedClient, + principal, + Collections.emptyMap()) + .thenReturn(authorizedClient)) + // Notify the authorizationFailureHandler of the failed authorization + .onErrorResume(OAuth2AuthorizationException.class, authorizationException -> authorizationFailureHandler.onAuthorizationFailure( + authorizationException, + principal, + Collections.emptyMap()) + .then(Mono.error(authorizationException))); + } + private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); this.authorizedClientProvider = authorizedClientProvider; } } + + /** + * Forwards authentication and authorization failures to a + * {@link ReactiveOAuth2AuthorizationFailureHandler}. + * + * @since 5.3 + */ + private class AuthorizationFailureForwarder implements ClientResponseHandler { + + /** + * A map of HTTP Status Code to OAuth 2.0 Error codes for + * HTTP status codes that should be interpreted as + * authentication or authorization failures. + */ + private final Map httpStatusToOAuth2ErrorCodeMap; + + /** + * The {@link ReactiveOAuth2AuthorizationFailureHandler} to notify + * when a authentication/authorization failure occurs. + */ + private final ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; + + private AuthorizationFailureForwarder(ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler) { + Assert.notNull(authorizationFailureHandler, "authorizationFailureHandler cannot be null"); + this.authorizationFailureHandler = authorizationFailureHandler; + + Map httpStatusToOAuth2Error = new HashMap<>(); + httpStatusToOAuth2Error.put(HttpStatus.UNAUTHORIZED.value(), OAuth2ErrorCodes.INVALID_TOKEN); + httpStatusToOAuth2Error.put(HttpStatus.FORBIDDEN.value(), OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + this.httpStatusToOAuth2ErrorCodeMap = Collections.unmodifiableMap(httpStatusToOAuth2Error); + } + + @Override + public Mono handleResponse( + ClientRequest request, + Mono responseMono) { + + return responseMono + .flatMap(response -> handleHttpStatus(request, response.rawStatusCode(), null) + .thenReturn(response)) + .onErrorResume(WebClientResponseException.class, e -> handleHttpStatus(request, e.getRawStatusCode(), e) + .then(Mono.error(e))) + .onErrorResume(OAuth2AuthorizationException.class, e -> handleAuthorizationException(request, e) + .then(Mono.error(e))); + } + + /** + * Handles the given http status code returned from a resource server + * by notifying the authorization failure handler if the http status + * code is in the {@link #httpStatusToOAuth2ErrorCodeMap}. + * + * @param request the request being processed + * @param httpStatusCode the http status returned by the resource server + * @param exception The root cause exception for the failure (nullable) + * @return a {@link Mono} that completes empty after the authorization failure handler completes. + */ + private Mono handleHttpStatus(ClientRequest request, int httpStatusCode, @Nullable Exception exception) { + return Mono.justOrEmpty(httpStatusToOAuth2ErrorCodeMap.get(httpStatusCode)) + .flatMap(oauth2ErrorCode -> { + Mono> serverWebExchange = effectiveServerWebExchange(request); + + Mono clientRegistrationId = effectiveClientRegistrationId(request); + + return Mono.zip(currentAuthenticationMono, serverWebExchange, clientRegistrationId) + .flatMap(tuple3 -> notifyAuthorizationFailure( + tuple3.getT1(), // Authentication principal + tuple3.getT2().orElse(null), // ServerWebExchange exchange + createAuthorizationException( + tuple3.getT3(), // String clientRegistrationId + oauth2ErrorCode, + exception))); + }); + } + + /** + * Handles the given OAuth2AuthorizationException that occurred downstream + * by notifying the authorization failure handler. + * + * @param request the request being processed + * @param exception the authorization exception to include in the failure event. + * @return a {@link Mono} that completes empty after the authorization failure handler completes. + */ + private Mono handleAuthorizationException(ClientRequest request, OAuth2AuthorizationException exception) { + Mono> serverWebExchange = effectiveServerWebExchange(request); + + return Mono.zip(currentAuthenticationMono, serverWebExchange) + .flatMap(tuple2 -> notifyAuthorizationFailure( + tuple2.getT1(), // Authentication principal + tuple2.getT2().orElse(null), // ServerWebExchange exchange + exception)); + } + + /** + * Creates an authorization exception using the given parameters. + * + * @param clientRegistrationId the client registration id of the client that failed authentication/authorization. + * @param oauth2ErrorCode the OAuth2 error code to use in the authorization failure event + * @param exception The root cause exception for the failure (nullable) + * @return an authorization exception using the given parameters. + */ + private ClientAuthorizationException createAuthorizationException( + String clientRegistrationId, + String oauth2ErrorCode, + @Nullable Exception exception) { + return new ClientAuthorizationException( + new OAuth2Error( + oauth2ErrorCode, + null, + "https://tools.ietf.org/html/rfc6750#section-3.1"), + clientRegistrationId, + exception); + } + + + /** + * Notifies the authorization failure handler of the failed authorization. + * + * @param principal the principal associated with the failed authorization attempt + * @param exchange the currently active exchange + * @param exception the authorization exception to include in the failure event. + * @return a {@link Mono} that completes empty after the authorization failure handler completes. + */ + private Mono notifyAuthorizationFailure( + Authentication principal, + ServerWebExchange exchange, + OAuth2AuthorizationException exception) { + + return authorizationFailureHandler.onAuthorizationFailure( + exception, + principal, + createAttributes(exchange)); + } + + private Map createAttributes(ServerWebExchange exchange) { + if (exchange == null) { + return Collections.emptyMap(); + } + return Collections.singletonMap(ServerWebExchange.class.getName(), exchange); + } + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java index ab9c7ff382a..01d38442d24 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -34,6 +37,7 @@ import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -59,6 +63,7 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { private OAuth2AuthorizedClient authorizedClient; private ArgumentCaptor authorizationContextCaptor; private PublisherProbe saveAuthorizedClientProbe; + private PublisherProbe removeAuthorizedClientProbe; @SuppressWarnings("unchecked") @Before @@ -67,6 +72,8 @@ public void setup() { this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); this.saveAuthorizedClientProbe = PublisherProbe.empty(); when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono()); + this.removeAuthorizedClientProbe = PublisherProbe.empty(); + when(this.authorizedClientService.removeAuthorizedClient(any(), any())).thenReturn(this.removeAuthorizedClientProbe.mono()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); this.contextAttributesMapper = mock(Function.class); when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty()); @@ -109,6 +116,20 @@ public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException( .hasMessage("contextAttributesMapper cannot be null"); } + @Test + public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationSuccessHandler cannot be null"); + } + + @Test + public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationFailureHandler cannot be null"); + } + @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) @@ -187,6 +208,214 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { verify(this.authorizedClientService).saveAuthorizedClient( eq(this.authorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + PublisherProbe authorizationSuccessHandlerProbe = PublisherProbe.empty(); + this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); + + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(this.authorizedClient) + .verifyComplete(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + authorizationSuccessHandlerProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); + } + + @Test + public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + this.removeAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); + } + + @Test + public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName())); + this.removeAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); + } + + @Test + public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); + } + + @Test + public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + OAuth2AuthorizationException exception = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); + } + + @Test + public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + OAuth2AuthorizationException exception = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + PublisherProbe authorizationFailureHandlerProbe = PublisherProbe.empty(); + this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + authorizationFailureHandlerProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); + verify(this.authorizedClientService, never()).saveAuthorizedClient(any(), any()); } @SuppressWarnings("unchecked") @@ -222,6 +451,7 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); } @SuppressWarnings("unchecked") @@ -277,6 +507,7 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); } @SuppressWarnings("unchecked") @@ -302,6 +533,7 @@ public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { verify(this.authorizedClientService).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java index d4104bae168..f4d17dcc1ba 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,10 +24,10 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -178,7 +178,7 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOf(OAuth2AuthorizationException.class) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) .hasMessageContaining("unauthorized_client"); } @@ -189,7 +189,7 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(HttpStatus.INTERNAL_SERVER_ERROR.value())); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOf(OAuth2AuthorizationException.class) + .isInstanceOf(ClientAuthorizationException.class) .hasMessageContaining("server_error"); } @@ -204,7 +204,7 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) - .isInstanceOf(OAuth2AuthorizationException.class) + .isInstanceOf(ClientAuthorizationException.class) .hasMessageContaining("invalid_token_response"); } @@ -307,7 +307,7 @@ public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParameters this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block(); String body = this.server.takeRequest().getBody().readUtf8(); - assertThat(body).isEqualTo("grant_type=authorization_code&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&client_id=client-id&code_verifier=code-verifier-1234"); + assertThat(body).isEqualTo("grant_type=authorization_code&client_id=client-id&code=code&redirect_uri=%7BbaseUrl%7D%2F%7Baction%7D%2Foauth2%2Fcode%2F%7BregistrationId%7D&code_verifier=code-verifier-1234"); } private OAuth2AuthorizationCodeGrantRequest pkceAuthorizationCodeGrantRequest() { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java index c4d92d629c0..56549e59752 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.junit.Test; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -32,6 +33,7 @@ import org.springframework.web.reactive.function.client.WebClientResponseException; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.*; /** @@ -103,7 +105,7 @@ public void getTokenResponseWhenPostThenSuccess() throws Exception { assertThat(response.getAccessToken()).isNotNull(); assertThat(actualRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); - assertThat(body).isEqualTo("grant_type=client_credentials&scope=read%3Auser&client_id=client-id&client_secret=client-secret"); + assertThat(body).isEqualTo("grant_type=client_credentials&client_id=client-id&client_secret=client-secret&scope=read%3Auser"); } @Test @@ -147,15 +149,19 @@ public void setWebClientCustomThenCustomClientIsUsed() { verify(customClient, atLeastOnce()).post(); } - @Test(expected = WebClientResponseException.class) - // gh-6089 + @Test public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException { ClientRegistration registration = this.clientRegistration.build(); enqueueUnexpectedResponse(); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); - OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); + assertThatThrownBy(() -> this.client.getTokenResponse(request).block()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .hasMessageContaining("[invalid_token_response]") + .hasMessageContaining("Empty OAuth 2.0 Access Token Response") + .hasMessageContaining("HTTP Status Code: 301"); + } private void enqueueUnexpectedResponse(){ diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java index 93a17f46738..f1da94589f7 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,11 +24,11 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import java.time.Instant; @@ -148,8 +148,10 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu this.clientRegistrationBuilder.build(), this.username, this.password); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .hasMessageContaining("[invalid_token_response]") + .hasMessageContaining("An error occurred parsing the Access Token response") + .hasMessageContaining("HTTP Status Code: 200") .hasCauseInstanceOf(Throwable.class); } @@ -186,9 +188,10 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti this.clientRegistrationBuilder.build(), this.username, this.password); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("HTTP Status Code 400"); + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) + .hasMessageContaining("[unauthorized_client]") + .hasMessageContaining("Error retrieving OAuth 2.0 Access Token") + .hasMessageContaining("HTTP Status Code: 400"); } @Test @@ -199,9 +202,10 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE this.clientRegistrationBuilder.build(), this.username, this.password); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("HTTP Status Code 500"); + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .hasMessageContaining("[invalid_token_response]") + .hasMessageContaining("Empty OAuth 2.0 Access Token Response") + .hasMessageContaining("HTTP Status Code: 500"); } private MockResponse jsonResponse(String json) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java index 0be683ae6c4..2eb3a680e2f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,11 +24,11 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; @@ -153,8 +153,9 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred parsing the Access Token response") + .isInstanceOf(ClientAuthorizationException.class) + .hasMessageContaining("[invalid_token_response]") + .hasMessageContaining("An error occurred parsing the Access Token response") .hasCauseInstanceOf(Throwable.class); } @@ -191,9 +192,9 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("HTTP Status Code 400"); + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("unauthorized_client")) + .hasMessageContaining("[unauthorized_client]") + .hasMessageContaining("HTTP Status Code: 400"); } @Test @@ -204,9 +205,10 @@ public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationE this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest).block()) - .isInstanceOf(OAuth2AuthorizationException.class) - .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") - .hasMessageContaining("HTTP Status Code 500"); + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token_response")) + .hasMessageContaining("[invalid_token_response]") + .hasMessageContaining("Empty OAuth 2.0 Access Token Response") + .hasMessageContaining("HTTP Status Code: 500"); } private MockResponse jsonResponse(String json) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java index 1e1bcbb3bce..d4034712155 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java @@ -23,6 +23,7 @@ import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -31,6 +32,9 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -45,6 +49,7 @@ import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.*; @@ -67,6 +72,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests { private ArgumentCaptor authorizationContextCaptor; private PublisherProbe loadAuthorizedClientProbe; private PublisherProbe saveAuthorizedClientProbe; + private PublisherProbe removeAuthorizedClientProbe; @SuppressWarnings("unchecked") @Before @@ -81,6 +87,9 @@ public void setup() { this.saveAuthorizedClientProbe = PublisherProbe.empty(); when(this.authorizedClientRepository.saveAuthorizedClient( any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.saveAuthorizedClientProbe.mono()); + this.removeAuthorizedClientProbe = PublisherProbe.empty(); + when(this.authorizedClientRepository.removeAuthorizedClient( + any(String.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(this.removeAuthorizedClientProbe.mono()); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); this.contextAttributesMapper = mock(Function.class); @@ -119,6 +128,20 @@ public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException .hasMessage("authorizedClientProvider cannot be null"); } + @Test + public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationSuccessHandler cannot be null"); + } + + @Test + public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationFailureHandler cannot be null"); + } + @Test public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) @@ -204,8 +227,211 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { verify(this.authorizedClientRepository).saveAuthorizedClient( eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderAndCustomSuccessHandlerThenInvokeCustomSuccessHandler() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + PublisherProbe authorizationSuccessHandlerProbe = PublisherProbe.empty(); + this.authorizedClientManager.setAuthorizationSuccessHandler((client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + authorizationSuccessHandlerProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenInvalidTokenThenRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientRepository).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange)); + this.removeAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenInvalidGrantThenRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientRepository).removeAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.serverWebExchange)); + this.removeAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenServerErrorThenDoNotRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + ClientAuthorizationException exception = new ClientAuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null), + this.clientRegistration.getRegistrationId()); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenOAuth2AuthorizationExceptionThenDoNotRemoveAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + OAuth2AuthorizationException exception = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenOAuth2AuthorizationExceptionAndCustomFailureHandlerThenInvokeCustomFailureHandler() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + + OAuth2AuthorizationException exception = new OAuth2AuthorizationException( + new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); + + when(this.authorizedClientProvider.authorize( + any(OAuth2AuthorizationContext.class))).thenReturn(Mono.error(exception)); + + PublisherProbe authorizationFailureHandlerProbe = PublisherProbe.empty(); + this.authorizedClientManager.setAuthorizationFailureHandler((client, principal, attributes) -> authorizationFailureHandlerProbe.mono()); + + assertThatCode(() -> this.authorizedClientManager.authorize(authorizeRequest) + .subscriberContext(this.context).block()) + .isEqualTo(exception); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + authorizationFailureHandlerProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + } @SuppressWarnings("unchecked") @Test public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { @@ -239,6 +465,7 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { verify(this.authorizedClientRepository).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } @Test @@ -332,6 +559,7 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { verify(this.authorizedClientRepository).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal), eq(this.serverWebExchange)); this.saveAuthorizedClientProbe.assertWasSubscribed(); + verify(this.authorizedClientRepository, never()).removeAuthorizedClient(any(), any(), any()); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java new file mode 100644 index 00000000000..6d8c2f0d12f --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionITests.java @@ -0,0 +1,330 @@ +/* + * Copyright 2002-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.client.WebClientResponseException; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; + +public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests { + + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + private ServerOAuth2AuthorizedClientExchangeFilterFunction authorizedClientFilter; + private MockWebServer server; + private String serverUrl; + private WebClient webClient; + private Authentication authentication; + private MockServerWebExchange exchange; + + @Before + public void setUp() throws Exception { + this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + final ServerOAuth2AuthorizedClientRepository delegate = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository( + new InMemoryReactiveOAuth2AuthorizedClientService(this.clientRegistrationRepository)); + this.authorizedClientRepository = spy(new ServerOAuth2AuthorizedClientRepository() { + + @Override + public Mono loadAuthorizedClient( + String clientRegistrationId, + Authentication principal, ServerWebExchange exchange) { + return delegate.loadAuthorizedClient(clientRegistrationId, principal, exchange); + } + + @Override + public Mono saveAuthorizedClient( + OAuth2AuthorizedClient authorizedClient, + Authentication principal, ServerWebExchange exchange) { + return delegate.saveAuthorizedClient(authorizedClient, principal, exchange); + } + + @Override + public Mono removeAuthorizedClient( + String clientRegistrationId, + Authentication principal, ServerWebExchange exchange) { + return delegate.removeAuthorizedClient(clientRegistrationId, principal, exchange); + } + + }); + this.authorizedClientFilter = new ServerOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.server = new MockWebServer(); + this.server.start(); + this.serverUrl = this.server.url("/").toString(); + this.webClient = WebClient.builder() + .filter(this.authorizedClientFilter) + .build(); + this.authentication = new TestingAuthenticationToken("principal", "password"); + this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/").build()).build(); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); + } + + @Test + public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"refreshed-access-token\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofHours(1)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "expired-access-token", issuedAt, expiresAt, new HashSet<>(Arrays.asList("read", "write"))); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, this.authentication.getName(), accessToken, refreshToken); + doReturn(Mono.just(authorizedClient)).when(this.authorizedClientRepository).loadAuthorizedClient( + eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(2); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + OAuth2AuthorizedClient refreshedAuthorizedClient = authorizedClientCaptor.getValue(); + assertThat(refreshedAuthorizedClient.getClientRegistration()).isSameAs(clientRegistration); + assertThat(refreshedAuthorizedClient.getAccessToken().getTokenValue()).isEqualTo("refreshed-access-token"); + } + + @Test + public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + + // Client 1 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration1 = TestClientRegistrations.clientCredentials() + .registrationId("client-1").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration1.getRegistrationId()))).thenReturn(Mono.just(clientRegistration1)); + + // Client 2 + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration2 = TestClientRegistrations.clientCredentials() + .registrationId("client-2").tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))).thenReturn(Mono.just(clientRegistration2)); + + this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration1.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .flatMap(response -> this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration2.getRegistrationId())) + .retrieve() + .bodyToMono(String.class)) + .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)) + .block(); + + assertThat(this.server.getRequestCount()).isEqualTo(4); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository, times(2)).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + assertThat(authorizedClientCaptor.getAllValues().get(0).getClientRegistration()).isSameAs(clientRegistration1); + assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2); + } + + /** + * When a non-expired {@link OAuth2AuthorizedClient} exists + * but the resource server returns 401, + * then remove the {@link OAuth2AuthorizedClient} from the repository. + */ + @Test + public void requestWhenUnauthorizedThenReAuthorize() { + String accessTokenResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + String clientResponse = "{\n" + + " \"attribute1\": \"value1\",\n" + + " \"attribute2\": \"value2\"\n" + + "}\n"; + this.server.enqueue(new MockResponse().setResponseCode(HttpStatus.UNAUTHORIZED.value())); + this.server.enqueue(jsonResponse(accessTokenResponse)); + this.server.enqueue(jsonResponse(clientResponse)); + + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl).build(); + when(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))).thenReturn(Mono.just(clientRegistration)); + + OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken(); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, this.authentication.getName(), accessToken, refreshToken); + doReturn(Mono.just(authorizedClient)) + .doReturn(Mono.empty()) + .when(this.authorizedClientRepository).loadAuthorizedClient( + eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); + + Mono requestMono = this.webClient + .get() + .uri(this.serverUrl) + .attributes(clientRegistrationId(clientRegistration.getRegistrationId())) + .retrieve() + .bodyToMono(String.class) + .subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)); + + // first try should fail, and remove the cached authorized client + assertThatCode(requestMono::block) + .isInstanceOfSatisfying(WebClientResponseException.class, e -> assertThat(e.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED)); + + assertThat(this.server.getRequestCount()).isEqualTo(1); + + verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); + verify(this.authorizedClientRepository).removeAuthorizedClient( + eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.exchange)); + + // second try should retrieve the authorized client and succeed + requestMono.block(); + + assertThat(this.server.getRequestCount()).isEqualTo(3); + + ArgumentCaptor authorizedClientCaptor = ArgumentCaptor.forClass(OAuth2AuthorizedClient.class); + verify(this.authorizedClientRepository).saveAuthorizedClient( + authorizedClientCaptor.capture(), eq(this.authentication), eq(this.exchange)); + assertThat(authorizedClientCaptor.getValue().getClientRegistration()).isSameAs(clientRegistration); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 4156a34670e..5b86a163f2a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.client.web.reactive.function.client; import org.junit.Before; @@ -27,6 +26,7 @@ import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; @@ -39,11 +39,15 @@ import org.springframework.mock.http.client.reactive.MockClientHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.oauth2.client.ClientAuthorizationException; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizationFailureHandler; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; @@ -59,6 +63,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -67,11 +74,15 @@ import org.springframework.util.StringUtils; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import org.springframework.web.reactive.function.client.WebClientResponseException; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; +import reactor.test.publisher.PublisherProbe; import reactor.util.context.Context; import java.net.URI; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -82,8 +93,16 @@ import java.util.Optional; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; import static org.springframework.http.HttpMethod.GET; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId; import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient; @@ -109,6 +128,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ReactiveOAuth2AccessTokenResponseClient passwordTokenResponseClient; + @Mock + private ReactiveOAuth2AuthorizationFailureHandler authorizationFailureHandler; + + @Captor + private ArgumentCaptor authorizationExceptionCaptor; + + @Captor + private ArgumentCaptor authenticationCaptor; + + @Captor + private ArgumentCaptor> attributesCaptor; + private ServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); @Captor @@ -414,6 +445,240 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { assertThat(getBody(request0)).isEmpty(); } + @Test + public void filterWhenUnauthorizedThenInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + PublisherProbe publisherProbe = PublisherProbe.empty(); + when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.UNAUTHORIZED.value()); + + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); + + assertThat(publisherProbe.wasSubscribed()).isTrue(); + + verify(authorizationFailureHandler).onAuthorizationFailure( + authorizationExceptionCaptor.capture(), + authenticationCaptor.capture(), + attributesCaptor.capture()); + + assertThat(authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token"); + assertThat(e).hasNoCause(); + assertThat(e).hasMessageContaining("[invalid_token]"); + }); + assertThat(authenticationCaptor.getValue()) + .isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(attributesCaptor.getValue()) + .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); + } + + @Test + public void filterWhenUnauthorizedWithWebClientExceptionThenInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + PublisherProbe publisherProbe = PublisherProbe.empty(); + when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + WebClientResponseException exception = WebClientResponseException.create( + HttpStatus.UNAUTHORIZED.value(), + HttpStatus.UNAUTHORIZED.getReasonPhrase(), + HttpHeaders.EMPTY, + new byte[0], + StandardCharsets.UTF_8); + + ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); + + assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) + .subscriberContext(serverWebExchange()) + .block()) + .isEqualTo(exception); + + assertThat(publisherProbe.wasSubscribed()).isTrue(); + + verify(authorizationFailureHandler).onAuthorizationFailure( + authorizationExceptionCaptor.capture(), + authenticationCaptor.capture(), + attributesCaptor.capture()); + + assertThat(authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo("invalid_token"); + assertThat(e).hasCause(exception); + assertThat(e).hasMessageContaining("[invalid_token]"); + }); + assertThat(authenticationCaptor.getValue()) + .isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(attributesCaptor.getValue()) + .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); + } + + @Test + public void filterWhenForbiddenThenInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + PublisherProbe publisherProbe = PublisherProbe.empty(); + when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.FORBIDDEN.value()); + + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); + + assertThat(publisherProbe.wasSubscribed()).isTrue(); + + verify(authorizationFailureHandler).onAuthorizationFailure( + authorizationExceptionCaptor.capture(), + authenticationCaptor.capture(), + attributesCaptor.capture()); + + assertThat(authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope"); + assertThat(e).hasNoCause(); + assertThat(e).hasMessageContaining("[insufficient_scope]"); + }); + assertThat(authenticationCaptor.getValue()) + .isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(attributesCaptor.getValue()) + .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); + } + + @Test + public void filterWhenForbiddenWithWebClientExceptionThenInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + PublisherProbe publisherProbe = PublisherProbe.empty(); + when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + WebClientResponseException exception = WebClientResponseException.create( + HttpStatus.FORBIDDEN.value(), + HttpStatus.FORBIDDEN.getReasonPhrase(), + HttpHeaders.EMPTY, + new byte[0], + StandardCharsets.UTF_8); + + ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); + + assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) + .subscriberContext(serverWebExchange()) + .block()) + .isEqualTo(exception); + + assertThat(publisherProbe.wasSubscribed()).isTrue(); + + verify(authorizationFailureHandler).onAuthorizationFailure( + authorizationExceptionCaptor.capture(), + authenticationCaptor.capture(), + attributesCaptor.capture()); + + assertThat(authorizationExceptionCaptor.getValue()) + .isInstanceOfSatisfying(ClientAuthorizationException.class, e -> { + assertThat(e.getClientRegistrationId()).isEqualTo(registration.getRegistrationId()); + assertThat(e.getError().getErrorCode()).isEqualTo("insufficient_scope"); + assertThat(e).hasCause(exception); + assertThat(e).hasMessageContaining("[insufficient_scope]"); + }); + assertThat(authenticationCaptor.getValue()) + .isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(attributesCaptor.getValue()) + .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); + } + + @Test + public void filterWhenAuthorizationExceptionThenInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + PublisherProbe publisherProbe = PublisherProbe.empty(); + when(authorizationFailureHandler.onAuthorizationFailure(any(), any(), any())).thenReturn(publisherProbe.mono()); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + OAuth2AuthorizationException exception = new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null)); + + ExchangeFunction throwingExchangeFunction = r -> Mono.error(exception); + + assertThatCode(() -> this.function.filter(request, throwingExchangeFunction) + .subscriberContext(serverWebExchange()) + .block()) + .isEqualTo(exception); + + assertThat(publisherProbe.wasSubscribed()).isTrue(); + + verify(authorizationFailureHandler).onAuthorizationFailure( + authorizationExceptionCaptor.capture(), + authenticationCaptor.capture(), + attributesCaptor.capture()); + + assertThat(authorizationExceptionCaptor.getValue()) + .isSameAs(exception); + assertThat(authenticationCaptor.getValue()) + .isInstanceOf(AnonymousAuthenticationToken.class); + assertThat(attributesCaptor.getValue()) + .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); + } + + @Test + public void filterWhenOtherHttpStatusShouldNotInvokeFailureHandler() { + function.setAuthorizationFailureHandler(authorizationFailureHandler); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, + "principalName", this.accessToken, refreshToken); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) + .attributes(oauth2AuthorizedClient(authorizedClient)) + .build(); + + when(exchange.getResponse().rawStatusCode()).thenReturn(HttpStatus.BAD_REQUEST.value()); + + this.function.filter(request, this.exchange) + .subscriberContext(serverWebExchange()) + .block(); + + verify(authorizationFailureHandler, never()).onAuthorizationFailure(any(), any(), any()); + } + @Test public void filterWhenPasswordClientNotAuthorizedThenGetNewToken() { TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java index 2c135bfa8df..a894c6d6ebe 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2AuthorizationException.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.core; +import org.springframework.util.Assert; + /** * Base exception for OAuth 2.0 Authorization errors. * @@ -30,7 +32,19 @@ public class OAuth2AuthorizationException extends RuntimeException { * @param error the {@link OAuth2Error OAuth 2.0 Error} */ public OAuth2AuthorizationException(OAuth2Error error) { - super(error.toString()); + this(error, error.toString()); + } + + /** + * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param message the exception message + * @since 5.3 + */ + public OAuth2AuthorizationException(OAuth2Error error, String message) { + super(message); + Assert.notNull(error, "error must not be null"); this.error = error; } @@ -41,7 +55,20 @@ public OAuth2AuthorizationException(OAuth2Error error) { * @param cause the root cause */ public OAuth2AuthorizationException(OAuth2Error error, Throwable cause) { - super(error.toString(), cause); + this(error, error.toString(), cause); + } + + /** + * Constructs an {@code OAuth2AuthorizationException} using the provided parameters. + * + * @param error the {@link OAuth2Error OAuth 2.0 Error} + * @param message the exception message + * @param cause the root cause + * @since 5.3 + */ + public OAuth2AuthorizationException(OAuth2Error error, String message, Throwable cause) { + super(message, cause); + Assert.notNull(error, "error must not be null"); this.error = error; } diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java index 973443ae2fe..7a8f7a16779 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/OAuth2ErrorCodes.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2017 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,6 +53,27 @@ public interface OAuth2ErrorCodes { */ String INVALID_SCOPE = "invalid_scope"; + /** + * {@code insufficient_scope} - The request requires higher privileges than + * provided by the access token. + * The resource server SHOULD respond with the HTTP 403 (Forbidden) + * status code and MAY include the "scope" attribute with the scope + * necessary to access the protected resource. + * + * @see RFC-6750 - Section 3.1 - Error Codes + */ + String INSUFFICIENT_SCOPE = "insufficient_scope"; + + /** + * {@code invalid_token} - The access token provided is expired, revoked, + * malformed, or invalid for other reasons. + * The resource SHOULD respond with the HTTP 401 (Unauthorized) status code. + * The client MAY request a new access token and retry the protected resource request. + * + * @see RFC-6750 - Section 3.1 - Error Codes + */ + String INVALID_TOKEN = "invalid_token"; + /** * {@code server_error} - The authorization server encountered an * unexpected condition that prevented it from fulfilling the request. diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java index b9c5c9bfd05..b6de896d320 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2AccessTokenResponseBodyExtractor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,10 @@ public Mono extract(ReactiveHttpInputMessage inputMes ParameterizedTypeReference> type = new ParameterizedTypeReference>() {}; BodyExtractor>, ReactiveHttpInputMessage> delegate = BodyExtractors.toMono(type); return delegate.extract(inputMessage, context) + .onErrorMap(e -> new OAuth2AuthorizationException( + invalidTokenResponse("An error occurred parsing the Access Token response: " + e.getMessage()), e)) + .switchIfEmpty(Mono.error(() -> new OAuth2AuthorizationException( + invalidTokenResponse("Empty OAuth 2.0 Access Token Response")))) .map(OAuth2AccessTokenResponseBodyExtractor::parse) .flatMap(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse) .map(OAuth2AccessTokenResponseBodyExtractor::oauth2AccessTokenResponse); @@ -68,12 +72,19 @@ private static TokenResponse parse(Map json) { return TokenResponse.parse(new JSONObject(json)); } catch (ParseException pe) { - OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, - "An error occurred parsing the Access Token response: " + pe.getMessage(), null); + OAuth2Error oauth2Error = invalidTokenResponse( + "An error occurred parsing the Access Token response: " + pe.getMessage()); throw new OAuth2AuthorizationException(oauth2Error, pe); } } + private static OAuth2Error invalidTokenResponse(String message) { + return new OAuth2Error( + INVALID_TOKEN_RESPONSE_ERROR_CODE, + message, + null); + } + private static Mono oauth2AccessTokenResponse(TokenResponse tokenResponse) { if (tokenResponse.indicatesSuccess()) { return Mono.just(tokenResponse) diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java index 3290a0518d1..48f7cd19da5 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/web/reactive/function/OAuth2BodyExtractorsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.mock.http.client.reactive.MockClientHttpResponse; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.web.reactive.function.BodyExtractor; import reactor.core.publisher.Mono; @@ -92,8 +93,23 @@ public void oauth2AccessTokenResponseWhenInvalidJsonThenException() { Mono result = extractor.extract(response, this.context); - assertThatCode(() -> result.block()) - .isInstanceOf(RuntimeException.class); + assertThatCode(result::block) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("An error occurred parsing the Access Token response"); + } + + @Test + public void oauth2AccessTokenResponseWhenEmptyThenException() { + BodyExtractor, ReactiveHttpInputMessage> extractor = OAuth2BodyExtractors + .oauth2AccessTokenResponse(); + + MockClientHttpResponse response = new MockClientHttpResponse(HttpStatus.OK); + + Mono result = extractor.extract(response, this.context); + + assertThatCode(result::block) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("Empty OAuth 2.0 Access Token Response"); } @Test