diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java index fadc5f8b922..5893d799d84 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserService.java @@ -18,10 +18,12 @@ import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -99,6 +101,10 @@ public Mono loadUser(OidcUserRequest userRequest) throws OAuth2Authent OidcUserInfo userInfo = authority.getUserInfo(); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } String userNameAttributeName = userRequest.getClientRegistration() .getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName(); if (StringUtils.hasText(userNameAttributeName)) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java index 9644f6620cf..3fb35f5ada7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserService.java @@ -15,14 +15,25 @@ */ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.OAuth2UserService; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.converter.ClaimConversionService; @@ -38,15 +49,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; -import java.time.Instant; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; - /** * An implementation of an {@link OAuth2UserService} that supports OpenID Connect 1.0 Provider's. * @@ -127,8 +129,12 @@ public OidcUser loadUser(OidcUserRequest userRequest) throws OAuth2Authenticatio } } - Set authorities = Collections.singleton( - new OidcUserAuthority(userRequest.getIdToken(), userInfo)); + Set authorities = new LinkedHashSet<>(); + authorities.add(new OidcUserAuthority(userRequest.getIdToken(), userInfo)); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } OidcUser user; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index 8033b93127b..0ca84ca6d44 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -15,13 +15,19 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.RequestEntity; import org.springframework.http.ResponseEntity; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -35,10 +41,6 @@ import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestTemplate; -import java.util.Collections; -import java.util.Map; -import java.util.Set; - /** * An implementation of an {@link OAuth2UserService} that supports standard OAuth 2.0 Provider's. *

@@ -127,7 +129,12 @@ public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2Authentic } Map userAttributes = response.getBody(); - Set authorities = Collections.singleton(new OAuth2UserAuthority(userAttributes)); + Set authorities = new LinkedHashSet<>(); + authorities.add(new OAuth2UserAuthority(userAttributes)); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index b808cb80688..108d837240b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -28,7 +28,9 @@ import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.core.AuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; @@ -131,6 +133,10 @@ public Mono loadUser(OAuth2UserRequest userRequest) GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); authorities.add(authority); + OAuth2AccessToken token = userRequest.getAccessToken(); + for (String scope : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + scope)); + } return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); }) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index a713515b693..0876c482477 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -16,13 +16,25 @@ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import reactor.core.publisher.Mono; + import org.springframework.core.convert.converter.Converter; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; @@ -36,17 +48,20 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.user.DefaultOAuth2User; import org.springframework.security.oauth2.core.user.OAuth2User; -import reactor.core.publisher.Mono; - -import java.time.Duration; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; +import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * @author Rob Winch @@ -178,6 +193,38 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(userRequest.getClientRegistration())); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), scopes("message:read", "message:write"), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + OidcUserRequest request = new OidcUserRequest( + clientRegistration().build(), noScopes(), idToken(body)); + OidcUser user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index aad464c0e0a..6b414375ec9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -15,6 +15,14 @@ */ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -23,10 +31,13 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; + import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService; import org.springframework.security.oauth2.core.AuthenticationMethod; @@ -40,19 +51,17 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; -import java.time.Instant; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.function.Function; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.CoreMatchers.containsString; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; +import static org.springframework.security.oauth2.core.oidc.TestOidcIdTokens.idToken; /** * Tests for {@link OidcUserService}. @@ -255,7 +264,7 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { assertThat(user.getUserInfo().getPreferredUsername()).isEqualTo("user1"); assertThat(user.getUserInfo().getEmail()).isEqualTo("user1@example.com"); - assertThat(user.getAuthorities().size()).isEqualTo(1); + assertThat(user.getAuthorities().size()).isEqualTo(3); assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OidcUserAuthority.class); OidcUserAuthority userAuthority = (OidcUserAuthority) user.getAuthorities().iterator().next(); assertThat(userAuthority.getAuthority()).isEqualTo("ROLE_USER"); @@ -481,6 +490,38 @@ public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcUserService userService = new OidcUserService(); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + scopes("message:read", "message:write"), idToken(body)); + OidcUser user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + body.put("sub", "test-subject"); + OidcUserService userService = new OidcUserService(); + OidcUserRequest request = new OidcUserRequest(clientRegistration().build(), + noScopes(), idToken(body)); + OidcUser user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OidcUserAuthority.class); + } + private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 99a6718e075..f42e167b83e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -15,6 +15,9 @@ */ package org.springframework.security.oauth2.client.userinfo; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; import java.util.concurrent.TimeUnit; import okhttp3.mockwebserver.MockResponse; @@ -26,20 +29,33 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; +import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.nullable; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * Tests for {@link DefaultOAuth2UserService}. @@ -325,6 +341,48 @@ public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPos assertThat(request.getBody().readUtf8()).isEqualTo("access_token=" + this.accessToken.getTokenValue()); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2User user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + + private DefaultOAuth2UserService withMockResponse(Map response) { + ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); + Converter> requestEntityConverter = mock(Converter.class); + RestOperations rest = mock(RestOperations.class); + when(rest.exchange(nullable(RequestEntity.class), any(ParameterizedTypeReference.class))) + .thenReturn(responseEntity); + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setRequestEntityConverter(requestEntityConverter); + userService.setRestOperations(rest); + return userService; + } + private MockResponse jsonResponse(String json) { return new MockResponse() .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java index b856c80252d..8d1725b37ce 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java @@ -16,15 +16,30 @@ package org.springframework.security.oauth2.client.userinfo; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; import org.junit.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AuthenticationServiceException; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthenticationMethod; @@ -32,14 +47,17 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.user.OAuth2User; import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; - -import okhttp3.mockwebserver.RecordedRequest; -import reactor.test.StepVerifier; - -import java.time.Duration; -import java.time.Instant; - -import static org.assertj.core.api.Assertions.*; +import org.springframework.web.reactive.function.client.WebClient; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.noScopes; +import static org.springframework.security.oauth2.core.TestOAuth2AccessTokens.scopes; /** * @author Rob Winch @@ -211,6 +229,53 @@ public void loadUserWhenUserInfoUriInvalidThenThrowAuthenticationServiceExceptio .isInstanceOf(AuthenticationServiceException.class); } + @Test + public void loadUserWhenTokenContainsScopesThenIndividualScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), scopes("message:read", "message:write")); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(3); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:read")); + assertThat(authorities.next()).isEqualTo(new SimpleGrantedAuthority("SCOPE_message:write")); + } + + @Test + public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { + Map body = new HashMap<>(); + body.put("id", "id"); + DefaultReactiveOAuth2UserService userService = withMockResponse(body); + OAuth2UserRequest request = new OAuth2UserRequest( + clientRegistration().build(), noScopes()); + OAuth2User user = userService.loadUser(request).block(); + + assertThat(user.getAuthorities()).hasSize(1); + Iterator authorities = user.getAuthorities().iterator(); + assertThat(authorities.next()).isInstanceOf(OAuth2UserAuthority.class); + } + + private DefaultReactiveOAuth2UserService withMockResponse(Map body) { + WebClient real = WebClient.builder().build(); + WebClient.RequestHeadersUriSpec spec = spy(real.post()); + WebClient rest = spy(WebClient.class); + WebClient.ResponseSpec clientResponse = mock(WebClient.ResponseSpec.class); + when(rest.get()).thenReturn(spec); + when(spec.retrieve()).thenReturn(clientResponse); + when(clientResponse.onStatus(any(Predicate.class), any(Function.class))) + .thenReturn(clientResponse); + when(clientResponse.bodyToMono(any(ParameterizedTypeReference.class))) + .thenReturn(Mono.just(body)); + + DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); + userService.setWebClient(rest); + return userService; + } + private OAuth2UserRequest oauth2UserRequest() { return new OAuth2UserRequest(this.clientRegistration.build(), this.accessToken); } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java new file mode 100644 index 00000000000..a99020ed267 --- /dev/null +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/oidc/TestOidcIdTokens.java @@ -0,0 +1,39 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.oidc; + +import java.time.Instant; +import java.util.Collections; +import java.util.Map; + +/** + * Test {@link OidcIdToken}s + * + * @author Josh Cummings + */ +public class TestOidcIdTokens { + public static OidcIdToken idToken() { + return idToken(Collections.singletonMap("id", "id")); + } + + public static OidcIdToken idToken(Map claims) { + return new OidcIdToken("token", + Instant.now(), + Instant.now().plusSeconds(86400), + claims); + } +}