Skip to content

Commit aa12748

Browse files
committed
Add Request-level CSRF Skip
Fixes gh-7367
1 parent 9920cb4 commit aa12748

File tree

4 files changed

+89
-9
lines changed

4 files changed

+89
-9
lines changed

web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.springframework.util.Assert;
3636
import org.springframework.web.filter.OncePerRequestFilter;
3737

38+
import static java.lang.Boolean.TRUE;
39+
3840
/**
3941
* <p>
4042
* Applies
@@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
6365
*/
6466
public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
6567

68+
/**
69+
* The attribute name to use when marking a given request as one that should not be filtered.
70+
*
71+
* To use, set the attribute on your {@link HttpServletRequest}:
72+
* <pre>
73+
* CsrfFilter.skipRequest(request);
74+
* </pre>
75+
*/
76+
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
77+
6678
private final Log logger = LogFactory.getLog(getClass());
6779
private final CsrfTokenRepository tokenRepository;
6880
private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
@@ -73,6 +85,11 @@ public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
7385
this.tokenRepository = csrfTokenRepository;
7486
}
7587

88+
@Override
89+
protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
90+
return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
91+
}
92+
7693
/*
7794
* (non-Javadoc)
7895
*
@@ -124,6 +141,10 @@ protected void doFilterInternal(HttpServletRequest request,
124141
filterChain.doFilter(request, response);
125142
}
126143

144+
public static void skipRequest(HttpServletRequest request) {
145+
request.setAttribute(SHOULD_NOT_FILTER, TRUE);
146+
}
147+
127148
/**
128149
* Specifies a {@link RequestMatcher} that is used to determine if CSRF protection
129150
* should be applied. If the {@link RequestMatcher} returns true for a given request,

web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
import org.springframework.web.server.WebFilter;
3333
import org.springframework.web.server.WebFilterChain;
3434

35+
import static java.lang.Boolean.TRUE;
36+
3537
/**
3638
* <p>
3739
* Applies
@@ -60,6 +62,16 @@
6062
public class CsrfWebFilter implements WebFilter {
6163
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
6264

65+
/**
66+
* The attribute name to use when marking a given request as one that should not be filtered.
67+
*
68+
* To use, set the attribute on your {@link ServerWebExchange}:
69+
* <pre>
70+
* CsrfWebFilter.skipExchange(exchange);
71+
* </pre>
72+
*/
73+
private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName();
74+
6375
private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
6476

6577
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@@ -86,6 +98,10 @@ public void setRequireCsrfProtectionMatcher(
8698

8799
@Override
88100
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
101+
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
102+
return chain.filter(exchange).then(Mono.empty());
103+
}
104+
89105
return this.requireCsrfProtectionMatcher.matches(exchange)
90106
.filter( matchResult -> matchResult.isMatch())
91107
.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
@@ -96,6 +112,10 @@ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
96112
.handle(exchange, e));
97113
}
98114

