From b84689860d9384db832784a81260fa4b6f70f5ea Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 4 Sep 2019 04:34:44 -0600 Subject: [PATCH 1/2] Grant Individual Authorities From Claims Fixes gh-7339 --- .../client/oidc/userinfo/OidcUserService.java | 27 +++-- .../userinfo/DefaultOAuth2UserService.java | 52 ++++++++- .../oidc/userinfo/OidcUserServiceTests.java | 100 ++++++++++++++++-- .../DefaultOAuth2UserServiceTests.java | 74 +++++++++++++ .../oauth2/core/oidc/TestOidcIdTokens.java | 39 +++++++ 5 files changed, 268 insertions(+), 24 deletions(-) create mode 100644 oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index 9644f6620cf..ec374fbffd1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -15,6 +15,17 @@ */ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; @@ -38,15 +49,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; - /** * An implementation of an {@link OAuth2UserService} that supports OpenID Connect 1.0 Provider's. * @@ -94,6 +96,7 @@ public class OidcUserService implements OAuth2UserService oauth2UserAuthorities = Collections.emptyList(); if (this.shouldRetrieveUserInfo(userRequest)) { OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest); @@ -106,6 +109,7 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); } userInfo = new OidcUserInfo(claims); + oauth2UserAuthorities = oauth2User.getAuthorities(); // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse @@ -127,8 +131,9 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio } } - Set authorities = Collections.singleton( - new OidcUserAuthority(userRequest.getIdToken(), userInfo)); + Set authorities = new LinkedHashSet<>(); + authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); + authorities.addAll(oauth2UserAuthorities); OidcUser user; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index 8033b93127b..c274ebd816e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -15,13 +15,22 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -35,10 +44,6 @@ import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Collections; -import java.util.Map; -import java.util.Set; - /** * An implementation of an {@link OAuth2UserService} that supports standard OAuth 2.0 Provider's. *

@@ -66,6 +71,9 @@ public class DefaultOAuth2UserService implements OAuth2UserService> PARAMETERIZED_RESPONSE_TYPE = new ParameterizedTypeReference>() {}; + private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = + Arrays.asList("scope", "scp"); + private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); private RestOperations restOperations; @@ -127,7 +135,11 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic } Map userAttributes = response.getBody(); - Set authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes)); + Set authorities = new LinkedHashSet<>(); + authorities.add(new OAuth2UserAuthority(userAttributes)); + for (String authority : getAuthorities(() -> userAttributes)) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); } @@ -160,4 +172,34 @@ public final void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } + + private String getAuthoritiesClaimName(ClaimAccessor claims) { + for (String claimName : WELL_KNOWN_AUTHORITIES_CLAIM_NAMES) { + if (claims.containsClaim(claimName)) { + return claimName; + } + } + return null; + } + + private Collection getAuthorities(ClaimAccessor claims) { + String claimName = getAuthoritiesClaimName(claims); + + if (claimName == null) { + return Collections.emptyList(); + } + + Object authorities = claims.getClaim(claimName); + if (authorities instanceof String) { + if (StringUtils.hasText((String) authorities)) { + return Arrays.asList(((String) authorities).split(" ")); + } else { + return Collections.emptyList(); + } + } else if (authorities instanceof Collection) { + return (Collection) authorities; + } + + return Collections.emptyList(); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index aad464c0e0a..473666e70f1 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -15,6 +15,15 @@ */ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -23,12 +32,20 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; + +import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; +import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -39,20 +56,20 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; - -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.function.Function; +import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.nullable; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; +import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * Tests for {@link OidcUserService}. @@ -481,6 +498,73 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + body.put("scope", "message:read message:write"); + OidcUserService userService = new OidcUserService(); + userService.setOauth2UserService(withMockResponse(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration(). + userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUser user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + body.put("scp", Arrays.asList("message:read", "message:write")); + OidcUserService userService = new OidcUserService(); + userService.setOauth2UserService(withMockResponse(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration(). + userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUser user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + body.put("authorities", Arrays.asList("message:read", "message:write")); + OidcUserService userService = new OidcUserService(); + userService.setOauth2UserService(withMockResponse(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration(). + userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUser user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); + } + + private DefaultOAuth2UserService withMockResponse(Map response) { + ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); + Converter> requestEntityConverter = mock(Converter.class); + RestOperations rest = mock(RestOperations.class); + when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) + .thenReturn(responseEntity); + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRequestEntityConverter(requestEntityConverter); + userService.setRestOperations(rest); + return userService; + } + private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 99a6718e075..346ac21dfce 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -15,6 +15,10 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; import java.util.concurrent.TimeUnit; import okhttp3.mockwebserver.MockResponse; @@ -26,18 +30,30 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; +import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; @@ -325,6 +341,64 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue()); } + @Test + public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("scope", "message:read message:write"); + DefaultOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("scp", Arrays.asList("message:read", "message:write")); + DefaultOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("authorities", Arrays.asList("message:read", "message:write")); + DefaultOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + + private DefaultOAuth2UserService withMockResponse(Map response) { + ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); + Converter> requestEntityConverter = mock(Converter.class); + RestOperations rest = mock(RestOperations.class); + when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) + .thenReturn(responseEntity); + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRequestEntityConverter(requestEntityConverter); + userService.setRestOperations(rest); + return userService; + } + private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java new file mode 100644 index 00000000000..a99020ed267 --- /dev/null +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.oidc; + +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +/** + * Test {@link OidcIdToken}s + * + * @author Josh Cummings + */ +public class TestOidcIdTokens { + public static OidcIdToken idToken() { + return idToken(Collections.singletonMap("id", "id")); + } + + public static OidcIdToken idToken(Map claims) { + return new OidcIdToken("token", + Instant.now(), + Instant.now().plusSeconds(86400), + claims); + } +} From a916115a8bfcbc3ab90808afa3d5334a5db733ef Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 4 Sep 2019 11:59:35 -0600 Subject: [PATCH 2/2] Add Authorities from Access Token --- .../OidcReactiveOAuth2UserService.java | 6 ++ .../client/oidc/userinfo/OidcUserService.java | 11 +-- .../userinfo/DefaultOAuth2UserService.java | 41 +--------- .../DefaultReactiveOAuth2UserService.java | 6 ++ .../OidcReactiveOAuth2UserServiceTests.java | 69 +++++++++++++--- .../oidc/userinfo/OidcUserServiceTests.java | 59 ++------------ .../DefaultOAuth2UserServiceTests.java | 30 ++----- ...DefaultReactiveOAuth2UserServiceTests.java | 81 +++++++++++++++++-- 8 files changed, 167 insertions(+), 136 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index fadc5f8b922..5893d799d84 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -18,10 +18,12 @@ import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -99,6 +101,10 @@ public Mono loadUser(OidcUserRequest userRequest) throws OAuth2Authent OidcUserInfo userInfo = authority.getUserInfo(); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } String userNameAttributeName = userRequest.getClientRegistration() .getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); if (StringUtils.hasText(userNameAttributeName)) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index ec374fbffd1..3fb35f5ada7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -17,8 +17,6 @@ import java.time.Instant; import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashSet; @@ -29,11 +27,13 @@ import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -96,7 +96,6 @@ public class OidcUserService implements OAuth2UserService oauth2UserAuthorities = Collections.emptyList(); if (this.shouldRetrieveUserInfo(userRequest)) { OAuth2User oauth2User = this.oauth2UserService.loadUser(userRequest); @@ -109,7 +108,6 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio claims = DEFAULT_CLAIM_TYPE_CONVERTER.convert(oauth2User.getAttributes()); } userInfo = new OidcUserInfo(claims); - oauth2UserAuthorities = oauth2User.getAuthorities(); // https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse @@ -133,7 +131,10 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio Set authorities = new LinkedHashSet<>(); authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); - authorities.addAll(oauth2UserAuthorities); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } OidcUser user; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index c274ebd816e..0ca84ca6d44 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -15,9 +15,6 @@ */ package org.springframework.security.oauth2.client.userinfo; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; @@ -30,7 +27,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.core.ClaimAccessor; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -71,9 +68,6 @@ public class DefaultOAuth2UserService implements OAuth2UserService> PARAMETERIZED_RESPONSE_TYPE = new ParameterizedTypeReference>() {}; - private static final Collection WELL_KNOWN_AUTHORITIES_CLAIM_NAMES = - Arrays.asList("scope", "scp"); - private Converter> requestEntityConverter = new OAuth2UserRequestEntityConverter(); private RestOperations restOperations; @@ -137,7 +131,8 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic Map userAttributes = response.getBody(); Set authorities = new LinkedHashSet<>(); authorities.add(new OAuth2UserAuthority(userAttributes)); - for (String authority : getAuthorities(() -> userAttributes)) { + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); } @@ -172,34 +167,4 @@ public final void setRestOperations(RestOperations restOperations) { Assert.notNull(restOperations, "restOperations cannot be null"); this.restOperations = restOperations; } - - private String getAuthoritiesClaimName(ClaimAccessor claims) { - for (String claimName : WELL_KNOWN_AUTHORITIES_CLAIM_NAMES) { - if (claims.containsClaim(claimName)) { - return claimName; - } - } - return null; - } - - private Collection getAuthorities(ClaimAccessor claims) { - String claimName = getAuthoritiesClaimName(claims); - - if (claimName == null) { - return Collections.emptyList(); - } - - Object authorities = claims.getClaim(claimName); - if (authorities instanceof String) { - if (StringUtils.hasText((String) authorities)) { - return Arrays.asList(((String) authorities).split(" ")); - } else { - return Collections.emptyList(); - } - } else if (authorities instanceof Collection) { - return (Collection) authorities; - } - - return Collections.emptyList(); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index b808cb80688..108d837240b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -28,7 +28,9 @@ import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; @@ -131,6 +133,10 @@ public Mono loadUser(OAuth2UserRequest userRequest) GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); }) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index a713515b693..0876c482477 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -16,13 +16,25 @@ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; @@ -36,17 +48,20 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; +import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * @author Rob Winch @@ -178,6 +193,38 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration())); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), scopes("message:read", "message:write"), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), noScopes(), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 473666e70f1..6b414375ec9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -16,7 +16,6 @@ package org.springframework.security.oauth2.client.oidc.userinfo; import java.time.Instant; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; @@ -33,19 +32,14 @@ import org.junit.Test; import org.junit.rules.ExpectedException; -import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; -import org.springframework.http.RequestEntity; -import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; -import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; @@ -56,18 +50,16 @@ import org.springframework.security.oauth2.core.oidc.StandardClaimNames; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; -import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.nullable; import static org.mockito.Mockito.same; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; @@ -272,7 +264,7 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1"); assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); + assertThat(user.getAuthorities().size()).isEqualTo(3); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class); OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next(); assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER"); @@ -499,15 +491,13 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { } @Test - public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); body.put("sub", "test-subject"); - body.put("scope", "message:read message:write"); OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + scopes("message:read", "message:write"), idToken(body)); OidcUser user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(3); @@ -518,34 +508,13 @@ public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { } @Test - public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); body.put("sub", "test-subject"); - body.put("scp", Arrays.asList("message:read", "message:write")); OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); - OidcUser user = userService.loadUser(request); - - assertThat(user.getAuthorities()).hasSize(3); - Iterator authorities = user.getAuthorities().iterator(); - assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); - } - - @Test - public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { - Map body = new HashMap<>(); - body.put("id", "id"); - body.put("sub", "test-subject"); - body.put("authorities", Arrays.asList("message:read", "message:write")); - OidcUserService userService = new OidcUserService(); - userService.setOauth2UserService(withMockResponse(body)); - OidcUserRequest request = new OidcUserRequest(clientRegistration(). - userInfoUri("uri").build(), scopes("profile"), idToken(body)); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + noScopes(), idToken(body)); OidcUser user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(1); @@ -553,18 +522,6 @@ public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); } - private DefaultOAuth2UserService withMockResponse(Map response) { - ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); - Converter> requestEntityConverter = mock(Converter.class); - RestOperations rest = mock(RestOperations.class); - when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) - .thenReturn(responseEntity); - DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); - userService.setRequestEntityConverter(requestEntityConverter); - userService.setRestOperations(rest); - return userService; - } - private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 346ac21dfce..f42e167b83e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -15,7 +15,6 @@ */ package org.springframework.security.oauth2.client.userinfo; -import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.Map; @@ -56,6 +55,7 @@ import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * Tests for {@link DefaultOAuth2UserService}. @@ -342,12 +342,12 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos } @Test - public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); - body.put("scope", "message:read message:write"); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); OAuth2User user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(3); @@ -358,28 +358,12 @@ public void loadUserWhenAttributesContainScopeThenIndividualScopeAuthorities() { } @Test - public void loadUserWhenAttributesContainScpThenIndividualScopeAuthorities() { + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { Map body = new HashMap<>(); body.put("id", "id"); - body.put("scp", Arrays.asList("message:read", "message:write")); DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); - OAuth2User user = userService.loadUser(request); - - assertThat(user.getAuthorities()).hasSize(3); - Iterator authorities = user.getAuthorities().iterator(); - assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); - assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); - } - - @Test - public void loadUserWhenAttributesDoesNotContainScopesThenNoScopeAuthorities() { - Map body = new HashMap<>(); - body.put("id", "id"); - body.put("authorities", Arrays.asList("message:read", "message:write")); - DefaultOAuth2UserService userService = withMockResponse(body); - OAuth2UserRequest request = new OAuth2UserRequest(clientRegistration().build(), noScopes()); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); OAuth2User user = userService.loadUser(request); assertThat(user.getAuthorities()).hasSize(1); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java index b856c80252d..8d1725b37ce 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java @@ -16,15 +16,30 @@ package org.springframework.security.oauth2.client.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthenticationMethod; @@ -32,14 +47,17 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; - -import okhttp3.mockwebserver.RecordedRequest; -import reactor.test.StepVerifier; - -import java.time.Duration; -import java.time.Instant; - -import static org.assertj.core.api.Assertions.*; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * @author Rob Winch @@ -211,6 +229,53 @@ public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceExceptio .isInstanceOf(AuthenticationServiceException.class); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + + private DefaultReactiveOAuth2UserService withMockResponse(Map body) { + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.post()); + WebClient rest = spy(WebClient.class); + WebClient.ResponseSpec clientResponse = mock(WebClient.ResponseSpec.class); + when(rest.get()).thenReturn(spec); + when(spec.retrieve()).thenReturn(clientResponse); + when(clientResponse.onStatus(any(Predicate.class), any(Function.class))) + .thenReturn(clientResponse); + when(clientResponse.bodyToMono(any(ParameterizedTypeReference.class))) + .thenReturn(Mono.just(body)); + + DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); + userService.setWebClient(rest); + return userService; + } + private OAuth2UserRequest oauth2UserRequest() { return new OAuth2UserRequest(this.clientRegistration.build(), this.accessToken); }