|
18 | 18 |
|
19 | 19 | import java.util.function.Supplier;
|
20 | 20 |
|
| 21 | +import javax.servlet.DispatcherType; |
21 | 22 | import javax.servlet.FilterChain;
|
22 | 23 | import javax.servlet.http.HttpServletRequest;
|
23 | 24 | import javax.servlet.http.HttpServletResponse;
|
|
26 | 27 | import org.junit.jupiter.api.BeforeEach;
|
27 | 28 | import org.junit.jupiter.api.Test;
|
28 | 29 | import org.junit.jupiter.api.extension.ExtendWith;
|
| 30 | +import org.junit.jupiter.params.ParameterizedTest; |
| 31 | +import org.junit.jupiter.params.provider.EnumSource; |
29 | 32 | import org.mockito.ArgumentCaptor;
|
30 | 33 | import org.mockito.Captor;
|
| 34 | +import org.mockito.InOrder; |
31 | 35 | import org.mockito.Mock;
|
32 | 36 | import org.mockito.junit.jupiter.MockitoExtension;
|
33 | 37 |
|
| 38 | +import org.springframework.mock.web.MockFilterChain; |
34 | 39 | import org.springframework.security.authentication.TestAuthentication;
|
35 | 40 | import org.springframework.security.core.Authentication;
|
36 | 41 | import org.springframework.security.core.context.SecurityContext;
|
|
40 | 45 |
|
41 | 46 | import static org.assertj.core.api.Assertions.assertThat;
|
42 | 47 | import static org.mockito.BDDMockito.given;
|
| 48 | +import static org.mockito.Mockito.inOrder; |
| 49 | +import static org.mockito.Mockito.lenient; |
| 50 | +import static org.mockito.Mockito.times; |
43 | 51 | import static org.mockito.Mockito.verify;
|
| 52 | +import static org.mockito.Mockito.verifyNoInteractions; |
44 | 53 |
|
45 | 54 | @ExtendWith(MockitoExtension.class)
|
46 | 55 | class SecurityContextHolderFilterTests {
|
47 | 56 |
|
| 57 | + private static final String FILTER_APPLIED = "org.springframework.security.web.context.SecurityContextHolderFilter.APPLIED"; |
| 58 | + |
48 | 59 | @Mock
|
49 | 60 | private SecurityContextRepository repository;
|
50 | 61 |
|
@@ -105,14 +116,38 @@ void doFilterThenSetsAndClearsSecurityContextHolderStrategy() throws Exception {
|
105 | 116 | }
|
106 | 117 |
|
107 | 118 | @Test
|
108 |
| - void shouldNotFilterErrorDispatchWhenDefault() { |
109 |
| - assertThat(this.filter.shouldNotFilterErrorDispatch()).isFalse(); |
| 119 | + void doFilterWhenFilterAppliedThenDoNothing() throws Exception { |
| 120 | + given(this.request.getAttribute(FILTER_APPLIED)).willReturn(true); |
| 121 | + this.filter.doFilter(this.request, this.response, new MockFilterChain()); |
| 122 | + verify(this.request, times(1)).getAttribute(FILTER_APPLIED); |
| 123 | + verifyNoInteractions(this.repository, this.response); |
110 | 124 | }
|
111 | 125 |
|
112 | 126 | @Test
|
113 |
| - void shouldNotFilterErrorDispatchWhenOverridden() { |
114 |
| - this.filter.setShouldNotFilterErrorDispatch(true); |
115 |
| - assertThat(this.filter.shouldNotFilterErrorDispatch()).isTrue(); |
| 127 | + void doFilterWhenNotAppliedThenSetsAndRemovesAttribute() throws Exception { |
| 128 | + given(this.repository.loadDeferredContext(this.requestArg.capture())).willReturn( |
| 129 | + new SupplierDeferredSecurityContext(SecurityContextHolder::createEmptyContext, this.strategy)); |
| 130 | + |
| 131 | + this.filter.doFilter(this.request, this.response, new MockFilterChain()); |
| 132 | + |
| 133 | + InOrder inOrder = inOrder(this.request, this.repository); |
| 134 | + inOrder.verify(this.request).setAttribute(FILTER_APPLIED, true); |
| 135 | + inOrder.verify(this.repository).loadDeferredContext(this.request); |
| 136 | + inOrder.verify(this.request).removeAttribute(FILTER_APPLIED); |
| 137 | + } |
| 138 | + |
| 139 | + @ParameterizedTest |
| 140 | + @EnumSource(DispatcherType.class) |
| 141 | + void doFilterWhenAnyDispatcherTypeThenFilter(DispatcherType dispatcherType) throws Exception { |
| 142 | + lenient().when(this.request.getDispatcherType()).thenReturn(dispatcherType); |
| 143 | + Authentication authentication = TestAuthentication.authenticatedUser(); |
| 144 | + SecurityContext expectedContext = new SecurityContextImpl(authentication); |
| 145 | + given(this.repository.loadDeferredContext(this.requestArg.capture())) |
| 146 | + .willReturn(new SupplierDeferredSecurityContext(() -> expectedContext, this.strategy)); |
| 147 | + FilterChain filterChain = (request, response) -> assertThat(SecurityContextHolder.getContext()) |
| 148 | + .isEqualTo(expectedContext); |
| 149 | + |
| 150 | + this.filter.doFilter(this.request, this.response, filterChain); |
116 | 151 | }
|
117 | 152 |
|
118 | 153 | }
|
0 commit comments