From f942dbeb1cf8181f6d58e8bc28ccd1331494aabf Mon Sep 17 00:00:00 2001 From: Nermin Karapandzic Date: Sat, 13 Jan 2024 19:33:21 +0100 Subject: [PATCH] Add argument resolver for SecurityContext Closes gh-13425 --- ...urrentSecurityContextArgumentResolver.java | 24 +++++- ...tSecurityContextArgumentResolverTests.java | 81 +++++++++++++++++++ ...urrentSecurityContextArgumentResolver.java | 48 ++++++----- ...urrentSecurityContextArgumentResolver.java | 24 +++++- ...tSecurityContextArgumentResolverTests.java | 54 ++++++++++++- ...tSecurityContextArgumentResolverTests.java | 61 ++++++++++++++ 6 files changed, 269 insertions(+), 23 deletions(-) diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java index 3b67e85a437..ec69e1d389f 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java @@ -118,7 +118,21 @@ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return isMonoSecurityContext(parameter) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + private boolean isMonoSecurityContext(MethodParameter parameter) { + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + return SecurityContext.class.isAssignableFrom(genericType); + } + return false; } @Override @@ -136,6 +150,14 @@ public Mono resolveArgument(MethodParameter parameter, Message messag private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter); + if (contextAnno != null) { + return resolveSecurityContextFromAnnotation(contextAnno, parameter, securityContext); + } + return securityContext; + } + + private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext contextAnno, MethodParameter parameter, + Object securityContext) { String expressionToParse = contextAnno.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java index 9b715b65451..22876bde63e 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java @@ -46,6 +46,24 @@ public void supportsParameterWhenAuthenticationPrincipalThenTrue() { assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContext"))).isTrue(); } + @Test + public void supportsParameterWhenMonoSecurityContextNoAnnotationThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContextNoAnnotation"))) + .isTrue(); + } + + @Test + public void supportsParameterWhenMonoCustomSecurityContextNoAnnotationThenTrue() { + assertThat( + this.resolver.supportsParameter(arg0("currentCustomSecurityContextOnMonoSecurityContextNoAnnotation"))) + .isTrue(); + } + + @Test + public void supportsParameterWhenNoSecurityContextNoAnnotationThenFalse() { + assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoStringNoAnnotation"))).isFalse(); + } + @Test public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() { Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) @@ -67,6 +85,18 @@ public void resolveArgumentWhenAuthenticationPrincipalThenFound() { private void currentSecurityContextOnMonoSecurityContext(@CurrentSecurityContext Mono context) { } + @SuppressWarnings("unused") + private void currentSecurityContextOnMonoSecurityContextNoAnnotation(Mono context) { + } + + @SuppressWarnings("unused") + private void currentCustomSecurityContextOnMonoSecurityContextNoAnnotation(Mono context) { + } + + @SuppressWarnings("unused") + private void currentSecurityContextOnMonoStringNoAnnotation(Mono context) { + } + @Test public void supportsParameterWhenCurrentUserThenTrue() { assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue(); @@ -110,6 +140,41 @@ public void supportsParameterWhenNotAnnotatedThenFalse() { private void monoUserDetails(Mono user) { } + @Test + public void supportsParameterWhenSecurityContextNotAnnotatedThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("monoSecurityContext"))).isTrue(); + } + + @Test + public void resolveArgumentWhenMonoSecurityContextNoAnnotationThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("monoSecurityContext"), null) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void monoSecurityContext(Mono securityContext) { + } + + @Test + public void resolveArgumentWhenMonoCustomSecurityContextNoAnnotationThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + CustomSecurityContext securityContext = new CustomSecurityContext(); + securityContext.setAuthentication(authentication); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("monoCustomSecurityContext"), null) + .contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) + .block(); + assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void monoCustomSecurityContext(Mono securityContext) { + } + private MethodParameter arg0(String methodName) { ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method(); return new SynthesizingMethodParameter(method.method(), 0); @@ -121,4 +186,20 @@ private MethodParameter arg0(String methodName) { } + static class CustomSecurityContext implements SecurityContext { + + private Authentication authentication; + + @Override + public Authentication getAuthentication() { + return this.authentication; + } + + @Override + public void setAuthentication(Authentication authentication) { + this.authentication = authentication; + } + + } + } diff --git a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java index 05ef53a59ca..d1a6ba12a7d 100644 --- a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -85,7 +85,8 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return SecurityContext.class.isAssignableFrom(parameter.getParameterType()) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; } @Override @@ -95,26 +96,12 @@ public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer m if (securityContext == null) { return null; } - Object securityContextResult = securityContext; CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); - String expressionToParse = annotation.expression(); - if (StringUtils.hasLength(expressionToParse)) { - StandardEvaluationContext context = new StandardEvaluationContext(); - context.setRootObject(securityContext); - context.setVariable("this", securityContext); - context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); - securityContextResult = expression.getValue(context); - } - if (securityContextResult != null - && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { - if (annotation.errorOnInvalidType()) { - throw new ClassCastException( - securityContextResult + " is not assignable to " + parameter.getParameterType()); - } - return null; + if (annotation != null) { + return resolveSecurityContextFromAnnotation(parameter, annotation, securityContext); } - return securityContextResult; + + return securityContext; } /** @@ -137,6 +124,29 @@ public void setBeanResolver(BeanResolver beanResolver) { this.beanResolver = beanResolver; } + private Object resolveSecurityContextFromAnnotation(MethodParameter parameter, CurrentSecurityContext annotation, + SecurityContext securityContext) { + Object securityContextResult = securityContext; + String expressionToParse = annotation.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(securityContext); + context.setVariable("this", securityContext); + context.setBeanResolver(this.beanResolver); + Expression expression = this.parser.parseExpression(expressionToParse); + securityContextResult = expression.getValue(context); + } + if (securityContextResult != null + && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { + if (annotation.errorOnInvalidType()) { + throw new ClassCastException( + securityContextResult + " is not assignable to " + parameter.getParameterType()); + } + return null; + } + return securityContextResult; + } + /** * Obtain the specified {@link Annotation} on the specified {@link MethodParameter}. * @param annotationClass the class of the {@link Annotation} to find on the diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java index a02a9c30b9f..fd51d8ac533 100644 --- a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -67,7 +67,21 @@ public void setBeanResolver(BeanResolver beanResolver) { @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return isMonoSecurityContext(parameter) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + private boolean isMonoSecurityContext(MethodParameter parameter) { + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + return SecurityContext.class.isAssignableFrom(genericType); + } + return false; } @Override @@ -95,6 +109,14 @@ public Mono resolveArgument(MethodParameter parameter, BindingContext bi */ private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) { CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); + if (annotation != null) { + return resolveSecurityContextFromAnnotation(annotation, parameter, securityContext); + } + return securityContext; + } + + private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext annotation, MethodParameter parameter, + Object securityContext) { Object securityContextResult = securityContext; String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { diff --git a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java index f04b5269e5e..80c33ac9eca 100644 --- a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -69,9 +69,26 @@ public void cleanup() { SecurityContextHolder.clearContext(); } + @Test + public void supportsParameterNoAnnotationWrongType() { + assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotationTypeMismatch())).isFalse(); + } + @Test public void supportsParameterNoAnnotation() { - assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isFalse(); + assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isTrue(); + } + + @Test + public void supportsParameterCustomSecurityContextNoAnnotation() { + assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation())) + .isTrue(); + } + + @Test + public void supportsParameterNoAnnotationCustomType() { + assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation())) + .isTrue(); } @Test @@ -88,6 +105,24 @@ public void resolveArgumentWithCustomSecurityContext() { assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal); } + @Test + public void resolveArgumentWithCustomSecurityContextNoAnnotation() { + String principal = "custom_security_context"; + setAuthenticationPrincipalWithCustomSecurityContext(principal); + CustomSecurityContext customSecurityContext = (CustomSecurityContext) this.resolver + .resolveArgument(showSecurityContextWithCustomSecurityContextNoAnnotation(), null, null, null); + assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal); + } + + @Test + public void resolveArgumentWithNoAnnotation() { + String principal = "custom_security_context"; + setAuthenticationPrincipal(principal); + SecurityContext securityContext = (SecurityContext) this.resolver + .resolveArgument(showSecurityContextNoAnnotation(), null, null, null); + assertThat(securityContext.getAuthentication().getPrincipal()).isEqualTo(principal); + } + @Test public void resolveArgumentWithCustomSecurityContextTypeMatch() { String principal = "custom_security_context_type_match"; @@ -212,10 +247,14 @@ public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch( .resolveArgument(showCurrentSecurityWithErrorOnInvalidTypeMisMatch(), null, null, null)); } - private MethodParameter showSecurityContextNoAnnotation() { + private MethodParameter showSecurityContextNoAnnotationTypeMismatch() { return getMethodParameter("showSecurityContextNoAnnotation", String.class); } + private MethodParameter showSecurityContextNoAnnotation() { + return getMethodParameter("showSecurityContextNoAnnotation", SecurityContext.class); + } + private MethodParameter showSecurityContextAnnotation() { return getMethodParameter("showSecurityContextAnnotation", SecurityContext.class); } @@ -276,6 +315,11 @@ public MethodParameter showCurrentSecurityWithErrorOnInvalidTypeMisMatch() { return getMethodParameter("showCurrentSecurityWithErrorOnInvalidTypeMisMatch", String.class); } + public MethodParameter showSecurityContextWithCustomSecurityContextNoAnnotation() { + return getMethodParameter("showSecurityContextWithCustomSecurityContextNoAnnotation", + CustomSecurityContext.class); + } + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes); return new MethodParameter(method, 0); @@ -358,6 +402,12 @@ public void showCurrentSecurityWithErrorOnInvalidTypeMisMatch( @CurrentSecurityWithErrorOnInvalidType String typeMisMatch) { } + public void showSecurityContextNoAnnotation(SecurityContext context) { + } + + public void showSecurityContextWithCustomSecurityContextNoAnnotation(CustomSecurityContext context) { + } + } static class CustomSecurityContext implements SecurityContext { diff --git a/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java index 06cc9282bff..5556a25ed1b 100644 --- a/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -69,6 +69,14 @@ public class CurrentSecurityContextArgumentResolverTests { ResolvableMethod securityContextMethod = ResolvableMethod.on(getClass()).named("securityContext").build(); + ResolvableMethod securityContextNoAnnotationMethod = ResolvableMethod.on(getClass()) + .named("securityContextNoAnnotation") + .build(); + + ResolvableMethod customSecurityContextNoAnnotationMethod = ResolvableMethod.on(getClass()) + .named("customSecurityContextNoAnnotation") + .build(); + ResolvableMethod securityContextWithAuthentication = ResolvableMethod.on(getClass()) .named("securityContextWithAuthentication") .build(); @@ -87,6 +95,19 @@ public void supportsParameterCurrentSecurityContext() { .isTrue(); } + @Test + public void supportsParameterCurrentSecurityContextNoAnnotation() { + assertThat(this.resolver + .supportsParameter(this.securityContextNoAnnotationMethod.arg(Mono.class, SecurityContext.class))).isTrue(); + } + + @Test + public void supportsParameterCurrentCustomSecurityContextNoAnnotation() { + assertThat(this.resolver.supportsParameter( + this.customSecurityContextNoAnnotationMethod.arg(Mono.class, CustomSecurityContext.class))) + .isTrue(); + } + @Test public void supportsParameterWithAuthentication() { assertThat(this.resolver @@ -123,6 +144,40 @@ public void resolveArgumentWithSecurityContext() { ReactiveSecurityContextHolder.clearContext(); } + @Test + public void resolveArgumentWithSecurityContextNoAnnotation() { + MethodParameter parameter = ResolvableMethod.on(getClass()) + .named("securityContextNoAnnotation") + .build() + .arg(Mono.class, SecurityContext.class); + Authentication auth = buildAuthenticationWithPrincipal("hello"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange); + SecurityContext securityContext = (SecurityContext) argument.contextWrite(context) + .cast(Mono.class) + .block() + .block(); + assertThat(securityContext.getAuthentication()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithCustomSecurityContextNoAnnotation() { + MethodParameter parameter = ResolvableMethod.on(getClass()) + .named("customSecurityContextNoAnnotation") + .build() + .arg(Mono.class, CustomSecurityContext.class); + Authentication auth = buildAuthenticationWithPrincipal("hello"); + Context context = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(new CustomSecurityContext(auth))); + Mono argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange); + CustomSecurityContext securityContext = (CustomSecurityContext) argument.contextWrite(context) + .cast(Mono.class) + .block() + .block(); + assertThat(securityContext.getAuthentication()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + @Test public void resolveArgumentWithCustomSecurityContext() { MethodParameter parameter = ResolvableMethod.on(getClass()) @@ -350,6 +405,12 @@ public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch( void securityContext(@CurrentSecurityContext Mono monoSecurityContext) { } + void securityContextNoAnnotation(Mono securityContextMono) { + } + + void customSecurityContextNoAnnotation(Mono securityContextMono) { + } + void customSecurityContext(@CurrentSecurityContext Mono monoSecurityContext) { }