Skip to content

Allowing for a @Bean of type OAuth2AccessTokenResponseClient<OAuth2Cl… #6606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -15,20 +15,22 @@
*/
package org.springframework.security.config.annotation.web.configuration;

import java.util.List;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.util.ClassUtils;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

import java.util.List;

/**
* {@link Configuration} for OAuth 2.0 Client support.
*
Expand Down Expand Up @@ -60,13 +62,17 @@ public String[] selectImports(AnnotationMetadata importingClassMetadata) {
static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient;

@Override
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) {
OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
new OAuth2AuthorizedClientArgumentResolver(
this.clientRegistrationRepository, this.authorizedClientRepository);
if (this.accessTokenResponseClient != null) {
authorizedClientArgumentResolver.setClientCredentialsTokenResponseClient(accessTokenResponseClient);
}
argumentResolvers.add(authorizedClientArgumentResolver);
}
}
Expand All @@ -84,5 +90,11 @@ public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository>
this.authorizedClientRepository = authorizedClientRepositories.get(0);
}
}

@Autowired
public void setAccessTokenResponseClient(
Optional<OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest>> accessTokenResponseClient) {
accessTokenResponseClient.ifPresent(client -> this.accessTokenResponseClient = client);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
Expand All @@ -15,6 +15,21 @@
*/
package org.springframework.security.config.annotation.web.configuration;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import javax.servlet.http.HttpServletRequest;
import org.junit.Rule;
import org.junit.Test;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
Expand All @@ -26,26 +41,18 @@
import org.springframework.security.config.test.SpringTestRule;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;

import javax.servlet.http.HttpServletRequest;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

/**
* Tests for {@link OAuth2ClientConfiguration}.
*
Expand All @@ -64,26 +71,66 @@ public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws
String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");

ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class);
when(authorizedClientRepository.loadAuthorizedClient(
eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))).thenReturn(authorizedClient);
eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class)))
.thenReturn(authorizedClient);

OAuth2AccessToken accessToken = mock(OAuth2AccessToken.class);
when(authorizedClient.getAccessToken()).thenReturn(accessToken);

OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);

OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository;
OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();

this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
.andExpect(status().isOk())
.andExpect(content().string("resolved"));
verifyZeroInteractions(accessTokenResponseClient);
}

@Test
public void requestWhenAuthorizedClientNotFoundAndClientCredentialsThenTokenResponseClientIsUsed() throws Exception {
String clientRegistrationId = "client1";
String principalName = "user1";
TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password");

ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
OAuth2AccessTokenResponseClient accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);

ClientRegistration clientRegistration = clientCredentials().registrationId(clientRegistrationId).build();
when(clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(clientRegistration);

OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(300)
.build();
when(accessTokenResponseClient.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class)))
.thenReturn(accessTokenResponse);

OAuth2AuthorizedClientArgumentResolverConfig.CLIENT_REGISTRATION_REPOSITORY = clientRegistrationRepository;
OAuth2AuthorizedClientArgumentResolverConfig.AUTHORIZED_CLIENT_REPOSITORY = authorizedClientRepository;
OAuth2AuthorizedClientArgumentResolverConfig.ACCESS_TOKEN_RESPONSE_CLIENT = accessTokenResponseClient;
this.spring.register(OAuth2AuthorizedClientArgumentResolverConfig.class).autowire();

this.mockMvc.perform(get("/authorized-client").with(authentication(authentication)))
.andExpect(status().isOk())
.andExpect(content().string("resolved"));
.andExpect(status().isOk())
.andExpect(content().string("resolved"));
verify(accessTokenResponseClient, times(1)).getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class));
}

@EnableWebMvc
@EnableWebSecurity
static class OAuth2AuthorizedClientArgumentResolverConfig extends WebSecurityConfigurerAdapter {
static ClientRegistrationRepository CLIENT_REGISTRATION_REPOSITORY;
static OAuth2AuthorizedClientRepository AUTHORIZED_CLIENT_REPOSITORY;
static OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> ACCESS_TOKEN_RESPONSE_CLIENT;

@Override
protected void configure(HttpSecurity http) throws Exception {
Expand All @@ -100,13 +147,18 @@ public String authorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAut

@Bean
public ClientRegistrationRepository clientRegistrationRepository() {
return mock(ClientRegistrationRepository.class);
return CLIENT_REGISTRATION_REPOSITORY;
}

@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return AUTHORIZED_CLIENT_REPOSITORY;
}

@Bean
public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
return ACCESS_TOKEN_RESPONSE_CLIENT;
}
}

// gh-5321
Expand Down Expand Up @@ -147,6 +199,11 @@ public OAuth2AuthorizedClientRepository authorizedClientRepository1() {
public OAuth2AuthorizedClientRepository authorizedClientRepository2() {
return mock(OAuth2AuthorizedClientRepository.class);
}

@Bean
public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
return mock(OAuth2AccessTokenResponseClient.class);
}
}

@Test
Expand Down Expand Up @@ -208,5 +265,53 @@ public ClientRegistrationRepository clientRegistrationRepository2() {
public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return mock(OAuth2AuthorizedClientRepository.class);
}

@Bean
public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient() {
return mock(OAuth2AccessTokenResponseClient.class);
}
}

@Test
public void loadContextWhenAccessTokenResponseClientRegisteredTwiceThenThrowNoUniqueBeanDefinitionException() {
assertThatThrownBy(() -> this.spring.register(AccessTokenResponseClientRegisteredTwiceConfig.class).autowire())
.hasRootCauseInstanceOf(NoUniqueBeanDefinitionException.class)
.hasMessageContaining("expected single matching bean but found 2: accessTokenResponseClient1,accessTokenResponseClient2");
}

@EnableWebMvc
@EnableWebSecurity
static class AccessTokenResponseClientRegisteredTwiceConfig extends WebSecurityConfigurerAdapter {

@Override
protected void configure(HttpSecurity http) throws Exception {
// @formatter:off
http
.authorizeRequests()
.anyRequest().authenticated()
.and()
.oauth2Login();
// @formatter:on
}

@Bean
public ClientRegistrationRepository clientRegistrationRepository() {
return mock(ClientRegistrationRepository.class);
}

@Bean
public OAuth2AuthorizedClientRepository authorizedClientRepository() {
return mock(OAuth2AuthorizedClientRepository.class);
}

@Bean
public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient1() {
return mock(OAuth2AccessTokenResponseClient.class);
}

@Bean
public OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient2() {
return mock(OAuth2AccessTokenResponseClient.class);
}
}
}