From 1591bd158a10bd8f56dd597359bcc397b1c54999 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 30 Jan 2023 10:47:36 -0500 Subject: [PATCH 1/5] Add OpenID Connect 1.0 Logout Endpoint Closes gh-266 --- .../java/sample/jpa/entity/client/Client.java | 12 +- .../AuthorizationRepository.java | 8 +- .../JpaOAuth2AuthorizationService.java | 7 +- .../client/JpaRegisteredClientRepository.java | 6 +- .../gettingStarted/SecurityConfigTests.java | 5 +- .../src/test/java/sample/jpa/JpaTests.java | 5 +- .../java/sample/util/RegisteredClients.java | 3 +- docs/src/docs/asciidoc/guides/how-to-jpa.adoc | 1 + .../InMemoryOAuth2AuthorizationService.java | 13 +- .../JdbcOAuth2AuthorizationService.java | 10 +- .../JdbcRegisteredClientRepository.java | 10 +- .../client/RegisteredClient.java | 67 +++++- .../AuthorizationServerContextFilter.java | 23 +- ...OAuth2AuthorizationEndpointConfigurer.java | 11 +- .../OAuth2AuthorizationServerConfigurer.java | 30 ++- .../configurers/OAuth2ConfigurerUtils.java | 30 ++- .../web/configurers/OidcConfigurer.java | 16 +- .../OidcLogoutEndpointConfigurer.java | 218 +++++++++++++++++ .../context/AuthorizationServerContext.java | 15 +- .../oidc/OidcClientMetadataClaimAccessor.java | 15 +- .../oidc/OidcClientMetadataClaimNames.java | 11 +- .../oidc/OidcClientRegistration.java | 38 ++- .../oidc/OidcProviderConfiguration.java | 17 +- .../OidcProviderMetadataClaimAccessor.java | 14 +- .../oidc/OidcProviderMetadataClaimNames.java | 9 +- ...entRegistrationAuthenticationProvider.java | 11 +- .../OidcLogoutAuthenticationProvider.java | 163 +++++++++++++ .../OidcLogoutAuthenticationToken.java | 170 ++++++++++++++ ...ClientOidcClientRegistrationConverter.java | 7 +- ...lientRegistrationHttpMessageConverter.java | 3 +- .../oidc/web/OidcLogoutEndpointFilter.java | 221 ++++++++++++++++++ ...dcProviderConfigurationEndpointFilter.java | 3 +- .../OidcLogoutAuthenticationConverter.java | 111 +++++++++ .../settings/AuthorizationServerSettings.java | 28 ++- .../settings/ConfigurationSettingNames.java | 8 +- .../authorization/token/JwtGenerator.java | 31 ++- .../OAuth2AuthorizationEndpointFilter.java | 19 ++ .../oauth2-registered-client-schema.sql | 1 + ...MemoryOAuth2AuthorizationServiceTests.java | 28 ++- .../JdbcOAuth2AuthorizationServiceTests.java | 34 ++- .../JdbcRegisteredClientRepositoryTests.java | 7 +- .../client/RegisteredClientTests.java | 43 +++- .../client/TestRegisteredClients.java | 2 + .../OidcProviderConfigurationTests.java | 3 +- .../TestAuthorizationServerContext.java | 14 +- .../oidc/OidcClientRegistrationTests.java | 41 +++- .../oidc/OidcProviderConfigurationTests.java | 18 +- ...gistrationAuthenticationProviderTests.java | 78 ++++++- ...RegistrationHttpMessageConverterTests.java | 8 +- ...viderConfigurationEndpointFilterTests.java | 5 +- .../AuthorizationServerSettingsTests.java | 16 +- .../token/JwtGeneratorTests.java | 31 ++- ...Auth2AuthorizationEndpointFilterTests.java | 33 +++ ...custom-oauth2-registered-client-schema.sql | 1 + .../config/AuthorizationServerConfig.java | 3 +- .../sample/config/DefaultSecurityConfig.java | 15 +- .../config/AuthorizationServerConfig.java | 3 +- .../sample/config/DefaultSecurityConfig.java | 15 +- .../config/AuthorizationServerConfig.java | 3 +- .../sample/config/DefaultSecurityConfig.java | 15 +- .../java/sample/config/SecurityConfig.java | 24 +- .../src/main/resources/templates/index.html | 11 +- 62 files changed, 1740 insertions(+), 81 deletions(-) create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationToken.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java create mode 100644 oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java diff --git a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/entity/client/Client.java b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/entity/client/Client.java index b4cfb529a..d8885c876 100644 --- a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/entity/client/Client.java +++ b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/entity/client/Client.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,6 +39,8 @@ public class Client { @Column(length = 1000) private String redirectUris; @Column(length = 1000) + private String postLogoutRedirectUris; + @Column(length = 1000) private String scopes; @Column(length = 2000) private String clientSettings; @@ -118,6 +120,14 @@ public void setRedirectUris(String redirectUris) { this.redirectUris = redirectUris; } + public String getPostLogoutRedirectUris() { + return this.postLogoutRedirectUris; + } + + public void setPostLogoutRedirectUris(String postLogoutRedirectUris) { + this.postLogoutRedirectUris = postLogoutRedirectUris; + } + public String getScopes() { return scopes; } diff --git a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/repository/authorization/AuthorizationRepository.java b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/repository/authorization/AuthorizationRepository.java index d7dbd7a33..5f5429856 100644 --- a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/repository/authorization/AuthorizationRepository.java +++ b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/repository/authorization/AuthorizationRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,10 +30,12 @@ public interface AuthorizationRepository extends JpaRepository findByAuthorizationCodeValue(String authorizationCode); Optional findByAccessTokenValue(String accessToken); Optional findByRefreshTokenValue(String refreshToken); + Optional findByOidcIdTokenValue(String idToken); @Query("select a from Authorization a where a.state = :token" + " or a.authorizationCodeValue = :token" + " or a.accessTokenValue = :token" + - " or a.refreshTokenValue = :token" + " or a.refreshTokenValue = :token" + + " or a.oidcIdTokenValue = :token" ) - Optional findByStateOrAuthorizationCodeValueOrAccessTokenValueOrRefreshTokenValue(@Param("token") String token); + Optional findByStateOrAuthorizationCodeValueOrAccessTokenValueOrRefreshTokenValueOrOidcIdTokenValue(@Param("token") String token); } diff --git a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/authorization/JpaOAuth2AuthorizationService.java b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/authorization/JpaOAuth2AuthorizationService.java index 14da3c4df..a4428a87a 100644 --- a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/authorization/JpaOAuth2AuthorizationService.java +++ b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/authorization/JpaOAuth2AuthorizationService.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; @@ -88,7 +89,7 @@ public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) Optional result; if (tokenType == null) { - result = this.authorizationRepository.findByStateOrAuthorizationCodeValueOrAccessTokenValueOrRefreshTokenValue(token); + result = this.authorizationRepository.findByStateOrAuthorizationCodeValueOrAccessTokenValueOrRefreshTokenValueOrOidcIdTokenValue(token); } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { result = this.authorizationRepository.findByState(token); } else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) { @@ -97,6 +98,8 @@ public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) result = this.authorizationRepository.findByAccessTokenValue(token); } else if (OAuth2ParameterNames.REFRESH_TOKEN.equals(tokenType.getValue())) { result = this.authorizationRepository.findByRefreshTokenValue(token); + } else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) { + result = this.authorizationRepository.findByOidcIdTokenValue(token); } else { result = Optional.empty(); } diff --git a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/client/JpaRegisteredClientRepository.java b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/client/JpaRegisteredClientRepository.java index 1db232aaa..20f14e915 100644 --- a/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/client/JpaRegisteredClientRepository.java +++ b/docs/src/docs/asciidoc/examples/src/main/java/sample/jpa/service/client/JpaRegisteredClientRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2022 the original author or authors. + * Copyright 2022-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,6 +78,8 @@ private RegisteredClient toObject(Client client) { client.getAuthorizationGrantTypes()); Set redirectUris = StringUtils.commaDelimitedListToSet( client.getRedirectUris()); + Set postLogoutRedirectUris = StringUtils.commaDelimitedListToSet( + client.getPostLogoutRedirectUris()); Set clientScopes = StringUtils.commaDelimitedListToSet( client.getScopes()); @@ -94,6 +96,7 @@ private RegisteredClient toObject(Client client) { authorizationGrantTypes.forEach(grantType -> grantTypes.add(resolveAuthorizationGrantType(grantType)))) .redirectUris((uris) -> uris.addAll(redirectUris)) + .postLogoutRedirectUris((uris) -> uris.addAll(postLogoutRedirectUris)) .scopes((scopes) -> scopes.addAll(clientScopes)); Map clientSettingsMap = parseMap(client.getClientSettings()); @@ -124,6 +127,7 @@ private Client toEntity(RegisteredClient registeredClient) { entity.setClientAuthenticationMethods(StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethods)); entity.setAuthorizationGrantTypes(StringUtils.collectionToCommaDelimitedString(authorizationGrantTypes)); entity.setRedirectUris(StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())); + entity.setPostLogoutRedirectUris(StringUtils.collectionToCommaDelimitedString(registeredClient.getPostLogoutRedirectUris())); entity.setScopes(StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())); entity.setClientSettings(writeMap(registeredClient.getClientSettings().getSettings())); entity.setTokenSettings(writeMap(registeredClient.getTokenSettings().getSettings())); diff --git a/docs/src/docs/asciidoc/examples/src/test/java/sample/gettingStarted/SecurityConfigTests.java b/docs/src/docs/asciidoc/examples/src/test/java/sample/gettingStarted/SecurityConfigTests.java index 22993e032..ed04a0438 100644 --- a/docs/src/docs/asciidoc/examples/src/test/java/sample/gettingStarted/SecurityConfigTests.java +++ b/docs/src/docs/asciidoc/examples/src/test/java/sample/gettingStarted/SecurityConfigTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -102,7 +102,8 @@ public void oidcLoginWhenGettingStartedConfigUsedThenSuccess() throws Exception assertThatAuthorization(refreshToken, null).isNotNull(); String idToken = (String) tokenResponse.get(OidcParameterNames.ID_TOKEN); - assertThatAuthorization(idToken, OidcParameterNames.ID_TOKEN).isNull(); // id_token is not searchable + assertThatAuthorization(idToken, OidcParameterNames.ID_TOKEN).isNotNull(); + assertThatAuthorization(idToken, null).isNotNull(); OAuth2Authorization authorization = findAuthorization(accessToken, OAuth2ParameterNames.ACCESS_TOKEN); assertThat(authorization.getToken(idToken)).isNotNull(); diff --git a/docs/src/docs/asciidoc/examples/src/test/java/sample/jpa/JpaTests.java b/docs/src/docs/asciidoc/examples/src/test/java/sample/jpa/JpaTests.java index 8e95fbb6d..a11b78df2 100644 --- a/docs/src/docs/asciidoc/examples/src/test/java/sample/jpa/JpaTests.java +++ b/docs/src/docs/asciidoc/examples/src/test/java/sample/jpa/JpaTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -117,7 +117,8 @@ public void oidcLoginWhenJpaCoreServicesAutowiredThenUsed() throws Exception { assertThatAuthorization(refreshToken, null).isNotNull(); String idToken = (String) tokenResponse.get(OidcParameterNames.ID_TOKEN); - assertThatAuthorization(idToken, OidcParameterNames.ID_TOKEN).isNull(); // id_token is not searchable + assertThatAuthorization(idToken, OidcParameterNames.ID_TOKEN).isNotNull(); + assertThatAuthorization(idToken, null).isNotNull(); OAuth2Authorization authorization = findAuthorization(accessToken, OAuth2ParameterNames.ACCESS_TOKEN); assertThat(authorization.getToken(idToken)).isNotNull(); diff --git a/docs/src/docs/asciidoc/examples/src/test/java/sample/util/RegisteredClients.java b/docs/src/docs/asciidoc/examples/src/test/java/sample/util/RegisteredClients.java index c4af4b722..28d96b854 100644 --- a/docs/src/docs/asciidoc/examples/src/test/java/sample/util/RegisteredClients.java +++ b/docs/src/docs/asciidoc/examples/src/test/java/sample/util/RegisteredClients.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ public static RegisteredClient messagingClient() { .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .redirectUri("http://127.0.0.1:8080/authorized") + .postLogoutRedirectUri("http://127.0.0.1:8080/index") .scope(OidcScopes.OPENID) .scope("message.read") .scope("message.write") diff --git a/docs/src/docs/asciidoc/guides/how-to-jpa.adoc b/docs/src/docs/asciidoc/guides/how-to-jpa.adoc index 89babeede..69b7376de 100644 --- a/docs/src/docs/asciidoc/guides/how-to-jpa.adoc +++ b/docs/src/docs/asciidoc/guides/how-to-jpa.adoc @@ -45,6 +45,7 @@ CREATE TABLE client ( clientAuthenticationMethods varchar(1000) NOT NULL, authorizationGrantTypes varchar(1000) NOT NULL, redirectUris varchar(1000) DEFAULT NULL, + postLogoutRedirectUris varchar(1000) DEFAULT NULL, scopes varchar(1000) NOT NULL, clientSettings varchar(2000) NOT NULL, tokenSettings varchar(2000) NOT NULL, diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java index 1cf2b91ae..2d2ca5fb1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,8 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.util.Assert; /** @@ -150,6 +152,7 @@ private static boolean hasToken(OAuth2Authorization authorization, String token, return matchesState(authorization, token) || matchesAuthorizationCode(authorization, token) || matchesAccessToken(authorization, token) || + matchesIdToken(authorization, token) || matchesRefreshToken(authorization, token); } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { return matchesState(authorization, token); @@ -157,6 +160,8 @@ private static boolean hasToken(OAuth2Authorization authorization, String token, return matchesAuthorizationCode(authorization, token); } else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) { return matchesAccessToken(authorization, token); + } else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) { + return matchesIdToken(authorization, token); } else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) { return matchesRefreshToken(authorization, token); } @@ -185,6 +190,12 @@ private static boolean matchesRefreshToken(OAuth2Authorization authorization, St return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token); } + private static boolean matchesIdToken(OAuth2Authorization authorization, String token) { + OAuth2Authorization.Token idToken = + authorization.getToken(OidcIdToken.class); + return idToken != null && idToken.getToken().getTokenValue().equals(token); + } + private static final class MaxSizeHashMap extends LinkedHashMap { private final int maxSize; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java index e053a99c5..bb9276715 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,6 +53,7 @@ import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.jackson2.OAuth2AuthorizationServerJackson2Module; @@ -112,11 +113,12 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic private static final String PK_FILTER = "id = ?"; private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR " + - "access_token_value = ? OR refresh_token_value = ?"; + "access_token_value = ? OR oidc_id_token_value = ? OR refresh_token_value = ?"; private static final String STATE_FILTER = "state = ?"; private static final String AUTHORIZATION_CODE_FILTER = "authorization_code_value = ?"; private static final String ACCESS_TOKEN_FILTER = "access_token_value = ?"; + private static final String ID_TOKEN_FILTER = "oidc_id_token_value = ?"; private static final String REFRESH_TOKEN_FILTER = "refresh_token_value = ?"; // @formatter:off @@ -240,6 +242,7 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t parameters.add(new SqlParameterValue(Types.VARCHAR, token)); parameters.add(mapToSqlParameter("authorization_code_value", token)); parameters.add(mapToSqlParameter("access_token_value", token)); + parameters.add(mapToSqlParameter("oidc_id_token_value", token)); parameters.add(mapToSqlParameter("refresh_token_value", token)); return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters); } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { @@ -251,6 +254,9 @@ public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType t } else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) { parameters.add(mapToSqlParameter("access_token_value", token)); return findBy(ACCESS_TOKEN_FILTER, parameters); + } else if (OidcParameterNames.ID_TOKEN.equals(tokenType.getValue())) { + parameters.add(mapToSqlParameter("oidc_id_token_value", token)); + return findBy(ID_TOKEN_FILTER, parameters); } else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) { parameters.add(mapToSqlParameter("refresh_token_value", token)); return findBy(REFRESH_TOKEN_FILTER, parameters); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java index 32e2442b0..3da2d3703 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,6 +77,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor + "client_authentication_methods, " + "authorization_grant_types, " + "redirect_uris, " + + "post_logout_redirect_uris, " + "scopes, " + "client_settings," + "token_settings"; @@ -90,13 +91,13 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor // @formatter:off private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME - + "(" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + "(" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; // @formatter:on // @formatter:off private static final String UPDATE_REGISTERED_CLIENT_SQL = "UPDATE " + TABLE_NAME + " SET client_name = ?, client_authentication_methods = ?, authorization_grant_types = ?," - + " redirect_uris = ?, scopes = ?, client_settings = ?, token_settings = ?" + + " redirect_uris = ?, post_logout_redirect_uris = ?, scopes = ?, client_settings = ?, token_settings = ?" + " WHERE " + PK_FILTER; // @formatter:on @@ -241,6 +242,7 @@ public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { Set clientAuthenticationMethods = StringUtils.commaDelimitedListToSet(rs.getString("client_authentication_methods")); Set authorizationGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorization_grant_types")); Set redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirect_uris")); + Set postLogoutRedirectUris = StringUtils.commaDelimitedListToSet(rs.getString("post_logout_redirect_uris")); Set clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes")); // @formatter:off @@ -257,6 +259,7 @@ public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { authorizationGrantTypes.forEach(grantType -> grantTypes.add(resolveAuthorizationGrantType(grantType)))) .redirectUris((uris) -> uris.addAll(redirectUris)) + .postLogoutRedirectUris((uris) -> uris.addAll(postLogoutRedirectUris)) .scopes((scopes) -> scopes.addAll(clientScopes)); // @formatter:on @@ -354,6 +357,7 @@ public List apply(RegisteredClient registeredClient) { new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethods)), new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypes)), new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())), + new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getPostLogoutRedirectUris())), new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())), new SqlParameterValue(Types.VARCHAR, writeMap(registeredClient.getClientSettings().getSettings())), new SqlParameterValue(Types.VARCHAR, writeMap(registeredClient.getTokenSettings().getSettings()))); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java index 42df4c3cb..3b7bcd09c 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ public class RegisteredClient implements Serializable { private Set clientAuthenticationMethods; private Set authorizationGrantTypes; private Set redirectUris; + private Set postLogoutRedirectUris; private Set scopes; private ClientSettings clientSettings; private TokenSettings tokenSettings; @@ -145,6 +146,18 @@ public Set getRedirectUris() { return this.redirectUris; } + /** + * Returns the post logout redirect URI(s) that the client may use for logout. + * The {@code post_logout_redirect_uri} parameter is used by the client when requesting + * that the End-User's User Agent be redirected to after a logout has been performed. + * + * @return the {@code Set} of post logout redirect URI(s) + * @since 1.1.0 + */ + public Set getPostLogoutRedirectUris() { + return this.postLogoutRedirectUris; + } + /** * Returns the scope(s) that the client may use. * @@ -190,6 +203,7 @@ public boolean equals(Object obj) { Objects.equals(this.clientAuthenticationMethods, that.clientAuthenticationMethods) && Objects.equals(this.authorizationGrantTypes, that.authorizationGrantTypes) && Objects.equals(this.redirectUris, that.redirectUris) && + Objects.equals(this.postLogoutRedirectUris, that.postLogoutRedirectUris) && Objects.equals(this.scopes, that.scopes) && Objects.equals(this.clientSettings, that.clientSettings) && Objects.equals(this.tokenSettings, that.tokenSettings); @@ -199,7 +213,7 @@ public boolean equals(Object obj) { public int hashCode() { return Objects.hash(this.id, this.clientId, this.clientIdIssuedAt, this.clientSecret, this.clientSecretExpiresAt, this.clientName, this.clientAuthenticationMethods, this.authorizationGrantTypes, this.redirectUris, - this.scopes, this.clientSettings, this.tokenSettings); + this.postLogoutRedirectUris, this.scopes, this.clientSettings, this.tokenSettings); } @Override @@ -211,6 +225,7 @@ public String toString() { ", clientAuthenticationMethods=" + this.clientAuthenticationMethods + ", authorizationGrantTypes=" + this.authorizationGrantTypes + ", redirectUris=" + this.redirectUris + + ", postLogoutRedirectUris=" + this.postLogoutRedirectUris + ", scopes=" + this.scopes + ", clientSettings=" + this.clientSettings + ", tokenSettings=" + this.tokenSettings + @@ -253,6 +268,7 @@ public static class Builder implements Serializable { private final Set clientAuthenticationMethods = new HashSet<>(); private final Set authorizationGrantTypes = new HashSet<>(); private final Set redirectUris = new HashSet<>(); + private final Set postLogoutRedirectUris = new HashSet<>(); private final Set scopes = new HashSet<>(); private ClientSettings clientSettings; private TokenSettings tokenSettings; @@ -277,6 +293,9 @@ protected Builder(RegisteredClient registeredClient) { if (!CollectionUtils.isEmpty(registeredClient.getRedirectUris())) { this.redirectUris.addAll(registeredClient.getRedirectUris()); } + if (!CollectionUtils.isEmpty(registeredClient.getPostLogoutRedirectUris())) { + this.postLogoutRedirectUris.addAll(registeredClient.getPostLogoutRedirectUris()); + } if (!CollectionUtils.isEmpty(registeredClient.getScopes())) { this.scopes.addAll(registeredClient.getScopes()); } @@ -421,6 +440,33 @@ public Builder redirectUris(Consumer> redirectUrisConsumer) { return this; } + /** + * Adds a post logout redirect URI the client may use for logout. + * The {@code post_logout_redirect_uri} parameter is used by the client when requesting + * that the End-User's User Agent be redirected to after a logout has been performed. + * + * @param postLogoutRedirectUri the post logout redirect URI + * @return the {@link Builder} + * @since 1.1.0 + */ + public Builder postLogoutRedirectUri(String postLogoutRedirectUri) { + this.postLogoutRedirectUris.add(postLogoutRedirectUri); + return this; + } + + /** + * A {@code Consumer} of the post logout redirect URI(s) + * allowing the ability to add, replace, or remove. + * + * @param postLogoutRedirectUrisConsumer a {@link Consumer} of the post logout redirect URI(s) + * @return the {@link Builder} + * @since 1.1.0 + */ + public Builder postLogoutRedirectUris(Consumer> postLogoutRedirectUrisConsumer) { + postLogoutRedirectUrisConsumer.accept(this.postLogoutRedirectUris); + return this; + } + /** * Adds a scope the client may use. * @@ -499,6 +545,7 @@ public RegisteredClient build() { } validateScopes(); validateRedirectUris(); + validatePostLogoutRedirectUris(); return create(); } @@ -523,6 +570,8 @@ private RegisteredClient create() { new HashSet<>(this.authorizationGrantTypes)); registeredClient.redirectUris = Collections.unmodifiableSet( new HashSet<>(this.redirectUris)); + registeredClient.postLogoutRedirectUris = Collections.unmodifiableSet( + new HashSet<>(this.postLogoutRedirectUris)); registeredClient.scopes = Collections.unmodifiableSet( new HashSet<>(this.scopes)); registeredClient.clientSettings = this.clientSettings; @@ -557,12 +606,23 @@ private void validateRedirectUris() { return; } - for (String redirectUri : redirectUris) { + for (String redirectUri : this.redirectUris) { Assert.isTrue(validateRedirectUri(redirectUri), "redirect_uri \"" + redirectUri + "\" is not a valid redirect URI or contains fragment"); } } + private void validatePostLogoutRedirectUris() { + if (CollectionUtils.isEmpty(this.postLogoutRedirectUris)) { + return; + } + + for (String postLogoutRedirectUri : this.postLogoutRedirectUris) { + Assert.isTrue(validateRedirectUri(postLogoutRedirectUri), + "post_logout_redirect_uri \"" + postLogoutRedirectUri + "\" is not a valid post logout redirect URI or contains fragment"); + } + } + private static boolean validateRedirectUri(String redirectUri) { try { URI validRedirectUri = new URI(redirectUri); @@ -571,5 +631,6 @@ private static boolean validateRedirectUri(String redirectUri) { return false; } } + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java index aa6cee69f..b08faee8d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.springframework.lang.Nullable; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; @@ -42,21 +44,27 @@ */ final class AuthorizationServerContextFilter extends OncePerRequestFilter { private final AuthorizationServerSettings authorizationServerSettings; + private SessionRegistry sessionRegistry; AuthorizationServerContextFilter(AuthorizationServerSettings authorizationServerSettings) { Assert.notNull(authorizationServerSettings, "authorizationServerSettings cannot be null"); this.authorizationServerSettings = authorizationServerSettings; } + void setSessionRegistry(SessionRegistry sessionRegistry) { + this.sessionRegistry = sessionRegistry; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { try { - AuthorizationServerContext authorizationServerContext = + DefaultAuthorizationServerContext authorizationServerContext = new DefaultAuthorizationServerContext( () -> resolveIssuer(this.authorizationServerSettings, request), this.authorizationServerSettings); + authorizationServerContext.setSessionRegistry(this.sessionRegistry); AuthorizationServerContextHolder.setContext(authorizationServerContext); filterChain.doFilter(request, response); } finally { @@ -84,6 +92,7 @@ private static String getContextPath(HttpServletRequest request) { private static final class DefaultAuthorizationServerContext implements AuthorizationServerContext { private final Supplier issuerSupplier; private final AuthorizationServerSettings authorizationServerSettings; + private SessionRegistry sessionRegistry; private DefaultAuthorizationServerContext(Supplier issuerSupplier, AuthorizationServerSettings authorizationServerSettings) { this.issuerSupplier = issuerSupplier; @@ -100,6 +109,16 @@ public AuthorizationServerSettings getAuthorizationServerSettings() { return this.authorizationServerSettings; } + @Nullable + @Override + public SessionRegistry getSessionRegistry() { + return this.sessionRegistry; + } + + private void setSessionRegistry(SessionRegistry sessionRegistry) { + this.sessionRegistry = sessionRegistry; + } + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationEndpointConfigurer.java index f2a59a998..e1ae33e90 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationEndpointConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -68,6 +69,7 @@ public final class OAuth2AuthorizationEndpointConfigurer extends AbstractOAuth2C private AuthenticationFailureHandler errorResponseHandler; private String consentPage; private Consumer authorizationCodeRequestAuthenticationValidator; + private SessionAuthenticationStrategy sessionAuthenticationStrategy; /** * Restrict for internal use only. @@ -200,6 +202,10 @@ void addAuthorizationCodeRequestAuthenticationValidator( this.authorizationCodeRequestAuthenticationValidator.andThen(authenticationValidator); } + void setSessionAuthenticationStrategy(SessionAuthenticationStrategy sessionAuthenticationStrategy) { + this.sessionAuthenticationStrategy = sessionAuthenticationStrategy; + } + @Override void init(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); @@ -245,6 +251,9 @@ void configure(HttpSecurity httpSecurity) { if (StringUtils.hasText(this.consentPage)) { authorizationEndpointFilter.setConsentPage(this.consentPage); } + if (this.sessionAuthenticationStrategy != null) { + authorizationEndpointFilter.setSessionAuthenticationStrategy(this.sessionAuthenticationStrategy); + } httpSecurity.addFilterBefore(postProcess(authorizationEndpointFilter), AbstractPreAuthenticatedProcessingFilter.class); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java index 1468a3513..5edb80b01 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer; import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2Token; @@ -240,8 +242,23 @@ public void init(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); validateAuthorizationServerSettings(authorizationServerSettings); - OidcConfigurer oidcConfigurer = getConfigurer(OidcConfigurer.class); - if (oidcConfigurer == null) { + if (isOidcEnabled()) { + // Add OpenID Connect session tracking capabilities. + SessionRegistry sessionRegistry = OAuth2ConfigurerUtils.getSessionRegistry(httpSecurity); + OAuth2AuthorizationEndpointConfigurer authorizationEndpointConfigurer = + getConfigurer(OAuth2AuthorizationEndpointConfigurer.class); + authorizationEndpointConfigurer.setSessionAuthenticationStrategy((authentication, request, response) -> { + if (authentication instanceof OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication) { + if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)) { + if (sessionRegistry.getSessionInformation(request.getSession().getId()) == null) { + sessionRegistry.registerNewSession( + request.getSession().getId(), + ((Authentication) authorizationCodeRequestAuthentication.getPrincipal()).getPrincipal()); + } + } + } + }); + } else { // OpenID Connect is disabled. // Add an authentication validator that rejects authentication requests. OAuth2AuthorizationEndpointConfigurer authorizationEndpointConfigurer = @@ -287,6 +304,9 @@ public void configure(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); AuthorizationServerContextFilter authorizationServerContextFilter = new AuthorizationServerContextFilter(authorizationServerSettings); + if (isOidcEnabled()) { + authorizationServerContextFilter.setSessionRegistry(OAuth2ConfigurerUtils.getSessionRegistry(httpSecurity)); + } httpSecurity.addFilterAfter(postProcess(authorizationServerContextFilter), SecurityContextHolderFilter.class); JWKSource jwkSource = OAuth2ConfigurerUtils.getJwkSource(httpSecurity); @@ -297,6 +317,10 @@ public void configure(HttpSecurity httpSecurity) { } } + private boolean isOidcEnabled() { + return getConfigurer(OidcConfigurer.class) != null; + } + private Map, AbstractOAuth2Configurer> createConfigurers() { Map, AbstractOAuth2Configurer> configurers = new LinkedHashMap<>(); configurers.put(OAuth2ClientAuthenticationConfigurer.class, new OAuth2ClientAuthenticationConfigurer(this::postProcess)); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ConfigurerUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ConfigurerUtils.java index 7b4974c5d..50f0282e8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ConfigurerUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ConfigurerUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,14 @@ import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationListener; +import org.springframework.context.event.GenericApplicationListenerAdapter; +import org.springframework.context.event.SmartApplicationListener; import org.springframework.core.ResolvableType; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.context.DelegatingApplicationListener; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.core.session.SessionRegistryImpl; import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.NimbusJwtEncoder; @@ -180,6 +186,28 @@ static AuthorizationServerSettings getAuthorizationServerSettings(HttpSecurity h return authorizationServerSettings; } + static SessionRegistry getSessionRegistry(HttpSecurity httpSecurity) { + SessionRegistry sessionRegistry = httpSecurity.getSharedObject(SessionRegistry.class); + if (sessionRegistry == null) { + sessionRegistry = getOptionalBean(httpSecurity, SessionRegistry.class); + if (sessionRegistry == null) { + sessionRegistry = new SessionRegistryImpl(); + registerDelegateApplicationListener(httpSecurity, (SessionRegistryImpl) sessionRegistry); + } + httpSecurity.setSharedObject(SessionRegistry.class, sessionRegistry); + } + return sessionRegistry; + } + + private static void registerDelegateApplicationListener(HttpSecurity httpSecurity, ApplicationListener delegate) { + DelegatingApplicationListener delegatingApplicationListener = getOptionalBean(httpSecurity, DelegatingApplicationListener.class); + if (delegatingApplicationListener == null) { + return; + } + SmartApplicationListener smartListener = new GenericApplicationListenerAdapter(delegate); + delegatingApplicationListener.addListener(smartListener); + } + static T getBean(HttpSecurity httpSecurity, Class type) { return httpSecurity.getSharedObject(ApplicationContext.class).getBean(type); } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcConfigurer.java index 5ce9cbc38..f43bab9db 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ * @since 0.2.0 * @see OAuth2AuthorizationServerConfigurer#oidc * @see OidcProviderConfigurationEndpointConfigurer + * @see OidcLogoutEndpointConfigurer * @see OidcClientRegistrationEndpointConfigurer * @see OidcUserInfoEndpointConfigurer */ @@ -50,6 +51,7 @@ public final class OidcConfigurer extends AbstractOAuth2Configurer { OidcConfigurer(ObjectPostProcessor objectPostProcessor) { super(objectPostProcessor); addConfigurer(OidcProviderConfigurationEndpointConfigurer.class, new OidcProviderConfigurationEndpointConfigurer(objectPostProcessor)); + addConfigurer(OidcLogoutEndpointConfigurer.class, new OidcLogoutEndpointConfigurer(objectPostProcessor)); addConfigurer(OidcUserInfoEndpointConfigurer.class, new OidcUserInfoEndpointConfigurer(objectPostProcessor)); } @@ -65,6 +67,18 @@ public OidcConfigurer providerConfigurationEndpoint(Customizer logoutEndpointCustomizer) { + logoutEndpointCustomizer.customize(getConfigurer(OidcLogoutEndpointConfigurer.class)); + return this; + } + /** * Configures the OpenID Connect Dynamic Client Registration 1.0 Endpoint. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java new file mode 100644 index 000000000..04eb0a256 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java @@ -0,0 +1,218 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; + +import org.springframework.http.HttpMethod; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.config.annotation.ObjectPostProcessor; +import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.oidc.web.OidcLogoutEndpointFilter; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcLogoutAuthenticationConverter; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.logout.LogoutFilter; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; + +/** + * Configurer for OpenID Connect 1.0 RP-Initiated Logout Endpoint. + * + * @author Joe Grandja + * @since 1.1.0 + * @see OidcConfigurer#logoutEndpoint + * @see OidcLogoutEndpointFilter + */ +public final class OidcLogoutEndpointConfigurer extends AbstractOAuth2Configurer { + private RequestMatcher requestMatcher; + private final List logoutRequestConverters = new ArrayList<>(); + private Consumer> logoutRequestConvertersConsumer = (logoutRequestConverters) -> {}; + private final List authenticationProviders = new ArrayList<>(); + private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> {}; + private AuthenticationSuccessHandler logoutResponseHandler; + private AuthenticationFailureHandler errorResponseHandler; + + /** + * Restrict for internal use only. + */ + OidcLogoutEndpointConfigurer(ObjectPostProcessor objectPostProcessor) { + super(objectPostProcessor); + } + + /** + * Adds an {@link AuthenticationConverter} used when attempting to extract a Logout Request from {@link HttpServletRequest} + * to an instance of {@link OidcLogoutAuthenticationToken} used for authenticating the request. + * + * @param logoutRequestConverter an {@link AuthenticationConverter} used when attempting to extract a Logout Request from {@link HttpServletRequest} + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer logoutRequestConverter( + AuthenticationConverter logoutRequestConverter) { + Assert.notNull(logoutRequestConverter, "logoutRequestConverter cannot be null"); + this.logoutRequestConverters.add(logoutRequestConverter); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #logoutRequestConverter(AuthenticationConverter) AuthenticationConverter}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}. + * + * @param logoutRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer logoutRequestConverters( + Consumer> logoutRequestConvertersConsumer) { + Assert.notNull(logoutRequestConvertersConsumer, "logoutRequestConvertersConsumer cannot be null"); + this.logoutRequestConvertersConsumer = logoutRequestConvertersConsumer; + return this; + } + + /** + * Adds an {@link AuthenticationProvider} used for authenticating an {@link OidcLogoutAuthenticationToken}. + * + * @param authenticationProvider an {@link AuthenticationProvider} used for authenticating an {@link OidcLogoutAuthenticationToken} + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) { + Assert.notNull(authenticationProvider, "authenticationProvider cannot be null"); + this.authenticationProviders.add(authenticationProvider); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}. + * + * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer authenticationProviders( + Consumer> authenticationProvidersConsumer) { + Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null"); + this.authenticationProvidersConsumer = authenticationProvidersConsumer; + return this; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcLogoutAuthenticationToken} + * and performing the logout. + * + * @param logoutResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcLogoutAuthenticationToken} + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer logoutResponseHandler(AuthenticationSuccessHandler logoutResponseHandler) { + this.logoutResponseHandler = logoutResponseHandler; + return this; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @return the {@link OidcLogoutEndpointConfigurer} for further configuration + */ + public OidcLogoutEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) { + this.errorResponseHandler = errorResponseHandler; + return this; + } + + @Override + void init(HttpSecurity httpSecurity) { + AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); + String logoutEndpointUri = authorizationServerSettings.getOidcLogoutEndpoint(); + this.requestMatcher = new OrRequestMatcher( + new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.GET.name()), + new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.POST.name()) + ); + + List authenticationProviders = createDefaultAuthenticationProviders(httpSecurity); + if (!this.authenticationProviders.isEmpty()) { + authenticationProviders.addAll(0, this.authenticationProviders); + } + this.authenticationProvidersConsumer.accept(authenticationProviders); + authenticationProviders.forEach(authenticationProvider -> + httpSecurity.authenticationProvider(postProcess(authenticationProvider))); + } + + @Override + void configure(HttpSecurity httpSecurity) { + AuthenticationManager authenticationManager = httpSecurity.getSharedObject(AuthenticationManager.class); + AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); + + OidcLogoutEndpointFilter oidcLogoutEndpointFilter = + new OidcLogoutEndpointFilter( + authenticationManager, + authorizationServerSettings.getOidcLogoutEndpoint()); + List authenticationConverters = createDefaultAuthenticationConverters(); + if (!this.logoutRequestConverters.isEmpty()) { + authenticationConverters.addAll(0, this.logoutRequestConverters); + } + this.logoutRequestConvertersConsumer.accept(authenticationConverters); + oidcLogoutEndpointFilter.setAuthenticationConverter( + new DelegatingAuthenticationConverter(authenticationConverters)); + if (this.logoutResponseHandler != null) { + oidcLogoutEndpointFilter.setAuthenticationSuccessHandler(this.logoutResponseHandler); + } + if (this.errorResponseHandler != null) { + oidcLogoutEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); + } + httpSecurity.addFilterBefore(postProcess(oidcLogoutEndpointFilter), LogoutFilter.class); + } + + @Override + RequestMatcher getRequestMatcher() { + return this.requestMatcher; + } + + private static List createDefaultAuthenticationConverters() { + List authenticationConverters = new ArrayList<>(); + + authenticationConverters.add(new OidcLogoutAuthenticationConverter()); + + return authenticationConverters; + } + + private static List createDefaultAuthenticationProviders(HttpSecurity httpSecurity) { + List authenticationProviders = new ArrayList<>(); + + OidcLogoutAuthenticationProvider oidcLogoutAuthenticationProvider = + new OidcLogoutAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + authenticationProviders.add(oidcLogoutAuthenticationProvider); + + return authenticationProviders; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java index a12ef305a..2888b3e8a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.server.authorization.context; +import org.springframework.lang.Nullable; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; /** @@ -41,4 +43,15 @@ public interface AuthorizationServerContext { */ AuthorizationServerSettings getAuthorizationServerSettings(); + /** + * Returns the {@link SessionRegistry} used to track OpenID Connect sessions or {@code null} if OpenID Connect is disabled. + * + * @return the {@link SessionRegistry} used to track OpenID Connect sessions or {@code null} if OpenID Connect is disabled + * @since 1.1.0 + */ + @Nullable + default SessionRegistry getSessionRegistry() { + return null; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java index b52e2f2d6..8d36b9c30 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,6 +37,7 @@ * @see OidcClientMetadataClaimNames * @see OidcClientRegistration * @see 2. Client Metadata + * @see 3.1. Client Registration Metadata */ public interface OidcClientMetadataClaimAccessor extends ClaimAccessor { @@ -94,6 +95,18 @@ default List getRedirectUris() { return getClaimAsStringList(OidcClientMetadataClaimNames.REDIRECT_URIS); } + /** + * Returns the post logout redirection {@code URI} values used by the Client {@code (post_logout_redirect_uris)}. + * The {@code post_logout_redirect_uri} parameter is used by the client when requesting + * that the End-User's User Agent be redirected to after a logout has been performed. + * + * @return the post logout redirection {@code URI} values used by the Client + * @since 1.1.0 + */ + default List getPostLogoutRedirectUris() { + return getClaimAsStringList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); + } + /** * Returns the authentication method used by the Client for the Token Endpoint {@code (token_endpoint_auth_method)}. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java index 06cecaf77..e8d755378 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientMetadataClaimNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ * @author Joe Grandja * @since 0.1.1 * @see 2. Client Metadata + * @see 3.1. Client Registration Metadata */ public final class OidcClientMetadataClaimNames { @@ -61,6 +62,14 @@ public final class OidcClientMetadataClaimNames { */ public static final String REDIRECT_URIS = "redirect_uris"; + /** + * {@code post_logout_redirect_uris} - the post logout redirection {@code URI} values used by the Client. + * The {@code post_logout_redirect_uri} parameter is used by the client when requesting + * that the End-User's User Agent be redirected to after a logout has been performed. + * @since 1.1.0 + */ + public static final String POST_LOGOUT_REDIRECT_URIS = "post_logout_redirect_uris"; + /** * {@code token_endpoint_auth_method} - the authentication method used by the Client for the Token Endpoint */ diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java index 21c1b7216..036415ffe 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistration.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,8 +44,9 @@ * @author Joe Grandja * @since 0.1.1 * @see OidcClientMetadataClaimAccessor - * @see 3.1. Client Registration Request - * @see 3.2. Client Registration Response + * @see 3.1. Client Registration Request + * @see 3.2. Client Registration Response + * @see 3.1. Client Registration Metadata */ public final class OidcClientRegistration implements OidcClientMetadataClaimAccessor, Serializable { private static final long serialVersionUID = SpringAuthorizationServerVersion.SERIAL_VERSION_UID; @@ -168,6 +169,33 @@ public Builder redirectUris(Consumer> redirectUrisConsumer) { return this; } + /** + * Add the post logout redirection {@code URI} used by the Client, OPTIONAL. + * The {@code post_logout_redirect_uri} parameter is used by the client when requesting + * that the End-User's User Agent be redirected to after a logout has been performed. + * + * @param postLogoutRedirectUri the post logout redirection {@code URI} used by the Client + * @return the {@link Builder} for further configuration + * @since 1.1.0 + */ + public Builder postLogoutRedirectUri(String postLogoutRedirectUri) { + addClaimToClaimList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, postLogoutRedirectUri); + return this; + } + + /** + * A {@code Consumer} of the post logout redirection {@code URI} values used by the Client, + * allowing the ability to add, replace, or remove, OPTIONAL. + * + * @param postLogoutRedirectUrisConsumer a {@code Consumer} of the post logout redirection {@code URI} values used by the Client + * @return the {@link Builder} for further configuration + * @since 1.1.0 + */ + public Builder postLogoutRedirectUris(Consumer> postLogoutRedirectUrisConsumer) { + acceptClaimValues(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, postLogoutRedirectUrisConsumer); + return this; + } + /** * Sets the authentication method used by the Client for the Token Endpoint, OPTIONAL. * @@ -358,6 +386,10 @@ private void validate() { Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be null"); Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris must be of type List"); Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be empty"); + if (this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), "post_logout_redirect_uris must be of type List"); + Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS), "post_logout_redirect_uris cannot be empty"); + } if (this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES) != null) { Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), "grant_types must be of type List"); Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), "grant_types cannot be empty"); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfiguration.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfiguration.java index 96549c3bf..e2b8567d7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfiguration.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ * The claims are defined by the OpenID Connect Discovery 1.0 specification. * * @author Daniel Garnier-Moiroux + * @author Joe Grandja * @since 0.1.0 * @see AbstractOAuth2AuthorizationServerMetadata * @see OidcProviderMetadataClaimAccessor @@ -130,6 +131,17 @@ public Builder userInfoEndpoint(String userInfoEndpoint) { return claim(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT, userInfoEndpoint); } + /** + * Use this {@code end_session_endpoint} in the resulting {@link OidcProviderConfiguration}, OPTIONAL. + * + * @param endSessionEndpoint the {@code URL} of the OpenID Connect 1.0 End Session Endpoint + * @return the {@link Builder} for further configuration + * @since 1.1.0 + */ + public Builder endSessionEndpoint(String endSessionEndpoint) { + return claim(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT, endSessionEndpoint); + } + /** * Validate the claims and build the {@link OidcProviderConfiguration}. *