115+
public static void skipExchange(ServerWebExchange exchange) {
116+
exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
117+
}
118+
99119
private Mono<Void> validateToken(ServerWebExchange exchange) {
100120
return this.csrfTokenRepository.loadToken(exchange)
101121
.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client"))))

web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.mockito.Mock;
3232
import org.mockito.junit.MockitoJUnitRunner;
3333

34+
import org.springframework.mock.web.MockFilterChain;
3435
import org.springframework.mock.web.MockHttpServletRequest;
3536
import org.springframework.mock.web.MockHttpServletResponse;
3637
import org.springframework.security.web.access.AccessDeniedHandler;
@@ -39,6 +40,8 @@
3940
import static org.assertj.core.api.Assertions.assertThat;
4041
import static org.mockito.ArgumentMatchers.any;
4142
import static org.mockito.ArgumentMatchers.eq;
43+
import static org.mockito.Mockito.lenient;
44+
import static org.mockito.Mockito.mock;
4245
import static org.mockito.Mockito.never;
4346
import static org.mockito.Mockito.times;
4447
import static org.mockito.Mockito.verify;
@@ -390,6 +393,22 @@ public void doFilterDefaultAccessDenied() throws ServletException, IOException {
390393
verifyZeroInteractions(this.filterChain);
391394
}
392395

396+
@Test
397+
public void doFilterWhenSkipRequestInvokedThenSkips()
398+
throws Exception {
399+
400+
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
401+
CsrfFilter filter = new CsrfFilter(repository);
402+
403+
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
404+
405+
MockHttpServletRequest request = new MockHttpServletRequest();
406+
CsrfFilter.skipRequest(request);
407+
filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
408+
409+
verifyZeroInteractions(repository);
410+
}
411+
393412
@Test(expected = IllegalArgumentException.class)
394413
public void setRequireCsrfProtectionMatcherNull() {
395414
this.filter.setRequireCsrfProtectionMatcher(null);

web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,24 @@
2020
import org.junit.runner.RunWith;
2121
import org.mockito.Mock;
2222
import org.mockito.junit.MockitoJUnitRunner;
23+
import reactor.core.publisher.Mono;
24+
import reactor.test.StepVerifier;
25+
import reactor.test.publisher.PublisherProbe;
26+
2327
import org.springframework.http.HttpStatus;
2428
import org.springframework.http.MediaType;
2529
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
2630
import org.springframework.mock.web.server.MockServerWebExchange;
31+
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
2732
import org.springframework.web.server.WebFilterChain;
2833
import org.springframework.web.server.WebSession;
29-
import reactor.core.publisher.Mono;
30-
import reactor.test.StepVerifier;
31-
import reactor.test.publisher.PublisherProbe;
3234

3335
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
3436
import static org.mockito.ArgumentMatchers.any;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.verifyZeroInteractions;
3539
import static org.mockito.Mockito.when;
40+
import static org.springframework.mock.web.server.MockServerWebExchange.from;
3641

3742
/**
3843
* @author Rob Winch
@@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
4954

5055
private CsrfWebFilter csrfFilter = new CsrfWebFilter();
5156

52-
private MockServerWebExchange get = MockServerWebExchange.from(
57+
private MockServerWebExchange get = from(
5358
MockServerHttpRequest.get("/"));
5459

55-
private MockServerWebExchange post = MockServerWebExchange.from(
60+
private MockServerWebExchange post = from(
5661
MockServerHttpRequest.post("/"));
5762

5863
@Test
@@ -104,7 +109,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamInvalidTokenThen
104109
this.csrfFilter.setCsrfTokenRepository(this.repository);
105110
when(this.repository.loadToken(any()))
106111
.thenReturn(Mono.just(this.token));
107-
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
112+
this.post = from(MockServerHttpRequest.post("/")
108113
.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
109114

110115
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -125,7 +130,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenCo
125130
.thenReturn(Mono.just(this.token));
126131
when(this.repository.generateToken(any()))
127132
.thenReturn(Mono.just(this.token));
128-
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
133+
this.post = from(MockServerHttpRequest.post("/")
129134
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
130135
.body(this.token.getParameterName() + "="+this.token.getToken()));
131136

@@ -142,7 +147,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderInvalidTokenThenCsrfEx
142147
this.csrfFilter.setCsrfTokenRepository(this.repository);
143148
when(this.repository.loadToken(any()))
144149
.thenReturn(Mono.just(this.token));
145-
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
150+
this.post = from(MockServerHttpRequest.post("/")
146151
.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
147152

148153
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -163,7 +168,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue
163168
.thenReturn(Mono.just(this.token));
164169
when(this.repository.generateToken(any()))
165170
.thenReturn(Mono.just(this.token));
166-
this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
171+
this.post = from(MockServerHttpRequest.post("/")
167172
.header(this.token.getHeaderName(), this.token.getToken()));
168173

169174
Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -173,4 +178,19 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue
173178

174179
chainResult.assertWasSubscribed();
175180
}
181+
182+
@Test
183+
public void doFilterWhenSkipExchangeInvokedThenSkips() {
184+
PublisherProbe<Void> chainResult = PublisherProbe.empty();
185+
when(this.chain.filter(any())).thenReturn(chainResult.mono());
186+
187+
ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
188+
this.csrfFilter.setRequireCsrfProtectionMatcher(matcher);
189+
190+
MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build());
191+
CsrfWebFilter.skipExchange(exchange);
192+
this.csrfFilter.filter(exchange, this.chain).block();
193+
194+
verifyZeroInteractions(matcher);
195+
}
176196
}

0 commit comments

Comments
 (0)