Skip to content

Commit 753fe5e

Browse files
Apply SecurityContextHolderFilter to all dispatcher types
Closes gh-11962
1 parent 74e8fa1 commit 753fe5e

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121

2222
import javax.servlet.FilterChain;
2323
import javax.servlet.ServletException;
24+
import javax.servlet.ServletRequest;
25+
import javax.servlet.ServletResponse;
2426
import javax.servlet.http.HttpServletRequest;
2527
import javax.servlet.http.HttpServletResponse;
2628

2729
import org.springframework.security.core.context.SecurityContext;
2830
import org.springframework.security.core.context.SecurityContextHolder;
2931
import org.springframework.security.core.context.SecurityContextHolderStrategy;
3032
import org.springframework.util.Assert;
31-
import org.springframework.web.filter.OncePerRequestFilter;
33+
import org.springframework.web.filter.GenericFilterBean;
3234

3335
/**
3436
* A {@link javax.servlet.Filter} that uses the {@link SecurityContextRepository} to
@@ -40,17 +42,18 @@
4042
* mechanisms to choose individually if authentication should be persisted.
4143
*
4244
* @author Rob Winch
45+
* @author Marcus da Coregio
4346
* @since 5.7
4447
*/
45-
public class SecurityContextHolderFilter extends OncePerRequestFilter {
48+
public class SecurityContextHolderFilter extends GenericFilterBean {
49+
50+
private static final String FILTER_APPLIED = SecurityContextHolderFilter.class.getName() + ".APPLIED";
4651

4752
private final SecurityContextRepository securityContextRepository;
4853

4954
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
5055
.getContextHolderStrategy();
5156

52-
private boolean shouldNotFilterErrorDispatch;
53-
5457
/**
5558
* Creates a new instance.
5659
* @param securityContextRepository the repository to use. Cannot be null.
@@ -61,23 +64,29 @@ public SecurityContextHolderFilter(SecurityContextRepository securityContextRepo
6164
}
6265

6366
@Override
64-
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
67+
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
68+
throws IOException, ServletException {
69+
doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
70+
}
71+
72+
private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
6573
throws ServletException, IOException {
74+
if (request.getAttribute(FILTER_APPLIED) != null) {
75+
chain.doFilter(request, response);
76+
return;
77+
}
78+
request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
6679
Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadDeferredContext(request);
6780
try {
6881
this.securityContextHolderStrategy.setDeferredContext(deferredContext);
69-
filterChain.doFilter(request, response);
82+
chain.doFilter(request, response);
7083
}
7184
finally {
7285
this.securityContextHolderStrategy.clearContext();
86+
request.removeAttribute(FILTER_APPLIED);
7387
}
7488
}
7589

76-
@Override
77-
protected boolean shouldNotFilterErrorDispatch() {
78-
return this.shouldNotFilterErrorDispatch;
79-
}
80-
8190
/**
8291
* Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
8392
* the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
@@ -89,13 +98,4 @@ public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy secur
8998
this.securityContextHolderStrategy = securityContextHolderStrategy;
9099
}
91100

92-
/**
93-
* Disables {@link SecurityContextHolderFilter} for error dispatch.
94-
* @param shouldNotFilterErrorDispatch if the Filter should be disabled for error
95-
* dispatch. Default is false.
96-
*/
97-
public void setShouldNotFilterErrorDispatch(boolean shouldNotFilterErrorDispatch) {
98-
this.shouldNotFilterErrorDispatch = shouldNotFilterErrorDispatch;
99-
}
100-
101101
}

web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.function.Supplier;
2020

21+
import javax.servlet.DispatcherType;
2122
import javax.servlet.FilterChain;
2223
import javax.servlet.http.HttpServletRequest;
2324
import javax.servlet.http.HttpServletResponse;
@@ -26,11 +27,15 @@
2627
import org.junit.jupiter.api.BeforeEach;
2728
import org.junit.jupiter.api.Test;
2829
import org.junit.jupiter.api.extension.ExtendWith;
30+
import org.junit.jupiter.params.ParameterizedTest;
31+
import org.junit.jupiter.params.provider.EnumSource;
2932
import org.mockito.ArgumentCaptor;
3033
import org.mockito.Captor;
34+
import org.mockito.InOrder;
3135
import org.mockito.Mock;
3236
import org.mockito.junit.jupiter.MockitoExtension;
3337

38+
import org.springframework.mock.web.MockFilterChain;
3439
import org.springframework.security.authentication.TestAuthentication;
3540
import org.springframework.security.core.Authentication;
3641
import org.springframework.security.core.context.SecurityContext;
@@ -40,11 +45,17 @@
4045

4146
import static org.assertj.core.api.Assertions.assertThat;
4247
import static org.mockito.BDDMockito.given;
48+
import static org.mockito.Mockito.inOrder;
49+
import static org.mockito.Mockito.lenient;
50+
import static org.mockito.Mockito.times;
4351
import static org.mockito.Mockito.verify;
52+
import static org.mockito.Mockito.verifyNoInteractions;
4453

4554
@ExtendWith(MockitoExtension.class)
4655
class SecurityContextHolderFilterTests {
4756

57+
private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED";
58+
4859
@Mock
4960
private SecurityContextRepository repository;
5061

@@ -105,14 +116,38 @@ void doFilterThenSetsAndClearsSecurityContextHolderStrategy() throws Exception {
105116
}
106117

107118
@Test
108-
void shouldNotFilterErrorDispatchWhenDefault() {
109-
assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse();
119+
void doFilterWhenFilterAppliedThenDoNothing() throws Exception {
120+
given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true);
121+
this.filter.doFilter(this.request, this.response, new MockFilterChain());
122+
verify(this.request, times(1)).getAttribute(FILTER_APPLIED);
123+
verifyNoInteractions(this.repository, this.response);
110124
}
111125

112126
@Test
113-
void shouldNotFilterErrorDispatchWhenOverridden() {
114-
this.filter.setShouldNotFilterErrorDispatch(true);
115-
assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue();
127+
void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception {
128+
given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn(
129+
new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy));
130+
131+
this.filter.doFilter(this.request, this.response, new MockFilterChain());
132+
133+
InOrder inOrder = inOrder(this.request, this.repository);
134+
inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true);
135+
inOrder.verify(this.repository).loadDeferredContext(this.request);
136+
inOrder.verify(this.request).removeAttribute(FILTER_APPLIED);
137+
}
138+
139+
@ParameterizedTest
140+
@EnumSource(DispatcherType.class)
141+
void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception {
142+
lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType);
143+
Authentication authentication = TestAuthentication.authenticatedUser();
144+
SecurityContext expectedContext = new SecurityContextImpl(authentication);
145+
given(this.repository.loadDeferredContext(this.requestArg.capture()))
146+
.willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy));
147+
FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext())
148+
.isEqualTo(expectedContext);
149+
150+
this.filter.doFilter(this.request, this.response, filterChain);
116151
}
117152

118153
}

0 commit comments

Comments
 (0)