Skip to content

Add argument resolver for SecurityContext #14449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,21 @@ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) {

@Override
public boolean supportsParameter(MethodParameter parameter) {
return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
return isMonoSecurityContext(parameter)
|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
}

private boolean isMonoSecurityContext(MethodParameter parameter) {
boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType());
if (isParameterPublisher) {
ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
Class<?> genericType = resolvableType.resolveGeneric(0);
if (genericType == null) {
return false;
}
return SecurityContext.class.isAssignableFrom(genericType);
}
return false;
}

@Override
Expand All @@ -136,6 +150,14 @@ public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> messag

private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
if (contextAnno != null) {
return resolveSecurityContextFromAnnotation(contextAnno, parameter, securityContext);
}
return securityContext;
}

private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext contextAnno, MethodParameter parameter,
Object securityContext) {
String expressionToParse = contextAnno.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,24 @@ public void supportsParameterWhenAuthenticationPrincipalThenTrue() {
assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContext"))).isTrue();
}

@Test
public void supportsParameterWhenMonoSecurityContextNoAnnotationThenTrue() {
assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContextNoAnnotation")))
.isTrue();
}

@Test
public void supportsParameterWhenMonoCustomSecurityContextNoAnnotationThenTrue() {
assertThat(
this.resolver.supportsParameter(arg0("currentCustomSecurityContextOnMonoSecurityContextNoAnnotation")))
.isTrue();
}

@Test
public void supportsParameterWhenNoSecurityContextNoAnnotationThenFalse() {
assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoStringNoAnnotation"))).isFalse();
}

@Test
public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() {
Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null)
Expand All @@ -67,6 +85,18 @@ public void resolveArgumentWhenAuthenticationPrincipalThenFound() {
private void currentSecurityContextOnMonoSecurityContext(@CurrentSecurityContext Mono<SecurityContext> context) {
}

@SuppressWarnings("unused")
private void currentSecurityContextOnMonoSecurityContextNoAnnotation(Mono<SecurityContext> context) {
}

@SuppressWarnings("unused")
private void currentCustomSecurityContextOnMonoSecurityContextNoAnnotation(Mono<CustomSecurityContext> context) {
}

@SuppressWarnings("unused")
private void currentSecurityContextOnMonoStringNoAnnotation(Mono<String> context) {
}

@Test
public void supportsParameterWhenCurrentUserThenTrue() {
assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue();
Expand Down Expand Up @@ -110,6 +140,41 @@ public void supportsParameterWhenNotAnnotatedThenFalse() {
private void monoUserDetails(Mono<UserDetails> user) {
}

@Test
public void supportsParameterWhenSecurityContextNotAnnotatedThenTrue() {
assertThat(this.resolver.supportsParameter(arg0("monoSecurityContext"))).isTrue();
}

@Test
public void resolveArgumentWhenMonoSecurityContextNoAnnotationThenFound() {
Authentication authentication = TestAuthentication.authenticatedUser();
Mono<SecurityContext> result = (Mono<SecurityContext>) this.resolver
.resolveArgument(arg0("monoSecurityContext"), null)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal());
}

@SuppressWarnings("unused")
private void monoSecurityContext(Mono<SecurityContext> securityContext) {
}

@Test
public void resolveArgumentWhenMonoCustomSecurityContextNoAnnotationThenFound() {
Authentication authentication = TestAuthentication.authenticatedUser();
CustomSecurityContext securityContext = new CustomSecurityContext();
securityContext.setAuthentication(authentication);
Mono<CustomSecurityContext> result = (Mono<CustomSecurityContext>) this.resolver
.resolveArgument(arg0("monoCustomSecurityContext"), null)
.contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
.block();
assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal());
}

@SuppressWarnings("unused")
private void monoCustomSecurityContext(Mono<CustomSecurityContext> securityContext) {
}

private MethodParameter arg0(String methodName) {
ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method();
return new SynthesizingMethodParameter(method.method(), 0);
Expand All @@ -121,4 +186,20 @@ private MethodParameter arg0(String methodName) {

}

static class CustomSecurityContext implements SecurityContext {

private Authentication authentication;

@Override
public Authentication getAuthentication() {
return this.authentication;
}

@Override
public void setAuthentication(Authentication authentication) {
this.authentication = authentication;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth

@Override
public boolean supportsParameter(MethodParameter parameter) {
return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
return SecurityContext.class.isAssignableFrom(parameter.getParameterType())
|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
}

@Override
Expand All @@ -95,26 +96,12 @@ public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer m
if (securityContext == null) {
return null;
}
Object securityContextResult = securityContext;
CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter);
String expressionToParse = annotation.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(securityContext);
context.setVariable("this", securityContext);
context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse);
securityContextResult = expression.getValue(context);
}
if (securityContextResult != null
&& !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) {
if (annotation.errorOnInvalidType()) {
throw new ClassCastException(
securityContextResult + " is not assignable to " + parameter.getParameterType());
}
return null;
if (annotation != null) {
return resolveSecurityContextFromAnnotation(parameter, annotation, securityContext);
}
return securityContextResult;

return securityContext;
}

/**
Expand All @@ -137,6 +124,29 @@ public void setBeanResolver(BeanResolver beanResolver) {
this.beanResolver = beanResolver;
}

private Object resolveSecurityContextFromAnnotation(MethodParameter parameter, CurrentSecurityContext annotation,
SecurityContext securityContext) {
Object securityContextResult = securityContext;
String expressionToParse = annotation.expression();
if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(securityContext);
context.setVariable("this", securityContext);
context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse);
securityContextResult = expression.getValue(context);
}
if (securityContextResult != null
&& !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) {
if (annotation.errorOnInvalidType()) {
throw new ClassCastException(
securityContextResult + " is not assignable to " + parameter.getParameterType());
}
return null;
}
return securityContextResult;
}

/**
* Obtain the specified {@link Annotation} on the specified {@link MethodParameter}.
* @param annotationClass the class of the {@link Annotation} to find on the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,21 @@ public void setBeanResolver(BeanResolver beanResolver) {

@Override
public boolean supportsParameter(MethodParameter parameter) {
return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
return isMonoSecurityContext(parameter)
|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
}

private boolean isMonoSecurityContext(MethodParameter parameter) {
boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType());
if (isParameterPublisher) {
ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
Class<?> genericType = resolvableType.resolveGeneric(0);
if (genericType == null) {
return false;
}
return SecurityContext.class.isAssignableFrom(genericType);
}
return false;
}

@Override
Expand Down Expand Up @@ -95,6 +109,14 @@ public Mono<Object> resolveArgument(MethodParameter parameter, BindingContext bi
*/
private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) {
CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter);
if (annotation != null) {
return resolveSecurityContextFromAnnotation(annotation, parameter, securityContext);
}
return securityContext;
}