@@ -159,6 +171,9 @@ protected void validate() { if (getClaims().get(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT) != null) { validateURL(getClaims().get(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT), "userInfoEndpoint must be a valid URL"); } + if (getClaims().get(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT) != null) { + validateURL(getClaims().get(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT), "endSessionEndpoint must be a valid URL"); + } } @SuppressWarnings("unchecked") diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimAccessor.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimAccessor.java index 277ade2e7..d67d1a040 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimAccessor.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimAccessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ */ package org.springframework.security.oauth2.server.authorization.oidc; - import java.net.URL; import java.util.List; @@ -30,6 +29,7 @@ * in the OpenID Provider Configuration Response. * * @author Daniel Garnier-Moiroux + * @author Joe Grandja * @since 0.1.0 * @see ClaimAccessor * @see OAuth2AuthorizationServerMetadataClaimAccessor @@ -68,4 +68,14 @@ default URL getUserInfoEndpoint() { return getClaimAsURL(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT); } + /** + * Returns the {@code URL} of the OpenID Connect 1.0 End Session Endpoint {@code (end_session_endpoint)}. + * + * @return the {@code URL} of the OpenID Connect 1.0 End Session Endpoint + * @since 1.1.0 + */ + default URL getEndSessionEndpoint() { + return getClaimAsURL(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT); + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimNames.java index ed8b64265..b4bd44717 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderMetadataClaimNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ * in the OpenID Provider Configuration Response. * * @author Daniel Garnier-Moiroux + * @author Joe Grandja * @since 0.1.0 * @see OAuth2AuthorizationServerMetadataClaimNames * @see 3. OpenID Provider Metadata @@ -46,6 +47,12 @@ public final class OidcProviderMetadataClaimNames extends OAuth2AuthorizationSer */ public static final String USER_INFO_ENDPOINT = "userinfo_endpoint"; + /** + * {@code end_session_endpoint} - the {@code URL} of the OpenID Connect 1.0 End Session Endpoint + * @since 1.1.0 + */ + public static final String END_SESSION_ENDPOINT = "end_session_endpoint"; + private OidcProviderMetadataClaimNames() { } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java index 97166d0d9..27a1cce23 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -174,6 +174,10 @@ private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegis throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI, OidcClientMetadataClaimNames.REDIRECT_URIS); } + if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getPostLogoutRedirectUris())) { + throwInvalidClientRegistration("invalid_client_metadata", OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); + } + if (!isValidTokenEndpointAuthenticationMethod(clientRegistrationAuthentication.getClientRegistration())) { throwInvalidClientRegistration("invalid_client_metadata", OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); } @@ -371,6 +375,11 @@ public RegisteredClient convert(OidcClientRegistration clientRegistration) { builder.redirectUris(redirectUris -> redirectUris.addAll(clientRegistration.getRedirectUris())); + if (!CollectionUtils.isEmpty(clientRegistration.getPostLogoutRedirectUris())) { + builder.postLogoutRedirectUris(postLogoutRedirectUris -> + postLogoutRedirectUris.addAll(clientRegistration.getPostLogoutRedirectUris())); + } + if (!CollectionUtils.isEmpty(clientRegistration.getGrantTypes())) { builder.authorizationGrantTypes(authorizationGrantTypes -> clientRegistration.getGrantTypes().forEach(grantType -> diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java new file mode 100644 index 000000000..ce0259ed7 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java @@ -0,0 +1,163 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * An {@link AuthenticationProvider} implementation for OpenID Connect 1.0 RP-Initiated Logout Endpoint. + * + * @author Joe Grandja + * @since 1.1.0 + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see 2. RP-Initiated Logout + */ +public final class OidcLogoutAuthenticationProvider implements AuthenticationProvider { + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = + new OAuth2TokenType(OidcParameterNames.ID_TOKEN); + private final Log logger = LogFactory.getLog(getClass()); + private final RegisteredClientRepository registeredClientRepository; + private final OAuth2AuthorizationService authorizationService; + + /** + * Constructs an {@code OidcLogoutAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public OidcLogoutAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OidcLogoutAuthenticationToken oidcLogoutAuthentication = + (OidcLogoutAuthenticationToken) authentication; + + OAuth2Authorization authorization = this.authorizationService.findByToken( + oidcLogoutAuthentication.getIdToken(), ID_TOKEN_TOKEN_TYPE); + if (authorization == null) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + + RegisteredClient registeredClient = this.registeredClientRepository.findById( + authorization.getRegisteredClientId()); + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Retrieved authorization with ID Token"); + } + + OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken(); + + // Validate client identity + List audClaim = idToken.getAudience(); + if (CollectionUtils.isEmpty(audClaim) || + !audClaim.contains(registeredClient.getClientId())) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + if (StringUtils.hasText(oidcLogoutAuthentication.getClientId()) && + !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri()) && + !registeredClient.getPostLogoutRedirectUris().contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Validated logout request parameters"); + } + + // Validate user session + SessionInformation sessionInformation = null; + Authentication userPrincipal = (Authentication) oidcLogoutAuthentication.getPrincipal(); + if (isPrincipalAuthenticated(userPrincipal) && + StringUtils.hasText(oidcLogoutAuthentication.getSessionId())) { + sessionInformation = findSessionInformation( + userPrincipal, oidcLogoutAuthentication.getSessionId()); + if (sessionInformation != null) { + String sidClaim = idToken.getClaim("sid"); + if (!StringUtils.hasText(sidClaim) || + !sidClaim.equals(sessionInformation.getSessionId())) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + } + } + } + + if (this.logger.isTraceEnabled()) { + this.logger.trace("Authenticated logout request"); + } + + return new OidcLogoutAuthenticationToken(oidcLogoutAuthentication.getIdToken(), userPrincipal, + sessionInformation, oidcLogoutAuthentication.getClientId(), + oidcLogoutAuthentication.getPostLogoutRedirectUri(), oidcLogoutAuthentication.getState()); + } + + @Override + public boolean supports(Class authentication) { + return OidcLogoutAuthenticationToken.class.isAssignableFrom(authentication); + } + + private static boolean isPrincipalAuthenticated(Authentication principal) { + return principal != null && + !AnonymousAuthenticationToken.class.isAssignableFrom(principal.getClass()) && + principal.isAuthenticated(); + } + + private static SessionInformation findSessionInformation(Authentication principal, String sessionId) { + SessionRegistry sessionRegistry = AuthorizationServerContextHolder.getContext().getSessionRegistry(); + List sessions = sessionRegistry.getAllSessions(principal.getPrincipal(), true); + SessionInformation sessionInformation = null; + if (!CollectionUtils.isEmpty(sessions)) { + for (SessionInformation session : sessions) { + if (session.getSessionId().equals(sessionId)) { + sessionInformation = session; + break; + } + } + } + return sessionInformation; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationToken.java new file mode 100644 index 000000000..a9cc28c03 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationToken.java @@ -0,0 +1,170 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Collections; + +import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.oauth2.server.authorization.util.SpringAuthorizationServerVersion; +import org.springframework.util.Assert; + +/** + * An {@link Authentication} implementation used for OpenID Connect 1.0 RP-Initiated Logout Endpoint. + * + * @author Joe Grandja + * @since 1.1.0 + * @see AbstractAuthenticationToken + * @see OidcLogoutAuthenticationProvider + */ +public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = SpringAuthorizationServerVersion.SERIAL_VERSION_UID; + private final String idToken; + private final Authentication principal; + private final String sessionId; + private final SessionInformation sessionInformation; + private final String clientId; + private final String postLogoutRedirectUri; + private final String state; + + /** + * Constructs an {@code OidcLogoutAuthenticationToken} using the provided parameters. + * + * @param idToken the ID Token previously issued by the Provider to the Client and used as a hint about the End-User's current authenticated session with the Client + * @param principal the authenticated principal representing the End-User + * @param sessionId the End-User's current authenticated session identifier with the Client + * @param clientId the client identifier the ID Token was issued to + * @param postLogoutRedirectUri the URI which the Client is requesting that the End-User's User Agent be redirected to after a logout has been performed + * @param state the opaque value used by the Client to maintain state between the logout request and the callback to the {@code postLogoutRedirectUri} + */ + public OidcLogoutAuthenticationToken(String idToken, Authentication principal, @Nullable String sessionId, + @Nullable String clientId, @Nullable String postLogoutRedirectUri, @Nullable String state) { + super(Collections.emptyList()); + Assert.hasText(idToken, "idToken cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + this.idToken = idToken; + this.principal = principal; + this.sessionId = sessionId; + this.sessionInformation = null; + this.clientId = clientId; + this.postLogoutRedirectUri = postLogoutRedirectUri; + this.state = state; + setAuthenticated(false); + } + + /** + * Constructs an {@code OidcLogoutAuthenticationToken} using the provided parameters. + * + * @param idToken the ID Token previously issued by the Provider to the Client and used as a hint about the End-User's current authenticated session with the Client + * @param principal the authenticated principal representing the End-User + * @param sessionInformation the End-User's current authenticated session information with the Client + * @param clientId the client identifier the ID Token was issued to + * @param postLogoutRedirectUri the URI which the Client is requesting that the End-User's User Agent be redirected to after a logout has been performed + * @param state the opaque value used by the Client to maintain state between the logout request and the callback to the {@code postLogoutRedirectUri} + */ + public OidcLogoutAuthenticationToken(String idToken, Authentication principal, @Nullable SessionInformation sessionInformation, + @Nullable String clientId, @Nullable String postLogoutRedirectUri, @Nullable String state) { + super(Collections.emptyList()); + Assert.hasText(idToken, "idToken cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + this.idToken = idToken; + this.principal = principal; + this.sessionId = sessionInformation != null ? sessionInformation.getSessionId() : null; + this.sessionInformation = sessionInformation; + this.clientId = clientId; + this.postLogoutRedirectUri = postLogoutRedirectUri; + this.state = state; + setAuthenticated(true); + } + + /** + * Returns the authenticated principal representing the End-User. + * + * @return the authenticated principal + */ + @Override + public Object getPrincipal() { + return this.principal; + } + + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the ID Token previously issued by the Provider to the Client and used as a hint + * about the End-User's current authenticated session with the Client. + * + * @return the ID Token previously issued by the Provider to the Client + */ + public String getIdToken() { + return this.idToken; + } + + /** + * Returns the End-User's current authenticated session identifier with the Client. + * + * @return the End-User's current authenticated session identifier + */ + @Nullable + public String getSessionId() { + return this.sessionId; + } + + /** + * Returns the End-User's current authenticated session information with the Client. + * + * @return the End-User's current authenticated session information + */ + @Nullable + public SessionInformation getSessionInformation() { + return this.sessionInformation; + } + + /** + * Returns the client identifier the ID Token was issued to. + * + * @return the client identifier + */ + @Nullable + public String getClientId() { + return this.clientId; + } + + /** + * Returns the URI which the Client is requesting that the End-User's User Agent be redirected to after a logout has been performed. + * + * @return the URI which the Client is requesting that the End-User's User Agent be redirected to after a logout has been performed + */ + @Nullable + public String getPostLogoutRedirectUri() { + return this.postLogoutRedirectUri; + } + + /** + * Returns the opaque value used by the Client to maintain state between the logout request and the callback to the {@link #getPostLogoutRedirectUri()}. + * + * @return the opaque value used by the Client to maintain state between the logout request and the callback to the {@link #getPostLogoutRedirectUri()} + */ + @Nullable + public String getState() { + return this.state; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java index 75aa17c96..7cd62d223 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -48,6 +48,11 @@ public OidcClientRegistration convert(RegisteredClient registeredClient) { builder.redirectUris(redirectUris -> redirectUris.addAll(registeredClient.getRedirectUris())); + if (!CollectionUtils.isEmpty(registeredClient.getPostLogoutRedirectUris())) { + builder.postLogoutRedirectUris(postLogoutRedirectUris -> + postLogoutRedirectUris.addAll(registeredClient.getPostLogoutRedirectUris())); + } + builder.grantTypes(grantTypes -> registeredClient.getAuthorizationGrantTypes().forEach(authorizationGrantType -> grantTypes.add(authorizationGrantType.getValue()))); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java index 78655975a..3442c2ae6 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -148,6 +148,7 @@ private MapOidcClientRegistrationConverter() { claimConverters.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, MapOidcClientRegistrationConverter::convertClientSecretExpiresAt); claimConverters.put(OidcClientMetadataClaimNames.CLIENT_NAME, stringConverter); claimConverters.put(OidcClientMetadataClaimNames.REDIRECT_URIS, collectionStringConverter); + claimConverters.put(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, collectionStringConverter); claimConverters.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, stringConverter); claimConverters.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, stringConverter); claimConverters.put(OidcClientMetadataClaimNames.GRANT_TYPES, collectionStringConverter); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java new file mode 100644 index 000000000..ff8d9e5ec --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java @@ -0,0 +1,221 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.core.log.LogMessage; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcLogoutAuthenticationConverter; +import org.springframework.security.web.DefaultRedirectStrategy; +import org.springframework.security.web.RedirectStrategy; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.logout.LogoutHandler; +import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; +import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler; +import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler; +import org.springframework.security.web.util.matcher.AntPathRequestMatcher; +import org.springframework.security.web.util.matcher.OrRequestMatcher; +import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; +import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.util.UriComponentsBuilder; + +/** + * A {@code Filter} that processes OpenID Connect 1.0 RP-Initiated Logout Requests. + * + * @author Joe Grandja + * @since 1.1.0 + * @see OidcLogoutAuthenticationConverter + * @see OidcLogoutAuthenticationProvider + * @see 2. RP-Initiated Logout + */ +public final class OidcLogoutEndpointFilter extends OncePerRequestFilter { + + /** + * The default endpoint {@code URI} for OpenID Connect 1.0 RP-Initiated Logout Requests. + */ + private static final String DEFAULT_OIDC_LOGOUT_ENDPOINT_URI = "/connect/logout"; + + private final AuthenticationManager authenticationManager; + private final RequestMatcher logoutEndpointMatcher; + private final LogoutHandler logoutHandler; + private final LogoutSuccessHandler logoutSuccessHandler; + private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + private AuthenticationConverter authenticationConverter; + private AuthenticationSuccessHandler authenticationSuccessHandler = this::performLogout; + private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + + /** + * Constructs an {@code OidcLogoutEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + */ + public OidcLogoutEndpointFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_OIDC_LOGOUT_ENDPOINT_URI); + } + + /** + * Constructs an {@code OidcLogoutEndpointFilter} using the provided parameters. + * + * @param authenticationManager the authentication manager + * @param logoutEndpointUri the endpoint {@code URI} for OpenID Connect 1.0 RP-Initiated Logout Requests + */ + public OidcLogoutEndpointFilter(AuthenticationManager authenticationManager, + String logoutEndpointUri) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + Assert.hasText(logoutEndpointUri, "logoutEndpointUri cannot be empty"); + this.authenticationManager = authenticationManager; + this.logoutEndpointMatcher = new OrRequestMatcher( + new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.GET.name()), + new AntPathRequestMatcher(logoutEndpointUri, HttpMethod.POST.name())); + this.logoutHandler = new SecurityContextLogoutHandler(); + SimpleUrlLogoutSuccessHandler urlLogoutSuccessHandler = new SimpleUrlLogoutSuccessHandler(); + urlLogoutSuccessHandler.setDefaultTargetUrl("/"); + this.logoutSuccessHandler = urlLogoutSuccessHandler; + this.authenticationConverter = new OidcLogoutAuthenticationConverter(); + } + + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!this.logoutEndpointMatcher.matches(request)) { + filterChain.doFilter(request, response); + return; + } + + try { + Authentication oidcLogoutAuthentication = this.authenticationConverter.convert(request); + + Authentication oidcLogoutAuthenticationResult = + this.authenticationManager.authenticate(oidcLogoutAuthentication); + + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, oidcLogoutAuthenticationResult); + } catch (OAuth2AuthenticationException ex) { + if (this.logger.isTraceEnabled()) { + this.logger.trace(LogMessage.format("Logout request failed: %s", ex.getError()), ex); + } + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); + } catch (Exception ex) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_REQUEST, + "OpenID Connect 1.0 RP-Initiated Logout Error: " + ex.getMessage(), + "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling"); + if (this.logger.isTraceEnabled()) { + this.logger.trace(error, ex); + } + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(error)); + } + } + + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract a Logout Request from {@link HttpServletRequest} + * to an instance of {@link OidcLogoutAuthenticationToken} used for authenticating the request. + * + * @param authenticationConverter the {@link AuthenticationConverter} used when attempting to extract a Logout Request from {@link HttpServletRequest} + */ + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcLogoutAuthenticationToken} + * and performing the logout. + * + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcLogoutAuthenticationToken} + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + private void performLogout(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException, ServletException { + + OidcLogoutAuthenticationToken oidcLogoutAuthentication = (OidcLogoutAuthenticationToken) authentication; + + // Check for active user session + if (oidcLogoutAuthentication.getSessionInformation() != null) { + // Perform logout + this.logoutHandler.logout(request, response, + (Authentication) oidcLogoutAuthentication.getPrincipal()); + } + + if (oidcLogoutAuthentication.isAuthenticated() && + StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { + // Perform post-logout redirect + UriComponentsBuilder uriBuilder = UriComponentsBuilder + .fromUriString(oidcLogoutAuthentication.getPostLogoutRedirectUri()); + String redirectUri; + if (StringUtils.hasText(oidcLogoutAuthentication.getState())) { + uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); + Map queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, oidcLogoutAuthentication.getState()); + redirectUri = uriBuilder.build(queryParams).toString(); + } else { + redirectUri = uriBuilder.toUriString(); + } + this.redirectStrategy.sendRedirect(request, response, redirectUri); + } else { + // Perform default redirect + this.logoutSuccessHandler.onLogoutSuccess(request, response, + (Authentication) oidcLogoutAuthentication.getPrincipal()); + } + } + + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + AuthenticationException exception) throws IOException { + + OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); + response.sendError(HttpStatus.BAD_REQUEST.value(), + "OpenID Connect 1.0 RP-Initiated Logout Error: " + error.toString()); + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java index 83e6a19fa..608aafd5d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -98,6 +98,7 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse .tokenEndpointAuthenticationMethods(clientAuthenticationMethods()) .jwkSetUrl(asUrl(issuer, authorizationServerSettings.getJwkSetEndpoint())) .userInfoEndpoint(asUrl(issuer, authorizationServerSettings.getOidcUserInfoEndpoint())) + .endSessionEndpoint(asUrl(issuer, authorizationServerSettings.getOidcLogoutEndpoint())) .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java new file mode 100644 index 000000000..8739fe732 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java @@ -0,0 +1,111 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web.authentication; + +import java.util.Map; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpSession; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.oidc.web.OidcLogoutEndpointFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * Attempts to extract an OpenID Connect 1.0 RP-Initiated Logout Request from {@link HttpServletRequest} + * and then converts to an {@link OidcLogoutAuthenticationToken} used for authenticating the request. + * + * @author Joe Grandja + * @since 1.1.0 + * @see AuthenticationConverter + * @see OidcLogoutAuthenticationToken + * @see OidcLogoutEndpointFilter + */ +public final class OidcLogoutAuthenticationConverter implements AuthenticationConverter { + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + + @Override + public Authentication convert(HttpServletRequest request) { + MultiValueMap parameters = getParameters(request); + + // id_token_hint (REQUIRED) // RECOMMENDED as per spec + String idTokenHint = request.getParameter("id_token_hint"); + if (!StringUtils.hasText(idTokenHint) || + request.getParameterValues("id_token_hint").length != 1) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + } + + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + if (principal == null) { + principal = ANONYMOUS_AUTHENTICATION; + } + + String sessionId = null; + HttpSession session = request.getSession(false); + if (session != null) { + sessionId = session.getId(); + } + + // client_id (OPTIONAL) + String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); + if (StringUtils.hasText(clientId) && + parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + } + + // post_logout_redirect_uri (OPTIONAL) + String postLogoutRedirectUri = parameters.getFirst("post_logout_redirect_uri"); + if (StringUtils.hasText(postLogoutRedirectUri) && + parameters.get("post_logout_redirect_uri").size() != 1) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + } + + // state (OPTIONAL) + String state = parameters.getFirst(OAuth2ParameterNames.STATE); + if (StringUtils.hasText(state) && + parameters.get(OAuth2ParameterNames.STATE).size() != 1) { + throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + } + + return new OidcLogoutAuthenticationToken(idTokenHint, principal, + sessionId, clientId, postLogoutRedirectUri, state); + } + + private static MultiValueMap getParameters(HttpServletRequest request) { + Map parameterMap = request.getParameterMap(); + MultiValueMap parameters = new LinkedMultiValueMap<>(parameterMap.size()); + parameterMap.forEach((key, values) -> { + if (values.length > 0) { + for (String value : values) { + parameters.add(key, value); + } + } + }); + return parameters; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java index 53484bbcb..a2bb0c80d 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettings.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ private AuthorizationServerSettings(Map settings) { } /** - * Returns the URL of the Authorization Server's Issuer Identifier + * Returns the URL of the Authorization Server's Issuer Identifier. * * @return the URL of the Authorization Server's Issuer Identifier */ @@ -106,6 +106,16 @@ public String getOidcUserInfoEndpoint() { return getSetting(ConfigurationSettingNames.AuthorizationServer.OIDC_USER_INFO_ENDPOINT); } + /** + * Returns the OpenID Connect 1.0 Logout endpoint. The default is {@code /connect/logout}. + * + * @return the OpenID Connect 1.0 Logout endpoint + * @since 1.1.0 + */ + public String getOidcLogoutEndpoint() { + return getSetting(ConfigurationSettingNames.AuthorizationServer.OIDC_LOGOUT_ENDPOINT); + } + /** * Constructs a new {@link Builder} with the default settings. * @@ -119,7 +129,8 @@ public static Builder builder() { .tokenRevocationEndpoint("/oauth2/revoke") .tokenIntrospectionEndpoint("/oauth2/introspect") .oidcClientRegistrationEndpoint("/connect/register") - .oidcUserInfoEndpoint("/userinfo"); + .oidcUserInfoEndpoint("/userinfo") + .oidcLogoutEndpoint("/connect/logout"); } /** @@ -222,6 +233,17 @@ public Builder oidcUserInfoEndpoint(String oidcUserInfoEndpoint) { return setting(ConfigurationSettingNames.AuthorizationServer.OIDC_USER_INFO_ENDPOINT, oidcUserInfoEndpoint); } + /** + * Sets the OpenID Connect 1.0 Logout endpoint. + * + * @param oidcLogoutEndpoint the OpenID Connect 1.0 Logout endpoint + * @return the {@link Builder} for further configuration + * @since 1.1.0 + */ + public Builder oidcLogoutEndpoint(String oidcLogoutEndpoint) { + return setting(ConfigurationSettingNames.AuthorizationServer.OIDC_LOGOUT_ENDPOINT, oidcLogoutEndpoint); + } + /** * Builds the {@link AuthorizationServerSettings}. * diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java index b548019b2..e7cc341dd 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/settings/ConfigurationSettingNames.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,6 +116,12 @@ public static final class AuthorizationServer { */ public static final String OIDC_USER_INFO_ENDPOINT = AUTHORIZATION_SERVER_SETTINGS_NAMESPACE.concat("oidc-user-info-endpoint"); + /** + * Set the OpenID Connect 1.0 Logout endpoint. + * @since 1.1.0 + */ + public static final String OIDC_LOGOUT_ENDPOINT = AUTHORIZATION_SERVER_SETTINGS_NAMESPACE.concat("oidc-logout-endpoint"); + private AuthorizationServer() { } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java index 3cc52de73..98895c80a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,14 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; +import java.util.List; import org.springframework.lang.Nullable; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -126,7 +131,11 @@ public Jwt generate(OAuth2TokenContext context) { claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce); } } - // TODO Add 'auth_time' claim + SessionInformation sessionInformation = getSessionInformation(context); + if (sessionInformation != null) { + claimsBuilder.claim("sid", sessionInformation.getSessionId()); + claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest()); + } } // @formatter:on @@ -161,6 +170,24 @@ public Jwt generate(OAuth2TokenContext context) { return jwt; } + private static SessionInformation getSessionInformation(OAuth2TokenContext context) { + SessionInformation sessionInformation = null; + if (context.getAuthorizationServerContext().getSessionRegistry() != null) { + SessionRegistry sessionRegistry = context.getAuthorizationServerContext().getSessionRegistry(); + List sessions = sessionRegistry.getAllSessions(context.getPrincipal().getPrincipal(), false); + if (!CollectionUtils.isEmpty(sessions)) { + sessionInformation = sessions.get(0); + if (sessions.size() > 1) { + // Get the most recent session + sessions = new ArrayList<>(sessions); + sessions.sort(Comparator.comparing(SessionInformation::getLastRequest)); + sessionInformation = sessions.get(sessions.size() - 1); + } + } + } + return sessionInformation; + } + /** * Sets the {@link OAuth2TokenCustomizer} that customizes the * {@link JwtEncodingContext#getJwsHeader() JWS headers} and/or diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index f30bca93a..a109dfd51 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -35,6 +35,7 @@ import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse; @@ -54,6 +55,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.security.web.util.RedirectUrlBuilder; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -97,6 +99,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse; private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; + private SessionAuthenticationStrategy sessionAuthenticationStrategy = (authentication, request, response) -> {}; private String consentPage; /** @@ -182,6 +185,9 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse return; } + this.sessionAuthenticationStrategy.onAuthentication( + authenticationResult, request, response); + this.authenticationSuccessHandler.onAuthenticationSuccess( request, response, authenticationResult); @@ -238,6 +244,19 @@ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authent this.authenticationFailureHandler = authenticationFailureHandler; } + /** + * Sets the {@link SessionAuthenticationStrategy} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationToken} + * before calling the {@link AuthenticationSuccessHandler}. + * If OpenID Connect is enabled, the default implementation tracks OpenID Connect sessions using a {@link SessionRegistry}. + * + * @param sessionAuthenticationStrategy the {@link SessionAuthenticationStrategy} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationToken} + * @since 1.1.0 + */ + public void setSessionAuthenticationStrategy(SessionAuthenticationStrategy sessionAuthenticationStrategy) { + Assert.notNull(sessionAuthenticationStrategy, "sessionAuthenticationStrategy cannot be null"); + this.sessionAuthenticationStrategy = sessionAuthenticationStrategy; + } + /** * Specify the URI to redirect Resource Owners to if consent is required. A default consent * page will be generated when this attribute is not specified. diff --git a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql index a12023077..a11ff75c8 100644 --- a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql +++ b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql @@ -8,6 +8,7 @@ CREATE TABLE oauth2_registered_client ( client_authentication_methods varchar(1000) NOT NULL, authorization_grant_types varchar(1000) NOT NULL, redirect_uris varchar(1000) DEFAULT NULL, + post_logout_redirect_uris varchar(1000) DEFAULT NULL, scopes varchar(1000) NOT NULL, client_settings varchar(2000) NOT NULL, token_settings varchar(2000) NOT NULL, diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java index 225767930..83c3049ae 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,8 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -47,6 +49,7 @@ public class InMemoryOAuth2AuthorizationServiceTests { "code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES)); private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); private InMemoryOAuth2AuthorizationService authorizationService; @BeforeEach @@ -263,6 +266,29 @@ public void findByTokenWhenAccessTokenExistsThenFound() { assertThat(authorization).isEqualTo(result); } + @Test + public void findByTokenWhenIdTokenExistsThenFound() { + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject("subject") + .issuedAt(Instant.now().minusSeconds(60)) + .expiresAt(Instant.now()) + .build(); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(idToken) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + idToken.getTokenValue(), ID_TOKEN_TOKEN_TYPE); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(idToken.getTokenValue(), null); + assertThat(authorization).isEqualTo(result); + } + @Test public void findByTokenWhenRefreshTokenExistsThenFound() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now()); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java index 5eb467f61..5b08c5203 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; @@ -76,6 +77,7 @@ public class JdbcOAuth2AuthorizationServiceTests { private static final String OAUTH2_AUTHORIZATION_SCHEMA_CLOB_DATA_TYPE_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema-clob-data-type.sql"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); private static final String ID = "id"; private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); private static final String PRINCIPAL_NAME = "principal"; @@ -344,6 +346,32 @@ public void findByTokenWhenAccessTokenExistsThenFound() { assertThat(authorization).isEqualTo(result); } + @Test + public void findByTokenWhenIdTokenExistsThenFound() { + when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject("subject") + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(idToken, (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + idToken.getTokenValue(), ID_TOKEN_TOKEN_TYPE); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(idToken.getTokenValue(), null); + assertThat(authorization).isEqualTo(result); + } + @Test public void findByTokenWhenRefreshTokenExistsThenFound() { when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) @@ -494,7 +522,7 @@ private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAut private static final String PK_FILTER = "id = ?"; private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorizationCodeValue = ? OR " + - "accessTokenValue = ? OR refreshTokenValue = ?"; + "accessTokenValue = ? OR oidcIdTokenValue = ? OR refreshTokenValue = ?"; // @formatter:off private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES @@ -539,7 +567,7 @@ public OAuth2Authorization findById(String id) { @Override public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) { - return findBy(UNKNOWN_TOKEN_TYPE_FILTER, token, token, token, token); + return findBy(UNKNOWN_TOKEN_TYPE_FILTER, token, token, token, token, token); } private OAuth2Authorization findBy(String filter, Object... args) { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java index 4600f9f84..a7f39a04f 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -294,6 +294,7 @@ private static final class CustomJdbcRegisteredClientRepository extends JdbcRegi + "clientAuthenticationMethods, " + "authorizationGrantTypes, " + "redirectUris, " + + "postLogoutRedirectUris, " + "scopes, " + "clientSettings," + "tokenSettings"; @@ -305,7 +306,7 @@ private static final class CustomJdbcRegisteredClientRepository extends JdbcRegi // @formatter:off private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME - + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; // @formatter:on private CustomJdbcRegisteredClientRepository(JdbcOperations jdbcOperations) { @@ -353,6 +354,7 @@ public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { Set clientAuthenticationMethods = StringUtils.commaDelimitedListToSet(rs.getString("clientAuthenticationMethods")); Set authorizationGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorizationGrantTypes")); Set redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirectUris")); + Set postLogoutRedirectUris = StringUtils.commaDelimitedListToSet(rs.getString("postLogoutRedirectUris")); Set clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes")); // @formatter:off @@ -369,6 +371,7 @@ public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { authorizationGrantTypes.forEach(grantType -> grantTypes.add(resolveAuthorizationGrantType(grantType)))) .redirectUris((uris) -> uris.addAll(redirectUris)) + .postLogoutRedirectUris((uris) -> uris.addAll(postLogoutRedirectUris)) .scopes((scopes) -> scopes.addAll(clientScopes)); // @formatter:on diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java index 4e445cd67..a6b73a891 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,7 @@ public class RegisteredClientTests { private static final String CLIENT_ID = "client-1"; private static final String CLIENT_SECRET = "secret"; private static final Set REDIRECT_URIS = Collections.singleton("https://example.com"); + private static final Set POST_LOGOUT_REDIRECT_URIS = Collections.singleton("https://example.com/oidc-post-logout"); private static final Set SCOPES = Collections.unmodifiableSet( Stream.of("openid", "profile", "email").collect(Collectors.toSet())); private static final Set CLIENT_AUTHENTICATION_METHODS = @@ -71,6 +72,7 @@ public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() { .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .postLogoutRedirectUris(postLogoutRedirectUris -> postLogoutRedirectUris.addAll(POST_LOGOUT_REDIRECT_URIS)) .scopes(scopes -> scopes.addAll(SCOPES)) .build(); @@ -84,6 +86,7 @@ public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() { .isEqualTo(Collections.singleton(AuthorizationGrantType.AUTHORIZATION_CODE)); assertThat(registration.getClientAuthenticationMethods()).isEqualTo(CLIENT_AUTHENTICATION_METHODS); assertThat(registration.getRedirectUris()).isEqualTo(REDIRECT_URIS); + assertThat(registration.getPostLogoutRedirectUris()).isEqualTo(POST_LOGOUT_REDIRECT_URIS); assertThat(registration.getScopes()).isEqualTo(SCOPES); } @@ -229,6 +232,35 @@ public void buildWhenRedirectUriContainsFragmentThenThrowIllegalArgumentExceptio ).isInstanceOf(IllegalArgumentException.class); } + @Test + public void buildWhenPostLogoutRedirectUriInvalidThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) + .postLogoutRedirectUri("invalid URI") + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenPostLogoutRedirectUriContainsFragmentThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> + RegisteredClient.withId(ID) + .clientId(CLIENT_ID) + .clientSecret(CLIENT_SECRET) + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .redirectUri("https://example.com") + .postLogoutRedirectUri("https://example.com/index#fragment") + .scopes(scopes -> scopes.addAll(SCOPES)) + .build() + ).isInstanceOf(IllegalArgumentException.class); + } + @Test public void buildWhenTwoAuthorizationGrantTypesAreProvidedThenBothAreRegistered() { RegisteredClient registration = RegisteredClient.withId(ID) @@ -345,6 +377,8 @@ public void buildWhenRegisteredClientProvidedThenMakesACopy() { assertThat(registration.getAuthorizationGrantTypes()).isNotSameAs(updated.getAuthorizationGrantTypes()); assertThat(registration.getRedirectUris()).isEqualTo(updated.getRedirectUris()); assertThat(registration.getRedirectUris()).isNotSameAs(updated.getRedirectUris()); + assertThat(registration.getPostLogoutRedirectUris()).isEqualTo(updated.getPostLogoutRedirectUris()); + assertThat(registration.getPostLogoutRedirectUris()).isNotSameAs(updated.getPostLogoutRedirectUris()); assertThat(registration.getScopes()).isEqualTo(updated.getScopes()); assertThat(registration.getScopes()).isNotSameAs(updated.getScopes()); assertThat(registration.getClientSettings()).isEqualTo(updated.getClientSettings()); @@ -360,6 +394,7 @@ public void buildWhenRegisteredClientValuesOverriddenThenPropagated() { String newSecret = "new-secret"; String newScope = "new-scope"; String newRedirectUri = "https://another-redirect-uri.com"; + String newPostLogoutRedirectUri = "https://another-post-logout-redirect-uri.com"; RegisteredClient updated = RegisteredClient.from(registration) .clientName(newName) .clientSecret(newSecret) @@ -371,6 +406,10 @@ public void buildWhenRegisteredClientValuesOverriddenThenPropagated() { redirectUris.clear(); redirectUris.add(newRedirectUri); }) + .postLogoutRedirectUris(postLogoutRedirectUris -> { + postLogoutRedirectUris.clear(); + postLogoutRedirectUris.add(newPostLogoutRedirectUri); + }) .build(); assertThat(registration.getClientName()).isNotEqualTo(newName); @@ -381,6 +420,8 @@ public void buildWhenRegisteredClientValuesOverriddenThenPropagated() { assertThat(updated.getScopes()).containsExactly(newScope); assertThat(registration.getRedirectUris()).doesNotContain(newRedirectUri); assertThat(updated.getRedirectUris()).containsExactly(newRedirectUri); + assertThat(registration.getPostLogoutRedirectUris()).doesNotContain(newPostLogoutRedirectUri); + assertThat(updated.getPostLogoutRedirectUris()).containsExactly(newPostLogoutRedirectUri); } @Test diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java index 6036f0bae..90151201e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/TestRegisteredClients.java @@ -38,6 +38,7 @@ public static RegisteredClient.Builder registeredClient() { .redirectUri("https://example.com/callback-1") .redirectUri("https://example.com/callback-2") .redirectUri("https://example.com/callback-3") + .postLogoutRedirectUri("https://example.com/oidc-post-logout") .scope("scope1"); } @@ -52,6 +53,7 @@ public static RegisteredClient.Builder registeredClient2() { .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .redirectUri("https://example.com") + .postLogoutRedirectUri("https://example.com/oidc-post-logout") .scope("scope1") .scope("scope2"); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcProviderConfigurationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcProviderConfigurationTests.java index 5b304f7fb..f0f4da518 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcProviderConfigurationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcProviderConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,6 +126,7 @@ private ResultMatcher[] defaultConfigurationMatchers() { jsonPath("$.token_endpoint_auth_methods_supported[3]").value(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue()), jsonPath("jwks_uri").value(ISSUER_URL.concat(this.authorizationServerSettings.getJwkSetEndpoint())), jsonPath("userinfo_endpoint").value(ISSUER_URL.concat(this.authorizationServerSettings.getOidcUserInfoEndpoint())), + jsonPath("end_session_endpoint").value(ISSUER_URL.concat(this.authorizationServerSettings.getOidcLogoutEndpoint())), jsonPath("response_types_supported").value(OAuth2AuthorizationResponseType.CODE.getValue()), jsonPath("$.grant_types_supported[0]").value(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()), jsonPath("$.grant_types_supported[1]").value(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()), diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java index 19e8299c7..142bba5ca 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import java.util.function.Supplier; import org.springframework.lang.Nullable; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; /** @@ -26,6 +27,7 @@ public class TestAuthorizationServerContext implements AuthorizationServerContext { private final AuthorizationServerSettings authorizationServerSettings; private final Supplier issuerSupplier; + private SessionRegistry sessionRegistry; public TestAuthorizationServerContext(AuthorizationServerSettings authorizationServerSettings, @Nullable Supplier issuerSupplier) { this.authorizationServerSettings = authorizationServerSettings; @@ -44,4 +46,14 @@ public AuthorizationServerSettings getAuthorizationServerSettings() { return this.authorizationServerSettings; } + @Nullable + @Override + public SessionRegistry getSessionRegistry() { + return this.sessionRegistry; + } + + public void setSessionRegistry(SessionRegistry sessionRegistry) { + this.sessionRegistry = sessionRegistry; + } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistrationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistrationTests.java index 5985ceb6e..de126fdaa 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistrationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcClientRegistrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,6 +58,7 @@ public void buildWhenAllClaimsProvidedThenCreated() throws Exception { .clientSecretExpiresAt(clientSecretExpiresAt) .clientName("client-name") .redirectUri("https://client.example.com") + .postLogoutRedirectUri("https://client.example.com/oidc-post-logout") .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()) .tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256.getName()) .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) @@ -79,6 +80,7 @@ public void buildWhenAllClaimsProvidedThenCreated() throws Exception { assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getPostLogoutRedirectUris()).containsOnly("https://client.example.com/oidc-post-logout"); assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()); assertThat(clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256.getName()); assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); @@ -108,6 +110,7 @@ public void withClaimsWhenClaimsProvidedThenCreated() throws Exception { claims.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); claims.put(OidcClientMetadataClaimNames.CLIENT_NAME, "client-name"); claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, Collections.singletonList("https://client.example.com")); + claims.put(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, Collections.singletonList("https://client.example.com/oidc-post-logout")); claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()); claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, MacAlgorithm.HS256.getName()); claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, Arrays.asList( @@ -128,6 +131,7 @@ public void withClaimsWhenClaimsProvidedThenCreated() throws Exception { assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getPostLogoutRedirectUris()).containsOnly("https://client.example.com/oidc-post-logout"); assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()); assertThat(clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256.getName()); assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); @@ -261,6 +265,41 @@ public void buildWhenRedirectUrisAddingOrRemovingThenCorrectValues() { assertThat(clientRegistration.getRedirectUris()).containsExactly("https://client2.example.com"); } + @Test + public void buildWhenPostLogoutRedirectUrisNotListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, "postLogoutRedirectUris"); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("post_logout_redirect_uris must be of type List"); + } + + @Test + public void buildWhenPostLogoutRedirectUrisEmptyListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS, Collections.emptyList()); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("post_logout_redirect_uris cannot be empty"); + } + + @Test + public void buildWhenPostLogoutRedirectUrisAddingOrRemovingThenCorrectValues() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .postLogoutRedirectUri("https://client1.example.com/oidc-post-logout") + .postLogoutRedirectUris(postLogoutRedirectUris -> { + postLogoutRedirectUris.clear(); + postLogoutRedirectUris.add("https://client2.example.com/oidc-post-logout"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getPostLogoutRedirectUris()).containsExactly("https://client2.example.com/oidc-post-logout"); + } + @Test public void buildWhenGrantTypesNotListThenThrowIllegalArgumentException() { OidcClientRegistration.Builder builder = this.minimalBuilder diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfigurationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfigurationTests.java index 567e0f63a..5ca357998 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfigurationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/OidcProviderConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -62,6 +62,7 @@ public void buildWhenAllRequiredClaimsAndAdditionalClaimsThenCreated() { .userInfoEndpoint("https://example.com/issuer1/userinfo") .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) .clientRegistrationEndpoint("https://example.com/issuer1/connect/register") + .endSessionEndpoint("https://example.com/issuer1/connect/logout") .claim("a-claim", "a-value") .build(); @@ -77,6 +78,7 @@ public void buildWhenAllRequiredClaimsAndAdditionalClaimsThenCreated() { assertThat(providerConfiguration.getUserInfoEndpoint()).isEqualTo(url("https://example.com/issuer1/userinfo")); assertThat(providerConfiguration.getTokenEndpointAuthenticationMethods()).containsExactly(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()); assertThat(providerConfiguration.getClientRegistrationEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/register")); + assertThat(providerConfiguration.getEndSessionEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/logout")); assertThat(providerConfiguration.getClaim("a-claim")).isEqualTo("a-value"); } @@ -118,6 +120,7 @@ public void buildWhenClaimsProvidedThenCreated() { claims.put(OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, Collections.singletonList("RS256")); claims.put(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT, "https://example.com/issuer1/userinfo"); claims.put(OidcProviderMetadataClaimNames.REGISTRATION_ENDPOINT, "https://example.com/issuer1/connect/register"); + claims.put(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT, "https://example.com/issuer1/connect/logout"); claims.put("some-claim", "some-value"); OidcProviderConfiguration providerConfiguration = OidcProviderConfiguration.withClaims(claims).build(); @@ -134,6 +137,7 @@ public void buildWhenClaimsProvidedThenCreated() { assertThat(providerConfiguration.getUserInfoEndpoint()).isEqualTo(url("https://example.com/issuer1/userinfo")); assertThat(providerConfiguration.getTokenEndpointAuthenticationMethods()).isNull(); assertThat(providerConfiguration.getClientRegistrationEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/register")); + assertThat(providerConfiguration.getEndSessionEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/logout")); assertThat(providerConfiguration.getClaim("some-claim")).isEqualTo("some-value"); } @@ -150,6 +154,7 @@ public void buildWhenClaimsProvidedWithUrlsThenCreated() { claims.put(OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED, Collections.singletonList("RS256")); claims.put(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT, url("https://example.com/issuer1/userinfo")); claims.put(OidcProviderMetadataClaimNames.REGISTRATION_ENDPOINT, url("https://example.com/issuer1/connect/register")); + claims.put(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT, url("https://example.com/issuer1/connect/logout")); claims.put("some-claim", "some-value"); OidcProviderConfiguration providerConfiguration = OidcProviderConfiguration.withClaims(claims).build(); @@ -166,6 +171,7 @@ public void buildWhenClaimsProvidedWithUrlsThenCreated() { assertThat(providerConfiguration.getUserInfoEndpoint()).isEqualTo(url("https://example.com/issuer1/userinfo")); assertThat(providerConfiguration.getTokenEndpointAuthenticationMethods()).isNull(); assertThat(providerConfiguration.getClientRegistrationEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/register")); + assertThat(providerConfiguration.getEndSessionEndpoint()).isEqualTo(url("https://example.com/issuer1/connect/logout")); assertThat(providerConfiguration.getClaim("some-claim")).isEqualTo("some-value"); } @@ -412,6 +418,16 @@ public void buildWhenClientRegistrationEndpointNotUrlThenThrowIllegalArgumentExc .withMessage("clientRegistrationEndpoint must be a valid URL"); } + @Test + public void buildWhenEndSessionEndpointNotUrlThenThrowIllegalArgumentException() { + OidcProviderConfiguration.Builder builder = this.minimalConfigurationBuilder + .claims((claims) -> claims.put(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT, "not an url")); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("endSessionEndpoint must be a valid URL"); + } + @Test public void responseTypesWhenAddingOrRemovingThenCorrectValues() { OidcProviderConfiguration configuration = this.minimalConfigurationBuilder diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java index 3c51c8a67..3ebde3116 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -359,6 +359,78 @@ public void authenticateWhenRedirectUriContainsFragmentThenThrowOAuth2Authentica eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); } + @Test + public void authenticateWhenInvalidPostLogoutRedirectUriThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .postLogoutRedirectUri("invalid uri") + .build(); + // @formatter:on + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo("invalid_client_metadata"); + assertThat(error.getDescription()).contains(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); + }); + verify(this.authorizationService).findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenPostLogoutRedirectUriContainsFragmentThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwtClientRegistration(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .postLogoutRedirectUri("https://client.example.com/oidc-post-logout#fragment") + .build(); + // @formatter:on + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo("invalid_client_metadata"); + assertThat(error.getDescription()).contains(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); + }); + verify(this.authorizationService).findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + @Test public void authenticateWhenInvalidTokenEndpointAuthenticationMethodThenThrowOAuth2AuthenticationException() { Jwt jwt = createJwtClientRegistration(); @@ -545,6 +617,7 @@ public void authenticateWhenValidAccessTokenThenReturnClientRegistration() { OidcClientRegistration clientRegistration = OidcClientRegistration.builder() .clientName("client-name") .redirectUri("https://client.example.com") + .postLogoutRedirectUri("https://client.example.com/oidc-post-logout") .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) .scope("scope1") @@ -588,6 +661,7 @@ public void authenticateWhenValidAccessTokenThenReturnClientRegistration() { assertThat(registeredClientResult.getClientName()).isEqualTo(clientRegistration.getClientName()); assertThat(registeredClientResult.getClientAuthenticationMethods()).containsExactly(ClientAuthenticationMethod.CLIENT_SECRET_BASIC); assertThat(registeredClientResult.getRedirectUris()).containsExactly("https://client.example.com"); + assertThat(registeredClientResult.getPostLogoutRedirectUris()).containsExactly("https://client.example.com/oidc-post-logout"); assertThat(registeredClientResult.getAuthorizationGrantTypes()) .containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); assertThat(registeredClientResult.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); @@ -603,6 +677,8 @@ public void authenticateWhenValidAccessTokenThenReturnClientRegistration() { assertThat(clientRegistrationResult.getClientName()).isEqualTo(registeredClientResult.getClientName()); assertThat(clientRegistrationResult.getRedirectUris()) .containsExactlyInAnyOrderElementsOf(registeredClientResult.getRedirectUris()); + assertThat(clientRegistrationResult.getPostLogoutRedirectUris()) + .containsExactlyInAnyOrderElementsOf(registeredClientResult.getPostLogoutRedirectUris()); List grantTypes = new ArrayList<>(); registeredClientResult.getAuthorizationGrantTypes().forEach(authorizationGrantType -> diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java index 65f2ada78..3799ca84e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,9 @@ public void readInternalWhenValidParametersThenSuccess() throws Exception { +" \"redirect_uris\": [\n" + " \"https://client.example.com\"\n" + " ],\n" + +" \"post_logout_redirect_uris\": [\n" + + " \"https://client.example.com/oidc-post-logout\"\n" + + " ],\n" +" \"token_endpoint_auth_method\": \"client_secret_jwt\",\n" +" \"token_endpoint_auth_signing_alg\": \"HS256\",\n" +" \"grant_types\": [\n" @@ -125,6 +128,7 @@ public void readInternalWhenValidParametersThenSuccess() throws Exception { assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(Instant.ofEpochSecond(1607637467L)); assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getPostLogoutRedirectUris()).containsOnly("https://client.example.com/oidc-post-logout"); assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()); assertThat(clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256.getName()); assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); @@ -183,6 +187,7 @@ public void writeInternalWhenClientRegistrationThenSuccess() { .clientSecretExpiresAt(Instant.ofEpochSecond(1607637467)) .clientName("client-name") .redirectUri("https://client.example.com") + .postLogoutRedirectUri("https://client.example.com/oidc-post-logout") .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()) .tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256.getName()) .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) @@ -208,6 +213,7 @@ public void writeInternalWhenClientRegistrationThenSuccess() { assertThat(clientRegistrationResponse).contains("\"client_secret_expires_at\":1607637467"); assertThat(clientRegistrationResponse).contains("\"client_name\":\"client-name\""); assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"https://client.example.com\"]"); + assertThat(clientRegistrationResponse).contains("\"post_logout_redirect_uris\":[\"https://client.example.com/oidc-post-logout\"]"); assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_method\":\"client_secret_jwt\""); assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_signing_alg\":\"HS256\""); assertThat(clientRegistrationResponse).contains("\"grant_types\":[\"authorization_code\",\"client_credentials\"]"); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java index aeca75386..3d175e673 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,6 +92,7 @@ public void doFilterWhenConfigurationRequestThenConfigurationResponse() throws E String tokenEndpoint = "/oauth2/v1/token"; String jwkSetEndpoint = "/oauth2/v1/jwks"; String userInfoEndpoint = "/userinfo"; + String logoutEndpoint = "/connect/logout"; String tokenRevocationEndpoint = "/oauth2/v1/revoke"; String tokenIntrospectionEndpoint = "/oauth2/v1/introspect"; @@ -101,6 +102,7 @@ public void doFilterWhenConfigurationRequestThenConfigurationResponse() throws E .tokenEndpoint(tokenEndpoint) .jwkSetEndpoint(jwkSetEndpoint) .oidcUserInfoEndpoint(userInfoEndpoint) + .oidcLogoutEndpoint(logoutEndpoint) .tokenRevocationEndpoint(tokenRevocationEndpoint) .tokenIntrospectionEndpoint(tokenIntrospectionEndpoint) .build(); @@ -132,6 +134,7 @@ public void doFilterWhenConfigurationRequestThenConfigurationResponse() throws E assertThat(providerConfigurationResponse).contains("\"subject_types_supported\":[\"public\"]"); assertThat(providerConfigurationResponse).contains("\"id_token_signing_alg_values_supported\":[\"RS256\"]"); assertThat(providerConfigurationResponse).contains("\"userinfo_endpoint\":\"https://example.com/issuer1/userinfo\""); + assertThat(providerConfigurationResponse).contains("\"end_session_endpoint\":\"https://example.com/issuer1/connect/logout\""); assertThat(providerConfigurationResponse).contains("\"token_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\",\"client_secret_jwt\",\"private_key_jwt\"]"); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java index 3d27e919f..ccf3884fb 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/settings/AuthorizationServerSettingsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ * Tests for {@link AuthorizationServerSettings}. * * @author Daniel Garnier-Moiroux + * @author Joe Grandja */ public class AuthorizationServerSettingsTests { @@ -39,6 +40,7 @@ public void buildWhenDefaultThenDefaultsAreSet() { assertThat(authorizationServerSettings.getTokenIntrospectionEndpoint()).isEqualTo("/oauth2/introspect"); assertThat(authorizationServerSettings.getOidcClientRegistrationEndpoint()).isEqualTo("/connect/register"); assertThat(authorizationServerSettings.getOidcUserInfoEndpoint()).isEqualTo("/userinfo"); + assertThat(authorizationServerSettings.getOidcLogoutEndpoint()).isEqualTo("/connect/logout"); } @Test @@ -50,6 +52,7 @@ public void buildWhenSettingsProvidedThenSet() { String tokenIntrospectionEndpoint = "/oauth2/v1/introspect"; String oidcClientRegistrationEndpoint = "/connect/v1/register"; String oidcUserInfoEndpoint = "/connect/v1/userinfo"; + String oidcLogoutEndpoint = "/connect/v1/logout"; String issuer = "https://example.com:9000"; AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder() @@ -62,6 +65,7 @@ public void buildWhenSettingsProvidedThenSet() { .tokenRevocationEndpoint(tokenRevocationEndpoint) .oidcClientRegistrationEndpoint(oidcClientRegistrationEndpoint) .oidcUserInfoEndpoint(oidcUserInfoEndpoint) + .oidcLogoutEndpoint(oidcLogoutEndpoint) .build(); assertThat(authorizationServerSettings.getIssuer()).isEqualTo(issuer); @@ -72,6 +76,7 @@ public void buildWhenSettingsProvidedThenSet() { assertThat(authorizationServerSettings.getTokenIntrospectionEndpoint()).isEqualTo(tokenIntrospectionEndpoint); assertThat(authorizationServerSettings.getOidcClientRegistrationEndpoint()).isEqualTo(oidcClientRegistrationEndpoint); assertThat(authorizationServerSettings.getOidcUserInfoEndpoint()).isEqualTo(oidcUserInfoEndpoint); + assertThat(authorizationServerSettings.getOidcLogoutEndpoint()).isEqualTo(oidcLogoutEndpoint); } @Test @@ -81,7 +86,7 @@ public void settingWhenCustomThenSet() { .settings(settings -> settings.put("name2", "value2")) .build(); - assertThat(authorizationServerSettings.getSettings()).hasSize(9); + assertThat(authorizationServerSettings.getSettings()).hasSize(10); assertThat(authorizationServerSettings.getSetting("name1")).isEqualTo("value1"); assertThat(authorizationServerSettings.getSetting("name2")).isEqualTo("value2"); } @@ -142,4 +147,11 @@ public void jwksEndpointWhenNullThenThrowIllegalArgumentException() { .withMessage("value cannot be null"); } + @Test + public void oidcLogoutEndpointWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> AuthorizationServerSettings.builder().oidcLogoutEndpoint(null)) + .withMessage("value cannot be null"); + } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java index c22a5492f..22a937b5c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,12 @@ package org.springframework.security.oauth2.server.authorization.token; import java.security.Principal; +import java.sql.Date; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -27,6 +30,8 @@ import org.mockito.ArgumentCaptor; import org.springframework.security.core.Authentication; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -46,7 +51,6 @@ import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext; import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat; @@ -54,8 +58,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Tests for {@link JwtGenerator}. @@ -67,7 +73,8 @@ public class JwtGeneratorTests { private JwtEncoder jwtEncoder; private OAuth2TokenCustomizer jwtCustomizer; private JwtGenerator jwtGenerator; - private AuthorizationServerContext authorizationServerContext; + private TestAuthorizationServerContext authorizationServerContext; + private SessionRegistry sessionRegistry; @BeforeEach public void setUp() { @@ -77,6 +84,8 @@ public void setUp() { this.jwtGenerator.setJwtCustomizer(this.jwtCustomizer); AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build(); this.authorizationServerContext = new TestAuthorizationServerContext(authorizationServerSettings, null); + this.sessionRegistry = mock(SessionRegistry.class); + this.authorizationServerContext.setSessionRegistry(this.sessionRegistry); } @Test @@ -185,6 +194,20 @@ public void generateWhenIdTokenTypeThenReturnJwt() { } private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { + SessionInformation expectedSession = null; + if (OidcParameterNames.ID_TOKEN.equals(tokenContext.getTokenType().getValue())) { + List sessions = new ArrayList<>(); + sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), + "session3", Date.from(Instant.now()))); + sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), + "session2", Date.from(Instant.now().minus(1, ChronoUnit.HOURS)))); + sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), + "session1", Date.from(Instant.now().minus(2, ChronoUnit.HOURS)))); + expectedSession = sessions.get(0); // Most recent + when(this.sessionRegistry.getAllSessions(eq(tokenContext.getPrincipal().getPrincipal()), eq(false))) + .thenReturn(sessions); + } + this.jwtGenerator.generate(tokenContext); ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); @@ -238,6 +261,8 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { OAuth2AuthorizationRequest.class.getName()); String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce); + assertThat(jwtClaimsSet.getClaim("sid")).isEqualTo(expectedSession.getSessionId()); + assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(expectedSession.getLastRequest()); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index b740460db..09fe2afae 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -58,6 +58,7 @@ import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.WebAuthenticationDetails; +import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -151,6 +152,13 @@ public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentExcep .hasMessage("authenticationFailureHandler cannot be null"); } + @Test + public void setSessionAuthenticationStrategyWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setSessionAuthenticationStrategy(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("sessionAuthenticationStrategy cannot be null"); + } + @Test public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -383,6 +391,31 @@ public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exce verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException)); } + @Test + public void doFilterWhenCustomSessionAuthenticationStrategyThenUsed() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = + new OAuth2AuthorizationCodeRequestAuthenticationToken( + AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode, + registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes()); + authorizationCodeRequestAuthenticationResult.setAuthenticated(true); + when(this.authenticationManager.authenticate(any())) + .thenReturn(authorizationCodeRequestAuthenticationResult); + + SessionAuthenticationStrategy sessionAuthenticationStrategy = mock(SessionAuthenticationStrategy.class); + this.filter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy); + + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + verify(sessionAuthenticationStrategy).onAuthentication(same(authorizationCodeRequestAuthenticationResult), any(), any()); + } + @Test public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); diff --git a/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql index 64d3a4872..5eb82db39 100644 --- a/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql +++ b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql @@ -8,6 +8,7 @@ CREATE TABLE oauth2RegisteredClient ( clientAuthenticationMethods varchar(1000) NOT NULL, authorizationGrantTypes varchar(1000) NOT NULL, redirectUris varchar(1000) DEFAULT NULL, + postLogoutRedirectUris varchar(1000) DEFAULT NULL, scopes varchar(1000) NOT NULL, clientSettings varchar(2000) NOT NULL, tokenSettings varchar(2000) NOT NULL, diff --git a/samples/custom-consent-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java b/samples/custom-consent-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java index a2cf209b6..c22151d21 100644 --- a/samples/custom-consent-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java +++ b/samples/custom-consent-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -94,6 +94,7 @@ public RegisteredClientRepository registeredClientRepository() { .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .redirectUri("http://127.0.0.1:8080/login/oauth2/code/messaging-client-oidc") .redirectUri("http://127.0.0.1:8080/authorized") + .postLogoutRedirectUri("http://127.0.0.1:8080/index") .scope(OidcScopes.OPENID) .scope(OidcScopes.PROFILE) .scope("message.read") diff --git a/samples/custom-consent-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java b/samples/custom-consent-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java index 7266901af..b4370c024 100644 --- a/samples/custom-consent-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java +++ b/samples/custom-consent-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,14 @@ import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.core.session.SessionRegistryImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.session.HttpSessionEventPublisher; import static org.springframework.security.config.Customizer.withDefaults; @@ -58,4 +61,14 @@ UserDetailsService users() { } // @formatter:on + @Bean + SessionRegistry sessionRegistry() { + return new SessionRegistryImpl(); + } + + @Bean + HttpSessionEventPublisher httpSessionEventPublisher() { + return new HttpSessionEventPublisher(); + } + } diff --git a/samples/default-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java b/samples/default-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java index f1d9b7e4f..9abaa8e7b 100644 --- a/samples/default-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java +++ b/samples/default-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -88,6 +88,7 @@ public RegisteredClientRepository registeredClientRepository(JdbcTemplate jdbcTe .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .redirectUri("http://127.0.0.1:8080/login/oauth2/code/messaging-client-oidc") .redirectUri("http://127.0.0.1:8080/authorized") + .postLogoutRedirectUri("http://127.0.0.1:8080/index") .scope(OidcScopes.OPENID) .scope(OidcScopes.PROFILE) .scope("message.read") diff --git a/samples/default-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java b/samples/default-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java index 36c5f7f67..cb7bd1b5b 100644 --- a/samples/default-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java +++ b/samples/default-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,14 @@ import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.core.session.SessionRegistryImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.session.HttpSessionEventPublisher; import static org.springframework.security.config.Customizer.withDefaults; @@ -59,4 +62,14 @@ UserDetailsService users() { } // @formatter:on + @Bean + SessionRegistry sessionRegistry() { + return new SessionRegistryImpl(); + } + + @Bean + HttpSessionEventPublisher httpSessionEventPublisher() { + return new HttpSessionEventPublisher(); + } + } diff --git a/samples/federated-identity-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java b/samples/federated-identity-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java index f8df51481..2dd895edd 100644 --- a/samples/federated-identity-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java +++ b/samples/federated-identity-authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -90,6 +90,7 @@ public RegisteredClientRepository registeredClientRepository(JdbcTemplate jdbcTe .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .redirectUri("http://127.0.0.1:8080/login/oauth2/code/messaging-client-oidc") .redirectUri("http://127.0.0.1:8080/authorized") + .postLogoutRedirectUri("http://127.0.0.1:8080/index") .scope(OidcScopes.OPENID) .scope(OidcScopes.PROFILE) .scope("message.read") diff --git a/samples/federated-identity-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java b/samples/federated-identity-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java index 99706e81d..f4ba86745 100644 --- a/samples/federated-identity-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java +++ b/samples/federated-identity-authorizationserver/src/main/java/sample/config/DefaultSecurityConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,11 +23,14 @@ import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.core.session.SessionRegistryImpl; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.provisioning.InMemoryUserDetailsManager; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.session.HttpSessionEventPublisher; /** * @author Steve Riesenberg @@ -66,4 +69,14 @@ public UserDetailsService users() { } // @formatter:on + @Bean + public SessionRegistry sessionRegistry() { + return new SessionRegistryImpl(); + } + + @Bean + public HttpSessionEventPublisher httpSessionEventPublisher() { + return new HttpSessionEventPublisher(); + } + } diff --git a/samples/messages-client/src/main/java/sample/config/SecurityConfig.java b/samples/messages-client/src/main/java/sample/config/SecurityConfig.java index 6532782fb..29ec525fc 100644 --- a/samples/messages-client/src/main/java/sample/config/SecurityConfig.java +++ b/samples/messages-client/src/main/java/sample/config/SecurityConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,16 @@ */ package sample.config; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityCustomizer; +import org.springframework.security.oauth2.client.oidc.web.logout.OidcClientInitiatedLogoutSuccessHandler; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; import static org.springframework.security.config.Customizer.withDefaults; @@ -32,6 +36,9 @@ @Configuration(proxyBeanMethods = false) public class SecurityConfig { + @Autowired + private ClientRegistrationRepository clientRegistrationRepository; + @Bean WebSecurityCustomizer webSecurityCustomizer() { return (web) -> web.ignoring().requestMatchers("/webjars/**"); @@ -46,9 +53,22 @@ SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { ) .oauth2Login(oauth2Login -> oauth2Login.loginPage("/oauth2/authorization/messaging-client-oidc")) - .oauth2Client(withDefaults()); + .oauth2Client(withDefaults()) + .logout(logout -> + logout.logoutSuccessHandler(oidcLogoutSuccessHandler())); return http.build(); } // @formatter:on + private LogoutSuccessHandler oidcLogoutSuccessHandler() { + OidcClientInitiatedLogoutSuccessHandler oidcLogoutSuccessHandler = + new OidcClientInitiatedLogoutSuccessHandler(this.clientRegistrationRepository); + + // Set the location that the End-User's User Agent will be redirected to + // after the logout has been performed at the Provider + oidcLogoutSuccessHandler.setPostLogoutRedirectUri("{baseUrl}/index"); + + return oidcLogoutSuccessHandler; + } + } diff --git a/samples/messages-client/src/main/resources/templates/index.html b/samples/messages-client/src/main/resources/templates/index.html index edf9d32be..23085c1d4 100644 --- a/samples/messages-client/src/main/resources/templates/index.html +++ b/samples/messages-client/src/main/resources/templates/index.html @@ -1,5 +1,5 @@ - + Spring Security OAuth 2.0 Sample @@ -10,10 +10,11 @@

From 25f4e4e891d52f171878f372166dc18e74700719 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 16 Feb 2023 05:10:15 -0500 Subject: [PATCH 2/5] Add tests for Oidc Logout implementations --- .../OidcLogoutAuthenticationProvider.java | 21 +- .../oidc/web/OidcLogoutEndpointFilter.java | 3 +- .../OidcLogoutAuthenticationConverter.java | 17 +- .../annotation/web/configurers/OidcTests.java | 58 ++- ...OidcLogoutAuthenticationProviderTests.java | 420 ++++++++++++++++++ .../OidcLogoutAuthenticationTokenTests.java | 109 +++++ .../web/OidcLogoutEndpointFilterTests.java | 363 +++++++++++++++ 7 files changed, 979 insertions(+), 12 deletions(-) create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationTokenTests.java create mode 100644 oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilterTests.java diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java index ce0259ed7..b921f0feb 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java @@ -27,7 +27,10 @@ import org.springframework.security.core.session.SessionInformation; import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; @@ -78,7 +81,7 @@ public Authentication authenticate(Authentication authentication) throws Authent OAuth2Authorization authorization = this.authorizationService.findByToken( oidcLogoutAuthentication.getIdToken(), ID_TOKEN_TOKEN_TYPE); if (authorization == null) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint"); } RegisteredClient registeredClient = this.registeredClientRepository.findById( @@ -94,15 +97,15 @@ public Authentication authenticate(Authentication authentication) throws Authent List audClaim = idToken.getAudience(); if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD); } if (StringUtils.hasText(oidcLogoutAuthentication.getClientId()) && !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + throwError(OAuth2ErrorCodes.INVALID_TOKEN, OAuth2ParameterNames.CLIENT_ID); } if (StringUtils.hasText(oidcLogoutAuthentication.getPostLogoutRedirectUri()) && !registeredClient.getPostLogoutRedirectUris().contains(oidcLogoutAuthentication.getPostLogoutRedirectUri())) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri"); } if (this.logger.isTraceEnabled()) { @@ -120,7 +123,7 @@ public Authentication authenticate(Authentication authentication) throws Authent String sidClaim = idToken.getClaim("sid"); if (!StringUtils.hasText(sidClaim) || !sidClaim.equals(sessionInformation.getSessionId())) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + throwError(OAuth2ErrorCodes.INVALID_TOKEN, "sid"); } } } @@ -160,4 +163,12 @@ private static SessionInformation findSessionInformation(Authentication principa return sessionInformation; } + private static void throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error( + errorCode, + "OpenID Connect 1.0 Logout Request Parameter: " + parameterName, + "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling"); + throw new OAuth2AuthenticationException(error); + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java index ff8d9e5ec..6f77357bf 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java @@ -214,8 +214,7 @@ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse r AuthenticationException exception) throws IOException { OAuth2Error error = ((OAuth2AuthenticationException) exception).getError(); - response.sendError(HttpStatus.BAD_REQUEST.value(), - "OpenID Connect 1.0 RP-Initiated Logout Error: " + error.toString()); + response.sendError(HttpStatus.BAD_REQUEST.value(), error.toString()); } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java index 8739fe732..481f8a5a4 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcLogoutAuthenticationConverter.java @@ -25,6 +25,7 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; @@ -56,7 +57,7 @@ public Authentication convert(HttpServletRequest request) { String idTokenHint = request.getParameter("id_token_hint"); if (!StringUtils.hasText(idTokenHint) || request.getParameterValues("id_token_hint").length != 1) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "id_token_hint"); } Authentication principal = SecurityContextHolder.getContext().getAuthentication(); @@ -74,21 +75,21 @@ public Authentication convert(HttpServletRequest request) { String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID); if (StringUtils.hasText(clientId) && parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); } // post_logout_redirect_uri (OPTIONAL) String postLogoutRedirectUri = parameters.getFirst("post_logout_redirect_uri"); if (StringUtils.hasText(postLogoutRedirectUri) && parameters.get("post_logout_redirect_uri").size() != 1) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + throwError(OAuth2ErrorCodes.INVALID_REQUEST, "post_logout_redirect_uri"); } // state (OPTIONAL) String state = parameters.getFirst(OAuth2ParameterNames.STATE); if (StringUtils.hasText(state) && parameters.get(OAuth2ParameterNames.STATE).size() != 1) { - throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); } return new OidcLogoutAuthenticationToken(idTokenHint, principal, @@ -108,4 +109,12 @@ private static MultiValueMap getParameters(HttpServletRequest re return parameters; } + private static void throwError(String errorCode, String parameterName) { + OAuth2Error error = new OAuth2Error( + errorCode, + "OpenID Connect 1.0 Logout Request Parameter: " + parameterName, + "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling"); + throw new OAuth2AuthenticationException(error); + } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java index 00a519df8..80e9f4a52 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java @@ -47,6 +47,7 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; @@ -123,6 +124,7 @@ public class OidcTests { private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize"; private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token"; + private static final String DEFAULT_OIDC_LOGOUT_ENDPOINT_URI = "/connect/logout"; private static final String AUTHORITIES_CLAIM = "authorities"; private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static EmbeddedDatabase db; @@ -216,8 +218,9 @@ public void requestWhenAuthenticationRequestThenTokenResponseIncludesIdToken() t servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); - // Assert user authorities was propagated as claim in ID Token Jwt idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); + + // Assert user authorities was propagated as claim in ID Token List authoritiesClaim = idToken.getClaim(AUTHORITIES_CLAIM); Authentication principal = authorization.getAttribute(Principal.class.getName()); Set userAuthorities = new HashSet<>(); @@ -225,6 +228,59 @@ public void requestWhenAuthenticationRequestThenTokenResponseIncludesIdToken() t userAuthorities.add(authority.getAuthority()); } assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities); + + // Assert sid claim was added in ID Token + assertThat(idToken.getClaim("sid")).isNotNull(); + } + + @Test + public void requestWhenLogoutRequestThenLogout() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + this.registeredClientRepository.save(registeredClient); + + // Login + MultiValueMap authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient); + MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI) + .params(authorizationRequestParameters) + .with(user("user"))) + .andExpect(status().is3xxRedirection()) + .andReturn(); + + MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession(); + assertThat(session.isNew()).isTrue(); + + String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); + String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); + OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); + + // Get ID Token + mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI) + .params(getTokenRequestParameters(registeredClient, authorization)) + .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( + registeredClient.getClientId(), registeredClient.getClientSecret())) + .session(session)) + .andExpect(status().isOk()) + .andReturn(); + + MockHttpServletResponse servletResponse = mvcResult.getResponse(); + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); + OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); + + String idToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); + + // Logout + mvcResult = this.mvc.perform(post(DEFAULT_OIDC_LOGOUT_ENDPOINT_URI) + .param("id_token_hint", idToken) + .session(session)) + .andExpect(status().is3xxRedirection()) + .andReturn(); + redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); + + assertThat(redirectedUrl).matches("/"); + assertThat(session.isInvalid()).isTrue(); } @Test diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java new file mode 100644 index 000000000..f1bfc6a03 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java @@ -0,0 +1,420 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.Date; +import java.util.List; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OidcLogoutAuthenticationProvider}. + * + * @author Joe Grandja + */ +public class OidcLogoutAuthenticationProviderTests { + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private SessionRegistry sessionRegistry; + private AuthorizationServerSettings authorizationServerSettings; + private OidcLogoutAuthenticationProvider authenticationProvider; + + @BeforeEach + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.sessionRegistry = mock(SessionRegistry.class); + this.authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build(); + TestAuthorizationServerContext authorizationServerContext = + new TestAuthorizationServerContext(this.authorizationServerSettings, null); + authorizationServerContext.setSessionRegistry(this.sessionRegistry); + AuthorizationServerContextHolder.setContext(authorizationServerContext); + this.authenticationProvider = new OidcLogoutAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + } + + @AfterEach + public void cleanup() { + AuthorizationServerContextHolder.resetContext(); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationProvider(null, this.authorizationService)) + .withMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationProvider(this.registeredClientRepository, null)) + .withMessage("authorizationService cannot be null"); + } + + @Test + public void supportsWhenTypeOidcLogoutAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OidcLogoutAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenIdTokenNotFoundThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + "id-token", principal, "session-1", null, null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains("id_token_hint"); + }); + + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + } + + @Test + public void authenticateWhenMissingAudienceThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, "session-1", null, null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains(IdTokenClaimNames.AUD); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenInvalidAudienceThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId() + "-invalid")) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, "session-1", null, null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains(IdTokenClaimNames.AUD); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenInvalidClientIdThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, "session-1", registeredClient.getClientId() + "-invalid", null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains(OAuth2ParameterNames.CLIENT_ID); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenInvalidPostLogoutRedirectUriThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, "session-1", registeredClient.getClientId(), + "https://example.com/callback-1-invalid", null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).contains("post_logout_redirect_uri"); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenMissingSidThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + String sessionId = "session-1"; + List sessions = Collections.singletonList( + new SessionInformation(principal.getPrincipal(), sessionId, Date.from(Instant.now()))); + when(this.sessionRegistry.getAllSessions(eq(principal.getPrincipal()), eq(true))) + .thenReturn(sessions); + + principal.setAuthenticated(true); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, sessionId, null, null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains("sid"); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenInvalidSidThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .claim("sid", "other-session") + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + String sessionId = "session-1"; + List sessions = Collections.singletonList( + new SessionInformation(principal.getPrincipal(), sessionId, Date.from(Instant.now()))); + when(this.sessionRegistry.getAllSessions(eq(principal.getPrincipal()), eq(true))) + .thenReturn(sessions); + + principal.setAuthenticated(true); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, sessionId, null, null, null); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getDescription()).contains("sid"); + }); + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + } + + @Test + public void authenticateWhenValidIdTokenThenAuthenticated() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + String sessionId = "session-1"; + OidcIdToken idToken = OidcIdToken.withTokenValue("id-token") + .issuer("https://provider.com") + .subject(principal.getName()) + .audience(Collections.singleton(registeredClient.getClientId())) + .issuedAt(Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .expiresAt(Instant.now().plusSeconds(60).truncatedTo(ChronoUnit.MILLIS)) + .claim("sid", sessionId) + .build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) + .principalName(principal.getName()) + .token(idToken, + (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())) + .build(); + when(this.authorizationService.findByToken(eq(idToken.getTokenValue()), eq(ID_TOKEN_TOKEN_TYPE))) + .thenReturn(authorization); + when(this.registeredClientRepository.findById(eq(authorization.getRegisteredClientId()))) + .thenReturn(registeredClient); + + SessionInformation sessionInformation = new SessionInformation( + principal.getPrincipal(), sessionId, Date.from(Instant.now())); + List sessions = Collections.singletonList(sessionInformation); + when(this.sessionRegistry.getAllSessions(eq(principal.getPrincipal()), eq(true))) + .thenReturn(sessions); + + principal.setAuthenticated(true); + String postLogoutRedirectUri = registeredClient.getPostLogoutRedirectUris().toArray(new String[0])[0]; + String state = "state"; + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + idToken.getTokenValue(), principal, sessionId, registeredClient.getClientId(), postLogoutRedirectUri, state); + + OidcLogoutAuthenticationToken authenticationResult = + (OidcLogoutAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + verify(this.authorizationService).findByToken( + eq(authentication.getIdToken()), eq(ID_TOKEN_TOKEN_TYPE)); + verify(this.registeredClientRepository).findById( + eq(authorization.getRegisteredClientId())); + + assertThat(authenticationResult.getPrincipal()).isEqualTo(principal); + assertThat(authenticationResult.getCredentials().toString()).isEmpty(); + assertThat(authenticationResult.getIdToken()).isEqualTo(idToken.getTokenValue()); + assertThat(authenticationResult.getSessionId()).isEqualTo(sessionInformation.getSessionId()); + assertThat(authenticationResult.getSessionInformation()).isEqualTo(sessionInformation); + assertThat(authenticationResult.getClientId()).isEqualTo(registeredClient.getClientId()); + assertThat(authenticationResult.getPostLogoutRedirectUri()).isEqualTo(postLogoutRedirectUri); + assertThat(authenticationResult.getState()).isEqualTo(state); + assertThat(authenticationResult.isAuthenticated()).isTrue(); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationTokenTests.java new file mode 100644 index 000000000..f6af09d2e --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationTokenTests.java @@ -0,0 +1,109 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.sql.Date; +import java.time.Instant; + +import org.junit.jupiter.api.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.session.SessionInformation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OidcLogoutAuthenticationToken}. + * + * @author Joe Grandja + */ +public class OidcLogoutAuthenticationTokenTests { + private final String idToken = "id-token"; + private final TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + private final String sessionId = "session-1"; + private final SessionInformation sessionInformation = new SessionInformation(this.principal, "session-2", Date.from(Instant.now())); + private final String clientId = "client-1"; + private final String postLogoutRedirectUri = "https://example.com/oidc-post-logout"; + private final String state = "state-1"; + + @Test + public void constructorWhenIdTokenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + null, this.principal, this.sessionId, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("idToken cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + null, this.principal, this.sessionInformation, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("idToken cannot be empty"); + } + + @Test + public void constructorWhenIdTokenEmptyThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + "", this.principal, this.sessionId, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("idToken cannot be empty"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + "", this.principal, this.sessionInformation, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("idToken cannot be empty"); + } + + @Test + public void constructorWhenPrincipalNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + this.idToken, null, this.sessionId, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("principal cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationToken( + this.idToken, null, this.sessionInformation, this.clientId, this.postLogoutRedirectUri, this.state)) + .withMessage("principal cannot be null"); + } + + @Test + public void constructorWhenSessionIdProvidedThenCreated() { + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + this.idToken, this.principal, this.sessionId, this.clientId, this.postLogoutRedirectUri, this.state); + assertThat(authentication.getPrincipal()).isEqualTo(this.principal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getIdToken()).isEqualTo(this.idToken); + assertThat(authentication.getSessionId()).isEqualTo(this.sessionId); + assertThat(authentication.getSessionInformation()).isNull(); + assertThat(authentication.getClientId()).isEqualTo(this.clientId); + assertThat(authentication.getPostLogoutRedirectUri()).isEqualTo(this.postLogoutRedirectUri); + assertThat(authentication.getState()).isEqualTo(this.state); + assertThat(authentication.isAuthenticated()).isFalse(); + } + + @Test + public void constructorWhenSessionInformationProvidedThenCreated() { + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + this.idToken, this.principal, this.sessionInformation, this.clientId, this.postLogoutRedirectUri, this.state); + assertThat(authentication.getPrincipal()).isEqualTo(this.principal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getIdToken()).isEqualTo(this.idToken); + assertThat(authentication.getSessionId()).isEqualTo(this.sessionInformation.getSessionId()); + assertThat(authentication.getSessionInformation()).isEqualTo(this.sessionInformation); + assertThat(authentication.getClientId()).isEqualTo(this.clientId); + assertThat(authentication.getPostLogoutRedirectUri()).isEqualTo(this.postLogoutRedirectUri); + assertThat(authentication.getState()).isEqualTo(this.state); + assertThat(authentication.isAuthenticated()).isTrue(); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilterTests.java new file mode 100644 index 000000000..e973b5efe --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilterTests.java @@ -0,0 +1,363 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.web; + +import java.time.Instant; +import java.util.Date; +import java.util.function.Consumer; + +import jakarta.servlet.FilterChain; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.http.HttpStatus; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.mock.web.MockHttpSession; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcLogoutAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OidcLogoutEndpointFilter}. + * + * @author Joe Grandja + */ +public class OidcLogoutEndpointFilterTests { + private static final String DEFAULT_OIDC_LOGOUT_ENDPOINT_URI = "/connect/logout"; + private AuthenticationManager authenticationManager; + private OidcLogoutEndpointFilter filter; + private TestingAuthenticationToken principal; + + @BeforeEach + public void setUp() { + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OidcLogoutEndpointFilter(this.authenticationManager); + this.principal = new TestingAuthenticationToken("principal", "credentials"); + this.principal.setAuthenticated(true); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(this.principal); + SecurityContextHolder.setContext(securityContext); + } + + @AfterEach + public void cleanup() { + SecurityContextHolder.clearContext(); + } + + @Test + public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OidcLogoutEndpointFilter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationManager cannot be null"); + } + + @Test + public void constructorWhenLogoutEndpointUriNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OidcLogoutEndpointFilter(this.authenticationManager, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("logoutEndpointUri cannot be empty"); + } + + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationConverter cannot be null"); + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationFailureHandler cannot be null"); + } + + @Test + public void doFilterWhenNotLogoutRequestThenNotProcessed() throws Exception { + String requestUri = "/path"; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void doFilterWhenLogoutRequestMissingIdTokenHintThenInvalidRequestError() throws Exception { + doFilterWhenRequestInvalidParameterThenError( + createLogoutRequest(TestRegisteredClients.registeredClient().build()), + "id_token_hint", + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.removeParameter("id_token_hint")); + } + + @Test + public void doFilterWhenLogoutRequestMultipleIdTokenHintThenInvalidRequestError() throws Exception { + doFilterWhenRequestInvalidParameterThenError( + createLogoutRequest(TestRegisteredClients.registeredClient().build()), + "id_token_hint", + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter("id_token_hint", "id-token-2")); + } + + @Test + public void doFilterWhenLogoutRequestMultipleClientIdThenInvalidRequestError() throws Exception { + doFilterWhenRequestInvalidParameterThenError( + createLogoutRequest(TestRegisteredClients.registeredClient().build()), + OAuth2ParameterNames.CLIENT_ID, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2")); + } + + @Test + public void doFilterWhenLogoutRequestMultiplePostLogoutRedirectUriThenInvalidRequestError() throws Exception { + doFilterWhenRequestInvalidParameterThenError( + createLogoutRequest(TestRegisteredClients.registeredClient().build()), + "post_logout_redirect_uri", + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter("post_logout_redirect_uri", "https://example.com/callback-4")); + } + + @Test + public void doFilterWhenLogoutRequestMultipleStateThenInvalidRequestError() throws Exception { + doFilterWhenRequestInvalidParameterThenError( + createLogoutRequest(TestRegisteredClients.registeredClient().build()), + OAuth2ParameterNames.STATE, + OAuth2ErrorCodes.INVALID_REQUEST, + request -> request.addParameter(OAuth2ParameterNames.STATE, "state-2")); + } + + private void doFilterWhenRequestInvalidParameterThenError(MockHttpServletRequest request, + String parameterName, String errorCode, Consumer requestConsumer) throws Exception { + + requestConsumer.accept(request); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo("[" + errorCode + "] OpenID Connect 1.0 Logout Request Parameter: " + parameterName); + } + + @Test + public void doFilterWhenLogoutRequestAuthenticationExceptionThenErrorResponse() throws Exception { + OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri"); + when(this.authenticationManager.authenticate(any())) + .thenThrow(new OAuth2AuthenticationException(error)); + + MockHttpServletRequest request = createLogoutRequest(TestRegisteredClients.registeredClient().build()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); + assertThat(response.getErrorMessage()).isEqualTo(error.toString()); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.principal); + } + + @Test + public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception { + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + "id-token", this.principal, (SessionInformation) null, null, null, null); + + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + when(authenticationConverter.convert(any())).thenReturn(authentication); + this.filter.setAuthenticationConverter(authenticationConverter); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(authentication); + + MockHttpServletRequest request = createLogoutRequest(TestRegisteredClients.registeredClient().build()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationConverter).convert(any()); + verify(this.authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception { + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + "id-token", this.principal, (SessionInformation) null, null, null, null); + + AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(authentication); + + MockHttpServletRequest request = createLogoutRequest(TestRegisteredClients.registeredClient().build()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authenticationManager).authenticate(any()); + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), same(authentication)); + verifyNoInteractions(filterChain); + } + + @Test + public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception { + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + when(this.authenticationManager.authenticate(any())) + .thenThrow(new AuthenticationServiceException("AuthenticationServiceException")); + + MockHttpServletRequest request = createLogoutRequest(TestRegisteredClients.registeredClient().build()); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + ArgumentCaptor authenticationExceptionCaptor = ArgumentCaptor.forClass(AuthenticationException.class); + verify(this.authenticationManager).authenticate(any()); + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), authenticationExceptionCaptor.capture()); + verifyNoInteractions(filterChain); + + assertThat(authenticationExceptionCaptor.getValue()) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) + .satisfies(error -> { + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).contains("AuthenticationServiceException"); + }); + } + + @Test + public void doFilterWhenLogoutRequestAuthenticatedThenLogout() throws Exception { + MockHttpServletRequest request = createLogoutRequest(TestRegisteredClients.registeredClient().build()); + MockHttpSession session = (MockHttpSession) request.getSession(true); + + SessionInformation sessionInformation = new SessionInformation( + this.principal, session.getId(), Date.from(Instant.now())); + + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + "id-token", this.principal, sessionInformation, null, null, null); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(authentication); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo("/"); + assertThat(session.isInvalid()).isTrue(); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + @Test + public void doFilterWhenLogoutRequestAuthenticatedWithPostLogoutRedirectUriThenPostLogoutRedirect() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + MockHttpServletRequest request = createLogoutRequest(registeredClient); + MockHttpSession session = (MockHttpSession) request.getSession(true); + + SessionInformation sessionInformation = new SessionInformation( + this.principal, session.getId(), Date.from(Instant.now())); + + String postLogoutRedirectUri = registeredClient.getPostLogoutRedirectUris().iterator().next(); + String state = "state-1"; + OidcLogoutAuthenticationToken authentication = new OidcLogoutAuthenticationToken( + "id-token", this.principal, sessionInformation, + registeredClient.getClientId(), postLogoutRedirectUri, state); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(authentication); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(this.authenticationManager).authenticate(any()); + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); + assertThat(response.getRedirectedUrl()).isEqualTo(postLogoutRedirectUri + "?state=" + state); + assertThat(session.isInvalid()).isTrue(); + assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull(); + } + + private static MockHttpServletRequest createLogoutRequest(RegisteredClient registeredClient) { + String requestUri = DEFAULT_OIDC_LOGOUT_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + + request.addParameter("id_token_hint", "id-token"); + request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); + request.addParameter("post_logout_redirect_uri", registeredClient.getPostLogoutRedirectUris().iterator().next()); + request.addParameter(OAuth2ParameterNames.STATE, "state"); + + return request; + } + +} From 3e66287b97367501cfceac9e75c9b397c696ae45 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Feb 2023 12:16:46 -0500 Subject: [PATCH 3/5] Polish state encoding --- .../oidc/web/OidcLogoutEndpointFilter.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java index 6f77357bf..0292d6173 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcLogoutEndpointFilter.java @@ -16,8 +16,7 @@ package org.springframework.security.oauth2.server.authorization.oidc.web; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; +import java.nio.charset.StandardCharsets; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -53,6 +52,7 @@ import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.util.UriComponentsBuilder; +import org.springframework.web.util.UriUtils; /** * A {@code Filter} that processes OpenID Connect 1.0 RP-Initiated Logout Requests. @@ -195,13 +195,11 @@ private void performLogout(HttpServletRequest request, HttpServletResponse respo .fromUriString(oidcLogoutAuthentication.getPostLogoutRedirectUri()); String redirectUri; if (StringUtils.hasText(oidcLogoutAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); - Map queryParams = new HashMap<>(); - queryParams.put(OAuth2ParameterNames.STATE, oidcLogoutAuthentication.getState()); - redirectUri = uriBuilder.build(queryParams).toString(); - } else { - redirectUri = uriBuilder.toUriString(); + uriBuilder.queryParam( + OAuth2ParameterNames.STATE, + UriUtils.encode(oidcLogoutAuthentication.getState(), StandardCharsets.UTF_8)); } + redirectUri = uriBuilder.build(true).toUriString(); // build(true) -> Components are explicitly encoded this.redirectStrategy.sendRedirect(request, response, redirectUri); } else { // Perform default redirect From a423607b384ae061a5270c8f4faf0be995178bb9 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Feb 2023 12:42:43 -0500 Subject: [PATCH 4/5] Polish logging in OidcLogoutAuthenticationProvider --- .../authentication/OidcLogoutAuthenticationProvider.java | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java index b921f0feb..38980252e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java @@ -84,11 +84,15 @@ public Authentication authenticate(Authentication authentication) throws Authent throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint"); } + if (this.logger.isTraceEnabled()) { + this.logger.trace("Retrieved authorization with ID Token"); + } + RegisteredClient registeredClient = this.registeredClientRepository.findById( authorization.getRegisteredClientId()); if (this.logger.isTraceEnabled()) { - this.logger.trace("Retrieved authorization with ID Token"); + this.logger.trace("Retrieved registered client"); } OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken(); From 3312fefd92f66df5ff76652029a479adbe4a88af Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 17 Feb 2023 16:36:14 -0500 Subject: [PATCH 5/5] Improved usages of SessionRegistry --- ...thorizationCodeAuthenticationProvider.java | 45 ++++++++++++++++++- .../AuthorizationServerContextFilter.java | 23 +--------- .../OAuth2AuthorizationServerConfigurer.java | 3 -- .../OAuth2TokenEndpointConfigurer.java | 3 ++ .../OidcLogoutEndpointConfigurer.java | 3 +- .../context/AuthorizationServerContext.java | 15 +------ .../OidcLogoutAuthenticationProvider.java | 13 +++--- .../authorization/token/JwtGenerator.java | 30 +++---------- ...zationCodeAuthenticationProviderTests.java | 36 +++++++++++++-- .../TestAuthorizationServerContext.java | 14 +----- ...OidcLogoutAuthenticationProviderTests.java | 14 ++++-- .../token/JwtGeneratorTests.java | 37 +++++---------- 12 files changed, 121 insertions(+), 115 deletions(-) diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index b795674c8..c9d9de0c6 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ package org.springframework.security.oauth2.server.authorization.authentication; import java.security.Principal; +import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.commons.logging.Log; @@ -27,6 +30,8 @@ import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -52,6 +57,7 @@ import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; @@ -79,6 +85,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth private final Log logger = LogFactory.getLog(getClass()); private final OAuth2AuthorizationService authorizationService; private final OAuth2TokenGenerator tokenGenerator; + private SessionRegistry sessionRegistry; /** * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters. @@ -149,10 +156,12 @@ public Authentication authenticate(Authentication authentication) throws Authent this.logger.trace("Validated token request parameters"); } + Authentication principal = authorization.getAttribute(Principal.class.getName()); + // @formatter:off DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) - .principal(authorization.getAttribute(Principal.class.getName())) + .principal(principal) .authorizationServerContext(AuthorizationServerContextHolder.getContext()) .authorization(authorization) .authorizedScopes(authorization.getAuthorizedScopes()) @@ -210,6 +219,10 @@ public Authentication authenticate(Authentication authentication) throws Authent // ----- ID token ----- OidcIdToken idToken; if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) { + SessionInformation sessionInformation = getSessionInformation(principal); + if (sessionInformation != null) { + tokenContextBuilder.put(SessionInformation.class, sessionInformation); + } // @formatter:off tokenContext = tokenContextBuilder .tokenType(ID_TOKEN_TOKEN_TYPE) @@ -265,4 +278,32 @@ public boolean supports(Class authentication) { return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication); } + /** + * Sets the {@link SessionRegistry} used to track OpenID Connect sessions. + * + * @param sessionRegistry the {@link SessionRegistry} used to track OpenID Connect sessions + * @since 1.1.0 + */ + public void setSessionRegistry(SessionRegistry sessionRegistry) { + Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); + this.sessionRegistry = sessionRegistry; + } + + private SessionInformation getSessionInformation(Authentication principal) { + SessionInformation sessionInformation = null; + if (this.sessionRegistry != null) { + List sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), false); + if (!CollectionUtils.isEmpty(sessions)) { + sessionInformation = sessions.get(0); + if (sessions.size() > 1) { + // Get the most recent session + sessions = new ArrayList<>(sessions); + sessions.sort(Comparator.comparing(SessionInformation::getLastRequest)); + sessionInformation = sessions.get(sessions.size() - 1); + } + } + } + return sessionInformation; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java index b08faee8d..aa6cee69f 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/AuthorizationServerContextFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,8 +23,6 @@ import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import org.springframework.lang.Nullable; -import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext; import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; @@ -44,27 +42,21 @@ */ final class AuthorizationServerContextFilter extends OncePerRequestFilter { private final AuthorizationServerSettings authorizationServerSettings; - private SessionRegistry sessionRegistry; AuthorizationServerContextFilter(AuthorizationServerSettings authorizationServerSettings) { Assert.notNull(authorizationServerSettings, "authorizationServerSettings cannot be null"); this.authorizationServerSettings = authorizationServerSettings; } - void setSessionRegistry(SessionRegistry sessionRegistry) { - this.sessionRegistry = sessionRegistry; - } - @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { try { - DefaultAuthorizationServerContext authorizationServerContext = + AuthorizationServerContext authorizationServerContext = new DefaultAuthorizationServerContext( () -> resolveIssuer(this.authorizationServerSettings, request), this.authorizationServerSettings); - authorizationServerContext.setSessionRegistry(this.sessionRegistry); AuthorizationServerContextHolder.setContext(authorizationServerContext); filterChain.doFilter(request, response); } finally { @@ -92,7 +84,6 @@ private static String getContextPath(HttpServletRequest request) { private static final class DefaultAuthorizationServerContext implements AuthorizationServerContext { private final Supplier issuerSupplier; private final AuthorizationServerSettings authorizationServerSettings; - private SessionRegistry sessionRegistry; private DefaultAuthorizationServerContext(Supplier issuerSupplier, AuthorizationServerSettings authorizationServerSettings) { this.issuerSupplier = issuerSupplier; @@ -109,16 +100,6 @@ public AuthorizationServerSettings getAuthorizationServerSettings() { return this.authorizationServerSettings; } - @Nullable - @Override - public SessionRegistry getSessionRegistry() { - return this.sessionRegistry; - } - - private void setSessionRegistry(SessionRegistry sessionRegistry) { - this.sessionRegistry = sessionRegistry; - } - } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java index 5edb80b01..19c2fd9eb 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationServerConfigurer.java @@ -304,9 +304,6 @@ public void configure(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); AuthorizationServerContextFilter authorizationServerContextFilter = new AuthorizationServerContextFilter(authorizationServerSettings); - if (isOidcEnabled()) { - authorizationServerContextFilter.setSessionRegistry(OAuth2ConfigurerUtils.getSessionRegistry(httpSecurity)); - } httpSecurity.addFilterAfter(postProcess(authorizationServerContextFilter), SecurityContextHolderFilter.class); JWKSource jwkSource = OAuth2ConfigurerUtils.getJwkSource(httpSecurity); diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java index c9a9cfa92..1fb5813e7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2TokenEndpointConfigurer.java @@ -26,6 +26,7 @@ import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Token; @@ -216,9 +217,11 @@ private static List createDefaultAuthenticationProviders OAuth2AuthorizationService authorizationService = OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity); OAuth2TokenGenerator tokenGenerator = OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity); + SessionRegistry sessionRegistry = OAuth2ConfigurerUtils.getSessionRegistry(httpSecurity); OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(authorizationService, tokenGenerator); + authorizationCodeAuthenticationProvider.setSessionRegistry(sessionRegistry); authenticationProviders.add(authorizationCodeAuthenticationProvider); OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider = diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java index 04eb0a256..92580b6ec 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcLogoutEndpointConfigurer.java @@ -209,7 +209,8 @@ private static List createDefaultAuthenticationProviders OidcLogoutAuthenticationProvider oidcLogoutAuthenticationProvider = new OidcLogoutAuthenticationProvider( OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), + OAuth2ConfigurerUtils.getSessionRegistry(httpSecurity)); authenticationProviders.add(oidcLogoutAuthenticationProvider); return authenticationProviders; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java index 2888b3e8a..a12ef305a 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/context/AuthorizationServerContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ */ package org.springframework.security.oauth2.server.authorization.context; -import org.springframework.lang.Nullable; -import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; /** @@ -43,15 +41,4 @@ public interface AuthorizationServerContext { */ AuthorizationServerSettings getAuthorizationServerSettings(); - /** - * Returns the {@link SessionRegistry} used to track OpenID Connect sessions or {@code null} if OpenID Connect is disabled. - * - * @return the {@link SessionRegistry} used to track OpenID Connect sessions or {@code null} if OpenID Connect is disabled - * @since 1.1.0 - */ - @Nullable - default SessionRegistry getSessionRegistry() { - return null; - } - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java index 38980252e..7de465dd7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProvider.java @@ -38,7 +38,6 @@ import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -50,6 +49,7 @@ * @since 1.1.0 * @see RegisteredClientRepository * @see OAuth2AuthorizationService + * @see SessionRegistry * @see 2. RP-Initiated Logout */ public final class OidcLogoutAuthenticationProvider implements AuthenticationProvider { @@ -58,19 +58,23 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro private final Log logger = LogFactory.getLog(getClass()); private final RegisteredClientRepository registeredClientRepository; private final OAuth2AuthorizationService authorizationService; + private final SessionRegistry sessionRegistry; /** * Constructs an {@code OidcLogoutAuthenticationProvider} using the provided parameters. * * @param registeredClientRepository the repository of registered clients * @param authorizationService the authorization service + * @param sessionRegistry the {@link SessionRegistry} used to track OpenID Connect sessions */ public OidcLogoutAuthenticationProvider(RegisteredClientRepository registeredClientRepository, - OAuth2AuthorizationService authorizationService) { + OAuth2AuthorizationService authorizationService, SessionRegistry sessionRegistry) { Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); Assert.notNull(authorizationService, "authorizationService cannot be null"); + Assert.notNull(sessionRegistry, "sessionRegistry cannot be null"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; + this.sessionRegistry = sessionRegistry; } @Override @@ -152,9 +156,8 @@ private static boolean isPrincipalAuthenticated(Authentication principal) { principal.isAuthenticated(); } - private static SessionInformation findSessionInformation(Authentication principal, String sessionId) { - SessionRegistry sessionRegistry = AuthorizationServerContextHolder.getContext().getSessionRegistry(); - List sessions = sessionRegistry.getAllSessions(principal.getPrincipal(), true); + private SessionInformation findSessionInformation(Authentication principal, String sessionId) { + List sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), true); SessionInformation sessionInformation = null; if (!CollectionUtils.isEmpty(sessions)) { for (SessionInformation session : sessions) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java index 98895c80a..66ed255b4 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java @@ -17,14 +17,10 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; import java.util.Collections; -import java.util.Comparator; -import java.util.List; import org.springframework.lang.Nullable; import org.springframework.security.core.session.SessionInformation; -import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -131,7 +127,7 @@ public Jwt generate(OAuth2TokenContext context) { claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce); } } - SessionInformation sessionInformation = getSessionInformation(context); + SessionInformation sessionInformation = context.get(SessionInformation.class); if (sessionInformation != null) { claimsBuilder.claim("sid", sessionInformation.getSessionId()); claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest()); @@ -156,6 +152,12 @@ public Jwt generate(OAuth2TokenContext context) { if (context.getAuthorizationGrant() != null) { jwtContextBuilder.authorizationGrant(context.getAuthorizationGrant()); } + if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { + SessionInformation sessionInformation = context.get(SessionInformation.class); + if (sessionInformation != null) { + jwtContextBuilder.put(SessionInformation.class, sessionInformation); + } + } // @formatter:on JwtEncodingContext jwtContext = jwtContextBuilder.build(); @@ -170,24 +172,6 @@ public Jwt generate(OAuth2TokenContext context) { return jwt; } - private static SessionInformation getSessionInformation(OAuth2TokenContext context) { - SessionInformation sessionInformation = null; - if (context.getAuthorizationServerContext().getSessionRegistry() != null) { - SessionRegistry sessionRegistry = context.getAuthorizationServerContext().getSessionRegistry(); - List sessions = sessionRegistry.getAllSessions(context.getPrincipal().getPrincipal(), false); - if (!CollectionUtils.isEmpty(sessions)) { - sessionInformation = sessions.get(0); - if (sessions.size() > 1) { - // Get the most recent session - sessions = new ArrayList<>(sessions); - sessions.sort(Comparator.comparing(SessionInformation::getLastRequest)); - sessionInformation = sessions.get(sessions.size() - 1); - } - } - } - return sessionInformation; - } - /** * Sets the {@link OAuth2TokenCustomizer} that customizes the * {@link JwtEncodingContext#getJwsHeader() JWS headers} and/or diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index 0f010f1a1..d0bc3613e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,11 @@ import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Date; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; @@ -31,6 +34,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.session.SessionInformation; +import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -95,6 +100,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests { private OAuth2TokenCustomizer jwtCustomizer; private OAuth2TokenCustomizer accessTokenCustomizer; private OAuth2TokenGenerator tokenGenerator; + private SessionRegistry sessionRegistry; private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider; @BeforeEach @@ -116,8 +122,10 @@ public OAuth2Token generate(OAuth2TokenContext context) { return delegatingTokenGenerator.generate(context); } }); + this.sessionRegistry = mock(SessionRegistry.class); this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider( this.authorizationService, this.tokenGenerator); + this.authenticationProvider.setSessionRegistry(this.sessionRegistry); AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build(); AuthorizationServerContextHolder.setContext(new TestAuthorizationServerContext(authorizationServerSettings, null)); } @@ -146,6 +154,13 @@ public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturn assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue(); } + @Test + public void setSessionRegistryWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authenticationProvider.setSessionRegistry(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("sessionRegistry cannot be null"); + } + @Test public void authenticateWhenClientPrincipalNotOAuth2ClientAuthenticationTokenThenThrowOAuth2AuthenticationException() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); @@ -456,6 +471,19 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() when(this.jwtEncoder.encode(any())).thenReturn(createJwt()); + Authentication principal = authorization.getAttribute(Principal.class.getName()); + + List sessions = new ArrayList<>(); + sessions.add(new SessionInformation(principal.getPrincipal(), + "session3", Date.from(Instant.now()))); + sessions.add(new SessionInformation(principal.getPrincipal(), + "session2", Date.from(Instant.now().minus(1, ChronoUnit.HOURS)))); + sessions.add(new SessionInformation(principal.getPrincipal(), + "session1", Date.from(Instant.now().minus(2, ChronoUnit.HOURS)))); + SessionInformation expectedSession = sessions.get(0); // Most recent + when(this.sessionRegistry.getAllSessions(eq(principal.getPrincipal()), eq(false))) + .thenReturn(sessions); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); @@ -464,7 +492,7 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() // Access Token context JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0); assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient); - assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(accessTokenContext.getPrincipal()).isEqualTo(principal); assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization); assertThat(accessTokenContext.getAuthorization().getAccessToken()).isNull(); assertThat(accessTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes()); @@ -480,13 +508,15 @@ public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() // ID Token context JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); - assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(idTokenContext.getPrincipal()).isEqualTo(principal); assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization); assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotNull(); assertThat(idTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes()); assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN); assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); assertThat(idTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + SessionInformation sessionInformation = idTokenContext.get(SessionInformation.class); + assertThat(sessionInformation).isNotNull().isSameAs(expectedSession); assertThat(idTokenContext.getJwsHeader()).isNotNull(); assertThat(idTokenContext.getClaims()).isNotNull(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java index 142bba5ca..19e8299c7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/context/TestAuthorizationServerContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,6 @@ import java.util.function.Supplier; import org.springframework.lang.Nullable; -import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; /** @@ -27,7 +26,6 @@ public class TestAuthorizationServerContext implements AuthorizationServerContext { private final AuthorizationServerSettings authorizationServerSettings; private final Supplier issuerSupplier; - private SessionRegistry sessionRegistry; public TestAuthorizationServerContext(AuthorizationServerSettings authorizationServerSettings, @Nullable Supplier issuerSupplier) { this.authorizationServerSettings = authorizationServerSettings; @@ -46,14 +44,4 @@ public AuthorizationServerSettings getAuthorizationServerSettings() { return this.authorizationServerSettings; } - @Nullable - @Override - public SessionRegistry getSessionRegistry() { - return this.sessionRegistry; - } - - public void setSessionRegistry(SessionRegistry sessionRegistry) { - this.sessionRegistry = sessionRegistry; - } - } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java index f1bfc6a03..1bb0db18b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcLogoutAuthenticationProviderTests.java @@ -74,10 +74,9 @@ public void setUp() { this.authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build(); TestAuthorizationServerContext authorizationServerContext = new TestAuthorizationServerContext(this.authorizationServerSettings, null); - authorizationServerContext.setSessionRegistry(this.sessionRegistry); AuthorizationServerContextHolder.setContext(authorizationServerContext); this.authenticationProvider = new OidcLogoutAuthenticationProvider( - this.registeredClientRepository, this.authorizationService); + this.registeredClientRepository, this.authorizationService, this.sessionRegistry); } @AfterEach @@ -88,17 +87,24 @@ public void cleanup() { @Test public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new OidcLogoutAuthenticationProvider(null, this.authorizationService)) + .isThrownBy(() -> new OidcLogoutAuthenticationProvider(null, this.authorizationService, this.sessionRegistry)) .withMessage("registeredClientRepository cannot be null"); } @Test public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> new OidcLogoutAuthenticationProvider(this.registeredClientRepository, null)) + .isThrownBy(() -> new OidcLogoutAuthenticationProvider(this.registeredClientRepository, null, this.sessionRegistry)) .withMessage("authorizationService cannot be null"); } + @Test + public void constructorWhenSessionRegistryNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcLogoutAuthenticationProvider(this.registeredClientRepository, this.authorizationService, null)) + .withMessage("sessionRegistry cannot be null"); + } + @Test public void supportsWhenTypeOidcLogoutAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OidcLogoutAuthenticationToken.class)).isTrue(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java index 22a937b5c..5d934686a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java @@ -16,12 +16,10 @@ package org.springframework.security.oauth2.server.authorization.token; import java.security.Principal; -import java.sql.Date; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.ArrayList; +import java.util.Date; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; @@ -31,7 +29,6 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.session.SessionInformation; -import org.springframework.security.core.session.SessionRegistry; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; @@ -58,10 +55,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; /** * Tests for {@link JwtGenerator}. @@ -74,7 +69,6 @@ public class JwtGeneratorTests { private OAuth2TokenCustomizer jwtCustomizer; private JwtGenerator jwtGenerator; private TestAuthorizationServerContext authorizationServerContext; - private SessionRegistry sessionRegistry; @BeforeEach public void setUp() { @@ -84,8 +78,6 @@ public void setUp() { this.jwtGenerator.setJwtCustomizer(this.jwtCustomizer); AuthorizationServerSettings authorizationServerSettings = AuthorizationServerSettings.builder().issuer("https://provider.com").build(); this.authorizationServerContext = new TestAuthorizationServerContext(authorizationServerSettings, null); - this.sessionRegistry = mock(SessionRegistry.class); - this.authorizationServerContext.setSessionRegistry(this.sessionRegistry); } @Test @@ -177,16 +169,21 @@ public void generateWhenIdTokenTypeThenReturnJwt() { OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null); + Authentication principal = authorization.getAttribute(Principal.class.getName()); + SessionInformation sessionInformation = new SessionInformation( + principal.getPrincipal(), "session1", Date.from(Instant.now().minus(2, ChronoUnit.HOURS))); + // @formatter:off OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder() .registeredClient(registeredClient) - .principal(authorization.getAttribute(Principal.class.getName())) + .principal(principal) .authorizationServerContext(this.authorizationServerContext) .authorization(authorization) .authorizedScopes(authorization.getAuthorizedScopes()) .tokenType(ID_TOKEN_TOKEN_TYPE) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizationGrant(authentication) + .put(SessionInformation.class, sessionInformation) .build(); // @formatter:on @@ -194,20 +191,6 @@ public void generateWhenIdTokenTypeThenReturnJwt() { } private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { - SessionInformation expectedSession = null; - if (OidcParameterNames.ID_TOKEN.equals(tokenContext.getTokenType().getValue())) { - List sessions = new ArrayList<>(); - sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), - "session3", Date.from(Instant.now()))); - sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), - "session2", Date.from(Instant.now().minus(1, ChronoUnit.HOURS)))); - sessions.add(new SessionInformation(tokenContext.getPrincipal().getPrincipal(), - "session1", Date.from(Instant.now().minus(2, ChronoUnit.HOURS)))); - expectedSession = sessions.get(0); // Most recent - when(this.sessionRegistry.getAllSessions(eq(tokenContext.getPrincipal().getPrincipal()), eq(false))) - .thenReturn(sessions); - } - this.jwtGenerator.generate(tokenContext); ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); @@ -261,8 +244,10 @@ private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) { OAuth2AuthorizationRequest.class.getName()); String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE); assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce); - assertThat(jwtClaimsSet.getClaim("sid")).isEqualTo(expectedSession.getSessionId()); - assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(expectedSession.getLastRequest()); + + SessionInformation sessionInformation = tokenContext.get(SessionInformation.class); + assertThat(jwtClaimsSet.getClaim("sid")).isEqualTo(sessionInformation.getSessionId()); + assertThat(jwtClaimsSet.getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(sessionInformation.getLastRequest()); } }