From f2010222c59adf9abe97cc587a93f51a519d1cc3 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 11 Feb 2022 06:40:41 -0500 Subject: [PATCH] Introduce OAuth2TokenGenerator Closes gh-414 --- .../OAuth2AuthorizationServerConfigurer.java | 15 ++ .../authorization/OAuth2ConfigurerUtils.java | 25 +- .../OAuth2TokenEndpointConfigurer.java | 31 +-- ...cClientRegistrationEndpointConfigurer.java | 4 +- .../DefaultOAuth2TokenContext.java | 80 +++++++ .../server/authorization/JwtGenerator.java | 166 ++++++++++++++ .../authorization/OAuth2TokenContext.java | 29 ++- .../authorization/OAuth2TokenGenerator.java | 44 ++++ .../authentication/JwtUtils.java | 101 -------- ...thorizationCodeAuthenticationProvider.java | 149 ++++++------ ...ientCredentialsAuthenticationProvider.java | 87 ++++--- ...th2RefreshTokenAuthenticationProvider.java | 152 ++++++------- .../oidc/authentication/JwtUtils.java | 76 ------- ...entRegistrationAuthenticationProvider.java | 84 +++++-- .../OAuth2AuthorizationCodeGrantTests.java | 30 ++- .../server/authorization/OidcTests.java | 78 ++++++- .../authorization/JwtGeneratorTests.java | 215 ++++++++++++++++++ ...zationCodeAuthenticationProviderTests.java | 95 +++++++- ...redentialsAuthenticationProviderTests.java | 47 +++- ...freshTokenAuthenticationProviderTests.java | 95 +++++++- ...gistrationAuthenticationProviderTests.java | 64 +++++- 21 files changed, 1229 insertions(+), 438 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java delete mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java delete mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index c4e66c7f6..a7e9e6979 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -37,8 +37,10 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.Transient; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; @@ -146,6 +148,19 @@ public OAuth2AuthorizationServerConfigurer providerSettings(ProviderSettings return this; } + /** + * Sets the token generator. + * + * @param tokenGenerator the token generator + * @return the {@link OAuth2AuthorizationServerConfigurer} for further configuration + * @since 0.2.3 + */ + public OAuth2AuthorizationServerConfigurer tokenGenerator(OAuth2TokenGenerator tokenGenerator) { + Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); + getBuilder().setSharedObject(OAuth2TokenGenerator.class, tokenGenerator); + return this; + } + /** * Configures OAuth 2.0 Client Authentication. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java index 5495c6c96..d0bc57330 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java @@ -26,14 +26,17 @@ import org.springframework.context.ApplicationContext; import org.springframework.core.ResolvableType; import org.springframework.security.config.annotation.web.HttpSecurityBuilder; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.util.StringUtils; @@ -82,7 +85,25 @@ static > OAuth2AuthorizationConsentService getA return authorizationConsentService; } - static > JwtEncoder getJwtEncoder(B builder) { + @SuppressWarnings("unchecked") + static > OAuth2TokenGenerator getTokenGenerator(B builder) { + OAuth2TokenGenerator tokenGenerator = builder.getSharedObject(OAuth2TokenGenerator.class); + if (tokenGenerator == null) { + tokenGenerator = getOptionalBean(builder, OAuth2TokenGenerator.class); + if (tokenGenerator == null) { + JwtGenerator jwtGenerator = new JwtGenerator(getJwtEncoder(builder)); + OAuth2TokenCustomizer jwtCustomizer = getJwtCustomizer(builder); + if (jwtCustomizer != null) { + jwtGenerator.setJwtCustomizer(jwtCustomizer); + } + tokenGenerator = jwtGenerator; + } + builder.setSharedObject(OAuth2TokenGenerator.class, tokenGenerator); + } + return tokenGenerator; + } + + private static > JwtEncoder getJwtEncoder(B builder) { JwtEncoder jwtEncoder = builder.getSharedObject(JwtEncoder.class); if (jwtEncoder == null) { jwtEncoder = getOptionalBean(builder, JwtEncoder.class); @@ -107,7 +128,7 @@ static > JWKSource getJwkSourc } @SuppressWarnings("unchecked") - static > OAuth2TokenCustomizer getJwtCustomizer(B builder) { + private static > OAuth2TokenCustomizer getJwtCustomizer(B builder) { OAuth2TokenCustomizer jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class); if (jwtCustomizer == null) { ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java index c45424905..dbadad207 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java @@ -28,10 +28,10 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; -import org.springframework.security.oauth2.jwt.JwtEncoder; -import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken; @@ -160,34 +160,19 @@ RequestMatcher getRequestMatcher() { private > List createDefaultAuthenticationProviders(B builder) { List authenticationProviders = new ArrayList<>(); - JwtEncoder jwtEncoder = OAuth2ConfigurerUtils.getJwtEncoder(builder); - OAuth2TokenCustomizer jwtCustomizer = OAuth2ConfigurerUtils.getJwtCustomizer(builder); + OAuth2AuthorizationService authorizationService = OAuth2ConfigurerUtils.getAuthorizationService(builder); + OAuth2TokenGenerator tokenGenerator = OAuth2ConfigurerUtils.getTokenGenerator(builder); OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = - new OAuth2AuthorizationCodeAuthenticationProvider( - OAuth2ConfigurerUtils.getAuthorizationService(builder), - jwtEncoder); - if (jwtCustomizer != null) { - authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer); - } + new OAuth2AuthorizationCodeAuthenticationProvider(authorizationService, tokenGenerator); authenticationProviders.add(authorizationCodeAuthenticationProvider); OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider = - new OAuth2RefreshTokenAuthenticationProvider( - OAuth2ConfigurerUtils.getAuthorizationService(builder), - jwtEncoder); - if (jwtCustomizer != null) { - refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer); - } + new OAuth2RefreshTokenAuthenticationProvider(authorizationService, tokenGenerator); authenticationProviders.add(refreshTokenAuthenticationProvider); OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider = - new OAuth2ClientCredentialsAuthenticationProvider( - OAuth2ConfigurerUtils.getAuthorizationService(builder), - jwtEncoder); - if (jwtCustomizer != null) { - clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer); - } + new OAuth2ClientCredentialsAuthenticationProvider(authorizationService, tokenGenerator); authenticationProviders.add(clientCredentialsAuthenticationProvider); return authenticationProviders; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java index 248c81352..2a30def45 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ > void init(B builder) { new OidcClientRegistrationAuthenticationProvider( OAuth2ConfigurerUtils.getRegisteredClientRepository(builder), OAuth2ConfigurerUtils.getAuthorizationService(builder), - OAuth2ConfigurerUtils.getJwtEncoder(builder)); + OAuth2ConfigurerUtils.getTokenGenerator(builder)); builder.authenticationProvider(postProcess(oidcClientRegistrationAuthenticationProvider)); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java new file mode 100644 index 000000000..1a57d85e5 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default implementation of {@link OAuth2TokenContext}. + * + * @author Joe Grandja + * @since 0.2.3 + * @see OAuth2TokenContext + */ +public final class DefaultOAuth2TokenContext implements OAuth2TokenContext { + private final Map context; + + private DefaultOAuth2TokenContext(Map context) { + this.context = Collections.unmodifiableMap(new HashMap<>(context)); + } + + @SuppressWarnings("unchecked") + @Nullable + @Override + public V get(Object key) { + return hasKey(key) ? (V) this.context.get(key) : null; + } + + @Override + public boolean hasKey(Object key) { + Assert.notNull(key, "key cannot be null"); + return this.context.containsKey(key); + } + + /** + * Returns a new {@link Builder}. + * + * @return the {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link DefaultOAuth2TokenContext}. + */ + public static final class Builder extends AbstractBuilder { + + private Builder() { + } + + /** + * Builds a new {@link DefaultOAuth2TokenContext}. + * + * @return the {@link DefaultOAuth2TokenContext} + */ + public DefaultOAuth2TokenContext build() { + return new DefaultOAuth2TokenContext(getContext()); + } + + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java new file mode 100644 index 000000000..b55b9e53f --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java @@ -0,0 +1,166 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.function.Consumer; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * An {@link OAuth2TokenGenerator} that generates a {@link Jwt} + * used for an {@link OAuth2AccessToken} or {@link OidcIdToken}. + * + * @author Joe Grandja + * @since 0.2.3 + * @see OAuth2TokenGenerator + * @see Jwt + * @see JwtEncoder + * @see OAuth2TokenCustomizer + * @see JwtEncodingContext + * @see OAuth2AccessToken + * @see OidcIdToken + */ +public final class JwtGenerator implements OAuth2TokenGenerator { + private final JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer; + + /** + * Constructs a {@code JwtGenerator} using the provided parameters. + * + * @param jwtEncoder the jwt encoder + */ + public JwtGenerator(JwtEncoder jwtEncoder) { + Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); + this.jwtEncoder = jwtEncoder; + } + + @Nullable + @Override + public Jwt generate(OAuth2TokenContext context) { + if (context.getTokenType() == null || + (!OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType()) && + !OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue()))) { + return null; + } + + String issuer = null; + if (context.getProviderContext() != null) { + issuer = context.getProviderContext().getIssuer(); + } + RegisteredClient registeredClient = context.getRegisteredClient(); + + Instant issuedAt = Instant.now(); + Instant expiresAt; + if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + // TODO Allow configuration for ID Token time-to-live + expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + } else { + expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive()); + } + + // @formatter:off + JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder(); + if (StringUtils.hasText(issuer)) { + claimsBuilder.issuer(issuer); + } + claimsBuilder + .subject(context.getPrincipal().getName()) + .audience(Collections.singletonList(registeredClient.getClientId())) + .issuedAt(issuedAt) + .expiresAt(expiresAt); + if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) { + claimsBuilder.notBefore(issuedAt); + if (!CollectionUtils.isEmpty(context.getAuthorizedScopes())) { + claimsBuilder.claim(OAuth2ParameterNames.SCOPE, context.getAuthorizedScopes()); + } + } else if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + claimsBuilder.claim(IdTokenClaimNames.AZP, registeredClient.getClientId()); + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType())) { + OAuth2AuthorizationRequest authorizationRequest = context.getAuthorization().getAttribute( + OAuth2AuthorizationRequest.class.getName()); + String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); + if (StringUtils.hasText(nonce)) { + claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce); + } + } + // TODO Add 'auth_time' claim + } + // @formatter:on + + JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256); + + if (this.jwtCustomizer != null) { + // @formatter:off + JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder) + .registeredClient(context.getRegisteredClient()) + .principal(context.getPrincipal()) + .providerContext(context.getProviderContext()) + .authorizedScopes(context.getAuthorizedScopes()) + .tokenType(context.getTokenType()) + .authorizationGrantType(context.getAuthorizationGrantType()); + if (context.getAuthorization() != null) { + jwtContextBuilder.authorization(context.getAuthorization()); + } + if (context.getAuthorizationGrant() != null) { + jwtContextBuilder.authorizationGrant(context.getAuthorizationGrant()); + } + // @formatter:on + + JwtEncodingContext jwtContext = jwtContextBuilder.build(); + this.jwtCustomizer.customize(jwtContext); + } + + JoseHeader headers = headersBuilder.build(); + JwtClaimsSet claims = claimsBuilder.build(); + + Jwt jwt = this.jwtEncoder.encode(headers, claims); + + return jwt; + } + + /** + * Sets the {@link OAuth2TokenCustomizer} that customizes the + * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or + * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}. + * + * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt} + */ + public void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { + Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); + this.jwtCustomizer = jwtCustomizer; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java index e18828094..5bb3fc6e4 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,15 +27,17 @@ import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.context.Context; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.context.ProviderContext; import org.springframework.util.Assert; /** - * A context that holds information associated to an OAuth 2.0 Token - * and is used by an {@link OAuth2TokenCustomizer} for customizing the token attributes. + * A context that holds information (to be) associated to an OAuth 2.0 Token + * and is used by an {@link OAuth2TokenGenerator} and {@link OAuth2TokenCustomizer}. * * @author Joe Grandja * @since 0.1.0 * @see Context + * @see OAuth2TokenGenerator * @see OAuth2TokenCustomizer */ public interface OAuth2TokenContext extends Context { @@ -59,6 +61,16 @@ default T getPrincipal() { return get(AbstractBuilder.PRINCIPAL_AUTHENTICATION_KEY); } + /** + * Returns the {@link ProviderContext provider context}. + * + * @return the {@link ProviderContext} + * @since 0.2.3 + */ + default ProviderContext getProviderContext() { + return get(ProviderContext.class); + } + /** * Returns the {@link OAuth2Authorization authorization}. * @@ -141,6 +153,17 @@ public B principal(Authentication principal) { return put(PRINCIPAL_AUTHENTICATION_KEY, principal); } + /** + * Sets the {@link ProviderContext provider context}. + * + * @param providerContext the {@link ProviderContext} + * @return the {@link AbstractBuilder} for further configuration + * @since 0.2.3 + */ + public B providerContext(ProviderContext providerContext) { + return put(ProviderContext.class, providerContext); + } + /** * Sets the {@link OAuth2Authorization authorization}. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java new file mode 100644 index 000000000..c2bb3a206 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java @@ -0,0 +1,44 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.OAuth2Token; + +/** + * Implementations of this interface are responsible for generating an {@link OAuth2Token} + * using the attributes contained in the {@link OAuth2TokenContext}. + * + * @author Joe Grandja + * @since 0.2.3 + * @see OAuth2Token + * @see OAuth2TokenContext + * @param the type of the OAuth 2.0 Token + */ +@FunctionalInterface +public interface OAuth2TokenGenerator { + + /** + * Generate an OAuth 2.0 Token using the attributes contained in the {@link OAuth2TokenContext}, + * or return {@code null} if the {@link OAuth2TokenContext#getTokenType()} is not supported. + * + * @param context the context containing the OAuth 2.0 Token attributes + * @return an {@link OAuth2Token} or {@code null} if the {@link OAuth2TokenContext#getTokenType()} is not supported + */ + @Nullable + T generate(OAuth2TokenContext context); + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java deleted file mode 100644 index 8d72001e9..000000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.authentication; - -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Collections; -import java.util.Set; - -import org.springframework.security.authentication.AuthenticationProvider; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.JoseHeader; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - -/** - * Utility methods used by the {@link AuthenticationProvider}'s when issuing {@link Jwt}'s. - * - * @author Joe Grandja - * @since 0.1.0 - */ -final class JwtUtils { - - private JwtUtils() { - } - - static JoseHeader.Builder headers() { - return JoseHeader.withAlgorithm(SignatureAlgorithm.RS256); - } - - static JwtClaimsSet.Builder accessTokenClaims(RegisteredClient registeredClient, - String issuer, String subject, Set authorizedScopes) { - - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive()); - - // @formatter:off - JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder(); - if (StringUtils.hasText(issuer)) { - claimsBuilder.issuer(issuer); - } - claimsBuilder - .subject(subject) - .audience(Collections.singletonList(registeredClient.getClientId())) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .notBefore(issuedAt); - if (!CollectionUtils.isEmpty(authorizedScopes)) { - claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes); - } - // @formatter:on - - return claimsBuilder; - } - - static JwtClaimsSet.Builder idTokenClaims(RegisteredClient registeredClient, - String issuer, String subject, String nonce) { - - Instant issuedAt = Instant.now(); - // TODO Allow configuration for ID Token time-to-live - Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); - - // @formatter:off - JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder(); - if (StringUtils.hasText(issuer)) { - claimsBuilder.issuer(issuer); - } - claimsBuilder - .subject(subject) - .audience(Collections.singletonList(registeredClient.getClientId())) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .claim(IdTokenClaimNames.AZP, registeredClient.getClientId()); - if (StringUtils.hasText(nonce)) { - claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce); - } - // TODO Add 'auth_time' claim - // @formatter:on - - return claimsBuilder; - } - -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index ed18bfde2..925cd2363 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -22,7 +22,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -36,22 +35,26 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationCode; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; @@ -70,13 +73,12 @@ * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationCodeRequestAuthenticationProvider * @see OAuth2AuthorizationService - * @see JwtEncoder - * @see OAuth2TokenCustomizer - * @see JwtEncodingContext - * @see Section 4.1 Authorization Code Grant - * @see Section 4.1.3 Access Token Request + * @see OAuth2TokenGenerator + * @see Section 4.1 Authorization Code Grant + * @see Section 4.1.3 Access Token Request */ public final class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = @@ -84,21 +86,37 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); private final OAuth2AuthorizationService authorizationService; - private final JwtEncoder jwtEncoder; - private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; + private final OAuth2TokenGenerator tokenGenerator; private Supplier refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey; /** * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. * + * @deprecated Use {@link #OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead * @param authorizationService the authorization service * @param jwtEncoder the jwt encoder */ + @Deprecated public OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { Assert.notNull(authorizationService, "authorizationService cannot be null"); Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.authorizationService = authorizationService; - this.jwtEncoder = jwtEncoder; + this.tokenGenerator = new JwtGenerator(jwtEncoder); + } + + /** + * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + * @param tokenGenerator the token generator + * @since 0.2.3 + */ + public OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService authorizationService, + OAuth2TokenGenerator tokenGenerator) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); + this.authorizationService = authorizationService; + this.tokenGenerator = tokenGenerator; } /** @@ -106,11 +124,15 @@ public OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}. * + * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt} */ + @Deprecated public void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); - this.jwtCustomizer = jwtCustomizer; + if (this.tokenGenerator instanceof JwtGenerator) { + ((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer); + } } /** @@ -165,96 +187,65 @@ public Authentication authenticate(Authentication authentication) throws Authent throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); } - String issuer = ProviderContextHolder.getProviderContext().getIssuer(); - Set authorizedScopes = authorization.getAttribute( - OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME); - - JoseHeader.Builder headersBuilder = JwtUtils.headers(); - JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims( - registeredClient, issuer, authorization.getPrincipalName(), - authorizedScopes); - // @formatter:off - JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder) + DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) .principal(authorization.getAttribute(Principal.class.getName())) + .providerContext(ProviderContextHolder.getProviderContext()) .authorization(authorization) - .authorizedScopes(authorizedScopes) - .tokenType(OAuth2TokenType.ACCESS_TOKEN) + .authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .authorizationGrant(authorizationCodeAuthentication) - .build(); + .authorizationGrant(authorizationCodeAuthentication); // @formatter:on - this.jwtCustomizer.customize(context); - - JoseHeader headers = context.getHeaders().build(); - JwtClaimsSet claims = context.getClaims().build(); - Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims); + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization); + // ----- Access token ----- + OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build(); + OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext); + if (generatedAccessToken == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the access token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(), - jwtAccessToken.getExpiresAt(), authorizedScopes); + generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(), + generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes()); + if (generatedAccessToken instanceof Jwt) { + authorizationBuilder.token(accessToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims())); + } else { + authorizationBuilder.accessToken(accessToken); + } + // ----- Refresh token ----- OAuth2RefreshToken refreshToken = null; if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) && // Do not issue refresh token to public client !clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) { refreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive()); + authorizationBuilder.refreshToken(refreshToken); } - Jwt jwtIdToken = null; - if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { - String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); - - headersBuilder = JwtUtils.headers(); - claimsBuilder = JwtUtils.idTokenClaims( - registeredClient, issuer, authorization.getPrincipalName(), nonce); - - // @formatter:off - context = JwtEncodingContext.with(headersBuilder, claimsBuilder) - .registeredClient(registeredClient) - .principal(authorization.getAttribute(Principal.class.getName())) - .authorization(authorization) - .authorizedScopes(authorizedScopes) - .tokenType(ID_TOKEN_TOKEN_TYPE) - .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) - .authorizationGrant(authorizationCodeAuthentication) - .build(); - // @formatter:on - - this.jwtCustomizer.customize(context); - - headers = context.getHeaders().build(); - claims = context.getClaims().build(); - jwtIdToken = this.jwtEncoder.encode(headers, claims); - } - + // ----- ID token ----- OidcIdToken idToken; - if (jwtIdToken != null) { - idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), - jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); + if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { + tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build(); + OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext); + if (!(generatedIdToken instanceof Jwt)) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the ID token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } + idToken = new OidcIdToken(generatedIdToken.getTokenValue(), generatedIdToken.getIssuedAt(), + generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims()); + authorizationBuilder.token(idToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())); } else { idToken = null; } - // @formatter:off - OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization) - .token(accessToken, - (metadata) -> - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()) - ); - if (refreshToken != null) { - authorizationBuilder.refreshToken(refreshToken); - } - if (idToken != null) { - authorizationBuilder - .token(idToken, - (metadata) -> - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())); - } authorization = authorizationBuilder.build(); - // @formatter:on // Invalidate the authorization code as it can only be used once authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken()); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java index e23c33f64..d19744b77 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java @@ -25,16 +25,20 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.OAuth2TokenType; -import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; @@ -52,29 +56,44 @@ * @see OAuth2ClientCredentialsAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService - * @see JwtEncoder - * @see OAuth2TokenCustomizer - * @see JwtEncodingContext - * @see Section 4.4 Client Credentials Grant - * @see Section 4.4.2 Access Token Request + * @see OAuth2TokenGenerator + * @see Section 4.4 Client Credentials Grant + * @see Section 4.4.2 Access Token Request */ public final class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; private final OAuth2AuthorizationService authorizationService; - private final JwtEncoder jwtEncoder; - private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; + private final OAuth2TokenGenerator tokenGenerator; /** * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. * + * @deprecated Use {@link #OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead * @param authorizationService the authorization service * @param jwtEncoder the jwt encoder */ + @Deprecated public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { Assert.notNull(authorizationService, "authorizationService cannot be null"); Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.authorizationService = authorizationService; - this.jwtEncoder = jwtEncoder; + this.tokenGenerator = new JwtGenerator(jwtEncoder); + } + + /** + * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + * @param tokenGenerator the token generator + * @since 0.2.3 + */ + public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService, + OAuth2TokenGenerator tokenGenerator) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); + this.authorizationService = authorizationService; + this.tokenGenerator = tokenGenerator; } /** @@ -82,11 +101,15 @@ public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}. * + * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt} */ + @Deprecated public void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); - this.jwtCustomizer = jwtCustomizer; + if (this.tokenGenerator instanceof JwtGenerator) { + ((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer); + } } @Deprecated @@ -116,16 +139,11 @@ public Authentication authenticate(Authentication authentication) throws Authent authorizedScopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes()); } - String issuer = ProviderContextHolder.getProviderContext().getIssuer(); - - JoseHeader.Builder headersBuilder = JwtUtils.headers(); - JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims( - registeredClient, issuer, clientPrincipal.getName(), authorizedScopes); - // @formatter:off - JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder) + OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) .principal(clientPrincipal) + .providerContext(ProviderContextHolder.getProviderContext()) .authorizedScopes(authorizedScopes) .tokenType(OAuth2TokenType.ACCESS_TOKEN) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) @@ -133,26 +151,30 @@ public Authentication authenticate(Authentication authentication) throws Authent .build(); // @formatter:on - this.jwtCustomizer.customize(context); - - JoseHeader headers = context.getHeaders().build(); - JwtClaimsSet claims = context.getClaims().build(); - Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims); - + OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext); + if (generatedAccessToken == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the access token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(), - jwtAccessToken.getExpiresAt(), authorizedScopes); + generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(), + generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes()); // @formatter:off - OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient) + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(clientPrincipal.getName()) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .token(accessToken, - (metadata) -> - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims())) - .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes) - .build(); + .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes); // @formatter:on + if (generatedAccessToken instanceof Jwt) { + authorizationBuilder.token(accessToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims())); + } else { + authorizationBuilder.accessToken(accessToken); + } + + OAuth2Authorization authorization = authorizationBuilder.build(); this.authorizationService.save(authorization); @@ -163,4 +185,5 @@ public Authentication authenticate(Authentication authentication) throws Authent public boolean supports(Class authentication) { return OAuth2ClientCredentialsAuthenticationToken.class.isAssignableFrom(authentication); } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index 0f2a21254..f5059ec72 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -34,23 +34,26 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; -import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; -import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; import org.springframework.util.Assert; @@ -66,33 +69,48 @@ * @see OAuth2RefreshTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken * @see OAuth2AuthorizationService - * @see JwtEncoder - * @see OAuth2TokenCustomizer - * @see JwtEncodingContext - * @see Section 1.5 Refresh Token Grant - * @see Section 6 Refreshing an Access Token + * @see OAuth2TokenGenerator + * @see Section 1.5 Refresh Token Grant + * @see Section 6 Refreshing an Access Token */ public final class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); private final OAuth2AuthorizationService authorizationService; - private final JwtEncoder jwtEncoder; - private OAuth2TokenCustomizer jwtCustomizer = (context) -> {}; + private final OAuth2TokenGenerator tokenGenerator; private Supplier refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey; /** * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters. * + * @deprecated Use {@link #OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead * @param authorizationService the authorization service * @param jwtEncoder the jwt encoder */ + @Deprecated public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { Assert.notNull(authorizationService, "authorizationService cannot be null"); Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.authorizationService = authorizationService; - this.jwtEncoder = jwtEncoder; + this.tokenGenerator = new JwtGenerator(jwtEncoder); + } + + /** + * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters. + * + * @param authorizationService the authorization service + * @param tokenGenerator the token generator + * @since 0.2.3 + */ + public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService, + OAuth2TokenGenerator tokenGenerator) { + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); + this.authorizationService = authorizationService; + this.tokenGenerator = tokenGenerator; } /** @@ -100,11 +118,15 @@ public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService autho * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}. * + * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt} */ + @Deprecated public void setJwtCustomizer(OAuth2TokenCustomizer jwtCustomizer) { Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null"); - this.jwtCustomizer = jwtCustomizer; + if (this.tokenGenerator instanceof JwtGenerator) { + ((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer); + } } /** @@ -164,90 +186,65 @@ public Authentication authenticate(Authentication authentication) throws Authent scopes = authorizedScopes; } - String issuer = ProviderContextHolder.getProviderContext().getIssuer(); - - JoseHeader.Builder headersBuilder = JwtUtils.headers(); - JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims( - registeredClient, issuer, authorization.getPrincipalName(), scopes); - // @formatter:off - JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder) + DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) .principal(authorization.getAttribute(Principal.class.getName())) + .providerContext(ProviderContextHolder.getProviderContext()) .authorization(authorization) - .authorizedScopes(authorizedScopes) - .tokenType(OAuth2TokenType.ACCESS_TOKEN) + .authorizedScopes(scopes) .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) - .authorizationGrant(refreshTokenAuthentication) - .build(); + .authorizationGrant(refreshTokenAuthentication); // @formatter:on - this.jwtCustomizer.customize(context); - - JoseHeader headers = context.getHeaders().build(); - JwtClaimsSet claims = context.getClaims().build(); - Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims); + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization); + // ----- Access token ----- + OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build(); + OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext); + if (generatedAccessToken == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the access token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(), - jwtAccessToken.getExpiresAt(), scopes); - - TokenSettings tokenSettings = registeredClient.getTokenSettings(); - - OAuth2RefreshToken currentRefreshToken = refreshToken.getToken(); - if (!tokenSettings.isReuseRefreshTokens()) { - currentRefreshToken = generateRefreshToken(tokenSettings.getRefreshTokenTimeToLive()); + generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(), + generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes()); + if (generatedAccessToken instanceof Jwt) { + authorizationBuilder.token(accessToken, (metadata) -> { + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims()); + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false); + }); + } else { + authorizationBuilder.accessToken(accessToken); } - Jwt jwtIdToken = null; - if (authorizedScopes.contains(OidcScopes.OPENID)) { - headersBuilder = JwtUtils.headers(); - claimsBuilder = JwtUtils.idTokenClaims( - registeredClient, issuer, authorization.getPrincipalName(), null); - - // @formatter:off - context = JwtEncodingContext.with(headersBuilder, claimsBuilder) - .registeredClient(registeredClient) - .principal(authorization.getAttribute(Principal.class.getName())) - .authorization(authorization) - .authorizedScopes(authorizedScopes) - .tokenType(ID_TOKEN_TOKEN_TYPE) - .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) - .authorizationGrant(refreshTokenAuthentication) - .build(); - // @formatter:on - - this.jwtCustomizer.customize(context); - - headers = context.getHeaders().build(); - claims = context.getClaims().build(); - jwtIdToken = this.jwtEncoder.encode(headers, claims); + // ----- Refresh token ----- + OAuth2RefreshToken currentRefreshToken = refreshToken.getToken(); + if (!registeredClient.getTokenSettings().isReuseRefreshTokens()) { + currentRefreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive()); + authorizationBuilder.refreshToken(currentRefreshToken); } + // ----- ID token ----- OidcIdToken idToken; - if (jwtIdToken != null) { - idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), - jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); + if (authorizedScopes.contains(OidcScopes.OPENID)) { + tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build(); + OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext); + if (!(generatedIdToken instanceof Jwt)) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the ID token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } + idToken = new OidcIdToken(generatedIdToken.getTokenValue(), generatedIdToken.getIssuedAt(), + generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims()); + authorizationBuilder.token(idToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())); } else { idToken = null; } - // @formatter:off - OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization) - .token(accessToken, - (metadata) -> { - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()); - metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false); - }) - .refreshToken(currentRefreshToken); - if (idToken != null) { - authorizationBuilder - .token(idToken, - (metadata) -> - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())); - } authorization = authorizationBuilder.build(); - // @formatter:on this.authorizationService.save(authorization); @@ -271,4 +268,5 @@ private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) { Instant expiresAt = issuedAt.plus(tokenTimeToLive); return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt); } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java deleted file mode 100644 index ef9d88db2..000000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.oidc.authentication; - -import java.time.Instant; -import java.util.Collections; -import java.util.Set; - -import org.springframework.security.authentication.AuthenticationProvider; -import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; -import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.JoseHeader; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.util.CollectionUtils; -import org.springframework.util.StringUtils; - -/** - * TODO - * This class is mostly a straight copy from {@code org.springframework.security.oauth2.server.authorization.authentication.JwtUtils}. - * It should be consolidated when we introduce a token generator abstraction. - * - * Utility methods used by the {@link AuthenticationProvider}'s when issuing {@link Jwt}'s. - * - * @author Ovidiu Popa - * @since 0.2.1 - */ -final class JwtUtils { - - private JwtUtils() { - } - - static JoseHeader.Builder headers() { - return JoseHeader.withAlgorithm(SignatureAlgorithm.RS256); - } - - static JwtClaimsSet.Builder accessTokenClaims(RegisteredClient registeredClient, - String issuer, String subject, Set authorizedScopes) { - - Instant issuedAt = Instant.now(); - Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive()); - - // @formatter:off - JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder(); - if (StringUtils.hasText(issuer)) { - claimsBuilder.issuer(issuer); - } - claimsBuilder - .subject(subject) - .audience(Collections.singletonList(registeredClient.getClientId())) - .issuedAt(issuedAt) - .expiresAt(expiresAt) - .notBefore(issuedAt); - if (!CollectionUtils.isEmpty(authorizedScopes)) { - claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes); - } - // @formatter:on - - return claimsBuilder; - } - -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java index 1c50279b3..4886af967 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java @@ -38,6 +38,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -45,12 +46,15 @@ import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; -import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ClientSettings; @@ -73,11 +77,12 @@ * @since 0.1.1 * @see RegisteredClientRepository * @see OAuth2AuthorizationService - * @see JwtEncoder + * @see OAuth2TokenGenerator * @see 3. Client Registration Endpoint * @see 4. Client Configuration Endpoint */ public final class OidcClientRegistrationAuthenticationProvider implements AuthenticationProvider { + private static final String ERROR_URI = "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError"; private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator( Base64.getUrlEncoder().withoutPadding(), 32); private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator( @@ -86,7 +91,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe private static final String DEFAULT_CLIENT_CONFIGURATION_AUTHORIZED_SCOPE = "client.read"; private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; - private JwtEncoder jwtEncoder; + private OAuth2TokenGenerator tokenGenerator; /** * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters. @@ -110,7 +115,9 @@ public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository r * @param registeredClientRepository the repository of registered clients * @param authorizationService the authorization service * @param jwtEncoder the jwt encoder + * @deprecated Use {@link #OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository, OAuth2AuthorizationService, OAuth2TokenGenerator)} instead */ + @Deprecated public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository, OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) { Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); @@ -118,13 +125,31 @@ public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository r Assert.notNull(jwtEncoder, "jwtEncoder cannot be null"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; - this.jwtEncoder = jwtEncoder; + this.tokenGenerator = new JwtGenerator(jwtEncoder); + } + + /** + * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + * @param tokenGenerator the token generator + * @since 0.2.3 + */ + public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService, OAuth2TokenGenerator tokenGenerator) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(tokenGenerator, "tokenGenerator cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + this.tokenGenerator = tokenGenerator; } @Deprecated @Autowired(required = false) protected void setJwtEncoder(JwtEncoder jwtEncoder) { - this.jwtEncoder = jwtEncoder; + this.tokenGenerator = new JwtGenerator(jwtEncoder); } @Deprecated @@ -227,37 +252,52 @@ private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegis } private OAuth2Authorization registerAccessToken(RegisteredClient registeredClient) { - JoseHeader headers = JwtUtils.headers().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient, + registeredClient.getClientAuthenticationMethods().iterator().next(), registeredClient.getClientSecret()); Set authorizedScopes = new HashSet<>(); authorizedScopes.add(DEFAULT_CLIENT_CONFIGURATION_AUTHORIZED_SCOPE); authorizedScopes = Collections.unmodifiableSet(authorizedScopes); - String issuer = ProviderContextHolder.getProviderContext().getIssuer(); - JwtClaimsSet claims = JwtUtils.accessTokenClaims( - registeredClient, issuer, registeredClient.getClientId(), authorizedScopes) + // @formatter:off + OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() + .registeredClient(registeredClient) + .principal(clientPrincipal) + .providerContext(ProviderContextHolder.getProviderContext()) + .authorizedScopes(authorizedScopes) + .tokenType(OAuth2TokenType.ACCESS_TOKEN) + .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .build(); + // @formatter:on - Jwt registrationAccessToken = this.jwtEncoder.encode(headers, claims); - + OAuth2Token registrationAccessToken = this.tokenGenerator.generate(tokenContext); + if (registrationAccessToken == null) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, + "The token generator failed to generate the registration access token.", ERROR_URI); + throw new OAuth2AuthenticationException(error); + } OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, registrationAccessToken.getTokenValue(), registrationAccessToken.getIssuedAt(), - registrationAccessToken.getExpiresAt(), authorizedScopes); + registrationAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes()); // @formatter:off - OAuth2Authorization registeredClientAuthorization = OAuth2Authorization.withRegisteredClient(registeredClient) + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient) .principalName(registeredClient.getClientId()) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .token(accessToken, - (metadata) -> - metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, registrationAccessToken.getClaims())) - .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes) - .build(); + .attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes); // @formatter:on + if (registrationAccessToken instanceof Jwt) { + authorizationBuilder.token(accessToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) registrationAccessToken).getClaims())); + } else { + authorizationBuilder.accessToken(accessToken); + } + + OAuth2Authorization authorization = authorizationBuilder.build(); - this.authorizationService.save(registeredClientAuthorization); + this.authorizationService.save(authorization); - return registeredClientAuthorization; + return authorization; } private OidcClientRegistration.Builder buildRegistration(RegisteredClient registeredClient) { @@ -445,7 +485,7 @@ private static void throwInvalidClientRegistration(String errorCode, String fiel OAuth2Error error = new OAuth2Error( errorCode, "Invalid Client Registration: " + fieldName, - "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError"); + ERROR_URI); throw new OAuth2AuthenticationException(error); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index a3dd59e4d..6aac62097 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -84,11 +84,14 @@ import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken; @@ -184,6 +187,9 @@ public class OAuth2AuthorizationCodeGrantTests { @Autowired private JwtDecoder jwtDecoder; + @Autowired(required = false) + private OAuth2TokenGenerator tokenGenerator; + @BeforeClass public static void init() { JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); @@ -425,8 +431,8 @@ public void requestWhenConfidentialClientWithPkceAndMissingCodeVerifierThenBadRe } @Test - public void requestWhenCustomJwtEncoderThenUsed() throws Exception { - this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire(); + public void requestWhenCustomTokenGeneratorThenUsed() throws Exception { + this.spring.register(AuthorizationServerConfigurationWithTokenGenerator.class).autowire(); RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); this.registeredClientRepository.save(registeredClient); @@ -436,7 +442,10 @@ public void requestWhenCustomJwtEncoderThenUsed() throws Exception { this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) .params(getTokenRequestParameters(registeredClient, authorization)) - .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))); + .header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient))) + .andExpect(status().isOk()); + + verify(this.tokenGenerator).generate(any()); } @Test @@ -822,12 +831,25 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h @EnableWebSecurity @Import(OAuth2AuthorizationServerConfiguration.class) - static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration { + static class AuthorizationServerConfigurationWithTokenGenerator extends AuthorizationServerConfiguration { @Bean JwtEncoder jwtEncoder() { return jwtEncoder; } + + @Bean + OAuth2TokenGenerator tokenGenerator() { + JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder()); + jwtGenerator.setJwtCustomizer(jwtCustomizer()); + return spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); + } + } @EnableWebSecurity diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java index 76f63a1b6..667ba8b86 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,7 @@ import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.config.test.SpringTestRule; @@ -67,11 +68,16 @@ import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; @@ -79,6 +85,8 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.jackson2.TestingAuthenticationTokenMixin; +import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; @@ -90,6 +98,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; @@ -101,6 +113,7 @@ * Integration tests for OpenID Connect 1.0. * * @author Daniel Garnier-Moiroux + * @author Joe Grandja */ public class OidcTests { private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize"; @@ -132,6 +145,9 @@ public class OidcTests { @Autowired private JwtDecoder jwtDecoder; + @Autowired(required = false) + private OAuth2TokenGenerator tokenGenerator; + @BeforeClass public static void init() { JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); @@ -230,6 +246,25 @@ public void requestWhenAuthenticationRequestThenTokenResponseIncludesIdToken() t assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities); } + @Test + public void requestWhenCustomTokenGeneratorThenUsed() throws Exception { + this.spring.register(AuthorizationServerConfigurationWithTokenGenerator.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + this.registeredClientRepository.save(registeredClient); + + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + this.authorizationService.save(authorization); + + this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorization)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret()))) + .andExpect(status().isOk()); + + verify(this.tokenGenerator, times(2)).generate(any()); + } + private static MultiValueMap getAuthorizationRequestParameters(RegisteredClient registeredClient) { MultiValueMap parameters = new LinkedMultiValueMap<>(); parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); @@ -339,6 +374,46 @@ static class ParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2Autho } + @EnableWebSecurity + static class AuthorizationServerConfigurationWithTokenGenerator extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = + new OAuth2AuthorizationServerConfigurer<>(); + http.apply(authorizationServerConfigurer); + + authorizationServerConfigurer + .tokenGenerator(tokenGenerator()); + + RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher(); + + http + .requestMatcher(endpointsMatcher) + .authorizeRequests(authorizeRequests -> + authorizeRequests.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)); + + return http.build(); + } + // @formatter:on + + @Bean + OAuth2TokenGenerator tokenGenerator() { + JwtGenerator jwtGenerator = new JwtGenerator(new NimbusJwsEncoder(jwkSource())); + jwtGenerator.setJwtCustomizer(jwtCustomizer()); + return spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); + } + + } + @EnableWebSecurity @Import(OAuth2AuthorizationServerConfiguration.class) static class AuthorizationServerConfigurationWithIssuer extends AuthorizationServerConfiguration { @@ -368,4 +443,5 @@ ProviderSettings providerSettings() { return ProviderSettings.builder().issuer("https://not a valid uri").build(); } } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java new file mode 100644 index 000000000..5c9dd224f --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java @@ -0,0 +1,215 @@ +/* + * Copyright 2020-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import java.security.Principal; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.context.ProviderContext; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link JwtGenerator}. + * + * @author Joe Grandja + */ +public class JwtGeneratorTests { + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); + private JwtEncoder jwtEncoder; + private OAuth2TokenCustomizer jwtCustomizer; + private JwtGenerator jwtGenerator; + private ProviderContext providerContext; + + @Before + public void setUp() { + this.jwtEncoder = mock(JwtEncoder.class); + this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); + this.jwtGenerator = new JwtGenerator(this.jwtEncoder); + this.jwtGenerator.setJwtCustomizer(this.jwtCustomizer); + ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build(); + this.providerContext = new ProviderContext(providerSettings, null); + } + + @Test + public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new JwtGenerator(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtEncoder cannot be null"); + } + + @Test + public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.jwtGenerator.setJwtCustomizer(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jwtCustomizer cannot be null"); + } + + @Test + public void generateWhenUnsupportedTokenTypeThenReturnNull() { + // @formatter:off + OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() + .tokenType(new OAuth2TokenType("unsupported_token_type")) + .build(); + // @formatter:on + + assertThat(this.jwtGenerator.generate(tokenContext)).isNull(); + } + + @Test + public void generateWhenAccessTokenTypeThenReturnJwt() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null); + + // @formatter:off + OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() + .registeredClient(registeredClient) + .principal(authorization.getAttribute(Principal.class.getName())) + .providerContext(this.providerContext) + .authorization(authorization) + .authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)) + .tokenType(OAuth2TokenType.ACCESS_TOKEN) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrant(authentication) + .build(); + // @formatter:on + + assertGeneratedTokenType(tokenContext); + } + + @Test + public void generateWhenIdTokenTypeThenReturnJwt() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + Map authenticationRequestAdditionalParameters = new HashMap<>(); + authenticationRequestAdditionalParameters.put(OidcParameterNames.NONCE, "nonce"); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, authenticationRequestAdditionalParameters).build(); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null); + + // @formatter:off + OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() + .registeredClient(registeredClient) + .principal(authorization.getAttribute(Principal.class.getName())) + .providerContext(this.providerContext) + .authorization(authorization) + .authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)) + .tokenType(ID_TOKEN_TOKEN_TYPE) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .authorizationGrant(authentication) + .build(); + // @formatter:on + + assertGeneratedTokenType(tokenContext); + } + + private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { + this.jwtGenerator.generate(tokenContext); + + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture()); + + JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue(); + assertThat(jwtEncodingContext.getHeaders()).isNotNull(); + assertThat(jwtEncodingContext.getClaims()).isNotNull(); + assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(tokenContext.getRegisteredClient()); + assertThat(jwtEncodingContext.getPrincipal()).isEqualTo(tokenContext.getPrincipal()); + assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(tokenContext.getAuthorization()); + assertThat(jwtEncodingContext.getAuthorizedScopes()).isEqualTo(tokenContext.getAuthorizedScopes()); + assertThat(jwtEncodingContext.getTokenType()).isEqualTo(tokenContext.getTokenType()); + assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(tokenContext.getAuthorizationGrantType()); + assertThat(jwtEncodingContext.getAuthorizationGrant()).isEqualTo(tokenContext.getAuthorizationGrant()); + + ArgumentCaptor joseHeaderCaptor = ArgumentCaptor.forClass(JoseHeader.class); + ArgumentCaptor jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class); + verify(this.jwtEncoder).encode(joseHeaderCaptor.capture(), jwtClaimsSetCaptor.capture()); + + JoseHeader joseHeader = joseHeaderCaptor.getValue(); + assertThat(joseHeader.getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); + + JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue(); + assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getProviderContext().getIssuer()); + assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName()); + assertThat(jwtClaimsSet.getAudience()).containsExactly(tokenContext.getRegisteredClient().getClientId()); + + Instant issuedAt = Instant.now(); + Instant expiresAt; + if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { + expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive()); + } else { + expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + } + assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1)); + assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1)); + + if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { + assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1)); + + Set scopes = jwtClaimsSet.getClaim(OAuth2ParameterNames.SCOPE); + assertThat(scopes).isEqualTo(tokenContext.getAuthorizedScopes()); + } else { + assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.AZP)).isEqualTo(tokenContext.getRegisteredClient().getClientId()); + + OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute( + OAuth2AuthorizationRequest.class.getName()); + String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); + assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce); + } + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 48a28debe..60618e04a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -49,9 +49,12 @@ import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -65,6 +68,7 @@ import static org.assertj.core.api.Assertions.entry; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -83,16 +87,24 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; private OAuth2TokenCustomizer jwtCustomizer; + private OAuth2TokenGenerator tokenGenerator; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @Before public void setUp() { this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); - this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( - this.authorizationService, this.jwtEncoder); this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); - this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); + JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder); + jwtGenerator.setJwtCustomizer(this.jwtCustomizer); + this.tokenGenerator = spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); + this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( + this.authorizationService, this.tokenGenerator); ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build(); ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null)); } @@ -111,11 +123,18 @@ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentExcep @Test public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, null)) + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, (JwtEncoder) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("jwtEncoder cannot be null"); } + @Test + public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenGenerator cannot be null"); + } + @Test public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) @@ -273,6 +292,74 @@ public void authenticateWhenExpiredCodeThenThrowOAuth2AuthenticationException() .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } + @Test + public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + doAnswer(answer -> { + OAuth2TokenContext context = answer.getArgument(0); + if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) { + return null; + } else { + return answer.callRealMethod(); + } + }).when(this.tokenGenerator).generate(any()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the access token."); + }); + } + + @Test + public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute( + OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationCodeAuthenticationToken authentication = + new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null); + + when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt()); + + doAnswer(answer -> { + OAuth2TokenContext context = answer.getArgument(0); + if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + return null; + } else { + return answer.callRealMethod(); + } + }).when(this.tokenGenerator).generate(any()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the ID token."); + }); + } + @Test public void authenticateWhenValidCodeThenReturnAccessToken() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java index e1e96a3f1..98940f7c7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java @@ -38,9 +38,12 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; @@ -50,7 +53,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -64,16 +69,24 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests { private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; private OAuth2TokenCustomizer jwtCustomizer; + private OAuth2TokenGenerator tokenGenerator; private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider; @Before public void setUp() { this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); - this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( - this.authorizationService, this.jwtEncoder); this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); - this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); + JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder); + jwtGenerator.setJwtCustomizer(this.jwtCustomizer); + this.tokenGenerator = spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); + this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider( + this.authorizationService, this.tokenGenerator); ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build(); ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null)); } @@ -92,11 +105,18 @@ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentExcep @Test public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, null)) + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, (JwtEncoder) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("jwtEncoder cannot be null"); } + @Test + public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenGenerator cannot be null"); + } + @Test public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) @@ -193,6 +213,25 @@ public void authenticateWhenScopeRequestedThenAccessTokenContainsScope() { assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScope); } + @Test + public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2ClientCredentialsAuthenticationToken authentication = + new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null); + + doReturn(null).when(this.tokenGenerator).generate(any()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the access token."); + }); + } + @Test public void authenticateWhenValidAuthenticationThenReturnAccessToken() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index 514a7320d..9c0dc3add 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -47,9 +47,12 @@ import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -58,11 +61,12 @@ import org.springframework.security.oauth2.server.authorization.context.ProviderContext; import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.Assertions.entry; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; @@ -81,6 +85,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; private OAuth2TokenCustomizer jwtCustomizer; + private OAuth2TokenGenerator tokenGenerator; private OAuth2RefreshTokenAuthenticationProvider authenticationProvider; @Before @@ -88,10 +93,17 @@ public void setUp() { this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(Collections.singleton("scope1"))); - this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( - this.authorizationService, this.jwtEncoder); this.jwtCustomizer = mock(OAuth2TokenCustomizer.class); - this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer); + JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder); + jwtGenerator.setJwtCustomizer(this.jwtCustomizer); + this.tokenGenerator = spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); + this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider( + this.authorizationService, this.tokenGenerator); ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build(); ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null)); } @@ -111,12 +123,19 @@ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentExcep @Test public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, null)) + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, (JwtEncoder) null)) .isInstanceOf(IllegalArgumentException.class) .extracting(Throwable::getMessage) .isEqualTo("jwtEncoder cannot be null"); } + @Test + public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("tokenGenerator cannot be null"); + } + @Test public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null)) @@ -500,6 +519,70 @@ public void authenticateWhenRevokedRefreshTokenThenThrowOAuth2AuthenticationExce .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); } + @Test + public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getRefreshToken().getToken().getTokenValue()), + eq(OAuth2TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); + + doAnswer(answer -> { + OAuth2TokenContext context = answer.getArgument(0); + if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) { + return null; + } else { + return answer.callRealMethod(); + } + }).when(this.tokenGenerator).generate(any()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the access token."); + }); + } + + @Test + public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getRefreshToken().getToken().getTokenValue()), + eq(OAuth2TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken( + registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); + + doAnswer(answer -> { + OAuth2TokenContext context = answer.getArgument(0); + if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + return null; + } else { + return answer.callRealMethod(); + } + }).when(this.tokenGenerator).generate(any()); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the ID token."); + }); + } + private static Jwt createJwt(Set scope) { Instant issuedAt = Instant.now(); Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java index aab87185e..1574f4cd7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java @@ -48,8 +48,11 @@ import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.TestJoseHeaders; import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.JwtGenerator; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; @@ -66,8 +69,10 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -82,6 +87,7 @@ public class OidcClientRegistrationAuthenticationProviderTests { private RegisteredClientRepository registeredClientRepository; private OAuth2AuthorizationService authorizationService; private JwtEncoder jwtEncoder; + private OAuth2TokenGenerator tokenGenerator; private ProviderSettings providerSettings; private OidcClientRegistrationAuthenticationProvider authenticationProvider; @@ -90,10 +96,17 @@ public void setUp() { this.registeredClientRepository = mock(RegisteredClientRepository.class); this.authorizationService = mock(OAuth2AuthorizationService.class); this.jwtEncoder = mock(JwtEncoder.class); + JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder); + this.tokenGenerator = spy(new OAuth2TokenGenerator() { + @Override + public Jwt generate(OAuth2TokenContext context) { + return jwtGenerator.generate(context); + } + }); this.providerSettings = ProviderSettings.builder().issuer("https://provider.com").build(); ProviderContextHolder.setProviderContext(new ProviderContext(this.providerSettings, null)); this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider( - this.registeredClientRepository, this.authorizationService, this.jwtEncoder); + this.registeredClientRepository, this.authorizationService, this.tokenGenerator); } @After @@ -118,10 +131,17 @@ public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentExcep @Test public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, null)) + .isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, (JwtEncoder) null)) .withMessage("jwtEncoder cannot be null"); } + @Test + public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, (OAuth2TokenGenerator) null)) + .withMessage("tokenGenerator cannot be null"); + } + @Test public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue(); @@ -464,6 +484,46 @@ public void authenticateWhenClientRegistrationRequestAndTokenEndpointAuthenticat .isEqualTo(SignatureAlgorithm.RS256.getName()); } + @Test + public void authenticateWhenRegistrationAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + doReturn(null).when(this.tokenGenerator).generate(any()); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR); + assertThat(error.getDescription()).contains("The token generator failed to generate the registration access token."); + }); + } + @Test public void authenticateWhenClientRegistrationRequestAndValidAccessTokenThenReturnClientRegistration() { Jwt jwt = createJwtClientRegistration();