private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext annotation, MethodParameter parameter,
Object securityContext) {
Object securityContextResult = securityContext;
String expressionToParse = annotation.expression();
if (StringUtils.hasLength(expressionToParse)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,26 @@ public void cleanup() {
SecurityContextHolder.clearContext();
}

@Test
public void supportsParameterNoAnnotationWrongType() {
assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotationTypeMismatch())).isFalse();
}

@Test
public void supportsParameterNoAnnotation() {
assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isFalse();
assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isTrue();
}

@Test
public void supportsParameterCustomSecurityContextNoAnnotation() {
assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation()))
.isTrue();
}

@Test
public void supportsParameterNoAnnotationCustomType() {
assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation()))
.isTrue();
}

@Test
Expand All @@ -88,6 +105,24 @@ public void resolveArgumentWithCustomSecurityContext() {
assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
}

@Test
public void resolveArgumentWithCustomSecurityContextNoAnnotation() {
String principal = "custom_security_context";
setAuthenticationPrincipalWithCustomSecurityContext(principal);
CustomSecurityContext customSecurityContext = (CustomSecurityContext) this.resolver
.resolveArgument(showSecurityContextWithCustomSecurityContextNoAnnotation(), null, null, null);
assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
}

@Test
public void resolveArgumentWithNoAnnotation() {
String principal = "custom_security_context";
setAuthenticationPrincipal(principal);
SecurityContext securityContext = (SecurityContext) this.resolver
.resolveArgument(showSecurityContextNoAnnotation(), null, null, null);
assertThat(securityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
}

@Test
public void resolveArgumentWithCustomSecurityContextTypeMatch() {
String principal = "custom_security_context_type_match";
Expand Down Expand Up @@ -212,10 +247,14 @@ public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch(
.resolveArgument(showCurrentSecurityWithErrorOnInvalidTypeMisMatch(), null, null, null));
}

private MethodParameter showSecurityContextNoAnnotation() {
private MethodParameter showSecurityContextNoAnnotationTypeMismatch() {
return getMethodParameter("showSecurityContextNoAnnotation", String.class);
}

private MethodParameter showSecurityContextNoAnnotation() {
return getMethodParameter("showSecurityContextNoAnnotation", SecurityContext.class);
}

private MethodParameter showSecurityContextAnnotation() {
return getMethodParameter("showSecurityContextAnnotation", SecurityContext.class);
}
Expand Down Expand Up @@ -276,6 +315,11 @@ public MethodParameter showCurrentSecurityWithErrorOnInvalidTypeMisMatch() {
return getMethodParameter("showCurrentSecurityWithErrorOnInvalidTypeMisMatch", String.class);
}

public MethodParameter showSecurityContextWithCustomSecurityContextNoAnnotation() {
return getMethodParameter("showSecurityContextWithCustomSecurityContextNoAnnotation",
CustomSecurityContext.class);
}

private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
return new MethodParameter(method, 0);
Expand Down Expand Up @@ -358,6 +402,12 @@ public void showCurrentSecurityWithErrorOnInvalidTypeMisMatch(
@CurrentSecurityWithErrorOnInvalidType String typeMisMatch) {
}

public void showSecurityContextNoAnnotation(SecurityContext context) {
}

public void showSecurityContextWithCustomSecurityContextNoAnnotation(CustomSecurityContext context) {
}

}

static class CustomSecurityContext implements SecurityContext {
Expand Down
Loading