Skip to content

Restore @AuthenticationPrincipal/@CurrentSecurityContext Interface Support #16245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

package org.springframework.security.config.annotation.web.reactive;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.nio.charset.StandardCharsets;

import org.junit.jupiter.api.Test;
Expand All @@ -28,6 +32,7 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.Order;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
Expand Down Expand Up @@ -404,11 +409,28 @@ public String username(UserDetails user) {

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@AuthenticationPrincipal
@interface Property {

@AliasFor(attribute = "expression", annotation = AuthenticationPrincipal.class)
String value() default "id";

}

interface UsernameResolver {

String username(@Property("@principalBean.username(#this)") String username);

}

@RestController
static class AuthenticationPrincipalResolver {
static class AuthenticationPrincipalResolver implements UsernameResolver {

@Override
@GetMapping("/spel")
String username(@AuthenticationPrincipal(expression = "@principalBean.username(#this)") String username) {
public String username(String username) {
return username;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import java.lang.annotation.Annotation;

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
Expand Down Expand Up @@ -95,8 +97,12 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet

private ExpressionParser parser = new SpelExpressionParser();

private final Class<AuthenticationPrincipal> annotationType = AuthenticationPrincipal.class;

private SecurityAnnotationScanner<AuthenticationPrincipal> scanner = SecurityAnnotationScanners
.requireUnique(AuthenticationPrincipal.class);
.requireUnique(this.annotationType);

private boolean useAnnotationTemplate = false;

@Override
public boolean supportsParameter(MethodParameter parameter) {
Expand Down Expand Up @@ -149,6 +155,7 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
* @since 6.4
*/
public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDefaults) {
this.useAnnotationTemplate = templateDefaults != null;
this.scanner = SecurityAnnotationScanners.requireUnique(AuthenticationPrincipal.class, templateDefaults);
}

Expand All @@ -158,9 +165,22 @@ public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDef
* @param parameter the {@link MethodParameter} to search for an {@link Annotation}
* @return the {@link Annotation} that was found or null.
*/
@SuppressWarnings("unchecked")
private <T extends Annotation> T findMethodAnnotation(MethodParameter parameter) {
return (T) this.scanner.scan(parameter.getParameter());
private AuthenticationPrincipal findMethodAnnotation(MethodParameter parameter) {
if (this.useAnnotationTemplate) {
return this.scanner.scan(parameter.getParameter());
}
AuthenticationPrincipal annotation = parameter.getParameterAnnotation(this.annotationType);
if (annotation != null) {
return annotation;
}
Annotation[] annotationsToSearch = parameter.getParameterAnnotations();
for (Annotation toSearch : annotationsToSearch) {
annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), this.annotationType);
if (annotation != null) {
return MergedAnnotations.from(toSearch).get(this.annotationType).synthesize();
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.expression.BeanResolver;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
Expand Down Expand Up @@ -99,8 +101,12 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg

private ExpressionParser parser = new SpelExpressionParser();

private final Class<AuthenticationPrincipal> annotationType = AuthenticationPrincipal.class;

private SecurityAnnotationScanner<AuthenticationPrincipal> scanner = SecurityAnnotationScanners
.requireUnique(AuthenticationPrincipal.class);
.requireUnique(this.annotationType);

private boolean useAnnotationTemplate = false;

private BeanResolver beanResolver;

Expand Down Expand Up @@ -190,6 +196,7 @@ private boolean isInvalidType(MethodParameter parameter, Object principal) {
* @since 6.4
*/
public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDefaults) {
this.useAnnotationTemplate = templateDefaults != null;
this.scanner = SecurityAnnotationScanners.requireUnique(AuthenticationPrincipal.class, templateDefaults);
}

Expand All @@ -199,9 +206,22 @@ public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDef
* @param parameter the {@link MethodParameter} to search for an {@link Annotation}
* @return the {@link Annotation} that was found or null.
*/
@SuppressWarnings("unchecked")
private <T extends Annotation> T findMethodAnnotation(MethodParameter parameter) {
return (T) this.scanner.scan(parameter.getParameter());
private AuthenticationPrincipal findMethodAnnotation(MethodParameter parameter) {
if (this.useAnnotationTemplate) {
return this.scanner.scan(parameter.getParameter());
}
AuthenticationPrincipal annotation = parameter.getParameterAnnotation(this.annotationType);
if (annotation != null) {
return annotation;
}
Annotation[] annotationsToSearch = parameter.getParameterAnnotations();
for (Annotation toSearch : annotationsToSearch) {
annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), this.annotationType);
if (annotation != null) {
return MergedAnnotations.from(toSearch).get(this.annotationType).synthesize();
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.MergedAnnotations;
import org.springframework.expression.BeanResolver;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
Expand Down Expand Up @@ -97,8 +99,12 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu

private ExpressionParser parser = new SpelExpressionParser();

private final Class<CurrentSecurityContext> annotationType = CurrentSecurityContext.class;

private SecurityAnnotationScanner<CurrentSecurityContext> scanner = SecurityAnnotationScanners
.requireUnique(CurrentSecurityContext.class);
.requireUnique(this.annotationType);

private boolean useAnnotationTemplate = false;

private BeanResolver beanResolver;

Expand Down Expand Up @@ -208,6 +214,7 @@ private boolean isInvalidType(MethodParameter parameter, Object value) {
* @since 6.4
*/
public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDefaults) {
this.useAnnotationTemplate = templateDefaults != null;
this.scanner = SecurityAnnotationScanners.requireUnique(CurrentSecurityContext.class, templateDefaults);
}

Expand All @@ -216,9 +223,22 @@ public void setTemplateDefaults(AnnotationTemplateExpressionDefaults templateDef
* @param parameter the {@link MethodParameter} to search for an {@link Annotation}
* @return the {@link Annotation} that was found or null.
*/
@SuppressWarnings("unchecked")
private <T extends Annotation> T findMethodAnnotation(MethodParameter parameter) {
return (T) this.scanner.scan(parameter.getParameter());
private CurrentSecurityContext findMethodAnnotation(MethodParameter parameter) {
if (this.useAnnotationTemplate) {
return this.scanner.scan(parameter.getParameter());
}
CurrentSecurityContext annotation = parameter.getParameterAnnotation(this.annotationType);
if (annotation != null) {
return annotation;
}
Annotation[] annotationsToSearch = parameter.getParameterAnnotations();
for (Annotation toSearch : annotationsToSearch) {
annotation = AnnotationUtils.findAnnotation(toSearch.annotationType(), this.annotationType);
if (annotation != null) {
return MergedAnnotations.from(toSearch).get(this.annotationType).synthesize();
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.AnnotatedMethod;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
Expand Down Expand Up @@ -186,10 +187,21 @@ public void resolveArgumentCustomMetaAnnotationTpl() throws Exception {
assertThat(this.resolver.resolveArgument(showUserCustomMetaAnnotationTpl(), null)).isEqualTo(principal.id);
}

@Test
public void resolveArgumentWhenAliasForOnInterfaceThenInherits() {
CustomUserPrincipal principal = new CustomUserPrincipal();
setAuthenticationPrincipal(principal);
assertThat(this.resolver.resolveArgument(showUserNoConcreteAnnotation(), null)).isEqualTo(principal.property);
}

private MethodParameter showUserNoAnnotation() {
return getMethodParameter("showUserNoAnnotation", String.class);
}

private MethodParameter showUserNoConcreteAnnotation() {
return getMethodParameter("showUserNoConcreteAnnotation", String.class);
}

private MethodParameter showUserAnnotationString() {
return getMethodParameter("showUserAnnotation", String.class);
}
Expand Down Expand Up @@ -240,7 +252,7 @@ private MethodParameter showUserAnnotationObject() {

private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
return new MethodParameter(method, 0);
return new AnnotatedMethod(method).getMethodParameters()[0];
}

private void setAuthenticationPrincipal(Object principal) {
Expand Down Expand Up @@ -280,11 +292,32 @@ private void setAuthenticationPrincipal(Object principal) {

}

public static class TestController {
@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@AuthenticationPrincipal
@interface Property {

@AliasFor(attribute = "expression", annotation = AuthenticationPrincipal.class)
String value() default "id";

}

private interface TestInterface {

void showUserNoConcreteAnnotation(@Property("property") String property);

}

public static class TestController implements TestInterface {

public void showUserNoAnnotation(String user) {
}

@Override
public void showUserNoConcreteAnnotation(String user) {

}

public void showUserAnnotation(@AuthenticationPrincipal String user) {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@

package org.springframework.security.messaging.handler.invocation.reactive;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;

import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AliasFor;
import org.springframework.core.annotation.AnnotatedMethod;
import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.security.authentication.TestAuthentication;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand Down Expand Up @@ -128,6 +131,19 @@ public void supportsParameterWhenNotAnnotatedThenFalse() {
assertThat(this.resolver.supportsParameter(arg0("monoUserDetails"))).isFalse();
}

@Test
public void resolveArgumentWhenAliasForOnInterfaceThenInherits() {
CustomUserPrincipal principal = new CustomUserPrincipal();
Authentication authentication = new TestingAuthenticationToken(principal, "password", "ROLE_USER");
ResolvableMethod method = ResolvableMethod.on(TestController.class)
.named("showUserNoConcreteAnnotation")
.method();
MethodParameter parameter = new AnnotatedMethod(method.method()).getMethodParameters()[0];
Mono<Object> result = this.resolver.resolveArgument(parameter, null)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication));
assertThat(result.block()).isEqualTo(principal.property);
}

@SuppressWarnings("unused")
private void monoUserDetails(Mono<UserDetails> user) {
}
Expand Down Expand Up @@ -172,6 +188,8 @@ static class CustomUserPrincipal {

public final int id = 1;

public final String property = "property";

public Object getPrincipal() {
return this;
}
Expand All @@ -195,4 +213,29 @@ public Object getPrincipal() {

}

@Target({ ElementType.PARAMETER })
@Retention(RetentionPolicy.RUNTIME)
@AuthenticationPrincipal
@interface Property {

@AliasFor(attribute = "expression", annotation = AuthenticationPrincipal.class)
String value() default "id";

}

private interface TestInterface {

void showUserNoConcreteAnnotation(@Property("property") String property);

}

private static class TestController implements TestInterface {

@Override
public void showUserNoConcreteAnnotation(String user) {

}

}

}
Loading
Loading