Skip to content

Commit 635f7e1

Browse files
committed
CsrfWebFilter supports multipart/form-data
Fixes gh-7576
1 parent 387f765 commit 635f7e1

File tree

5 files changed

+148
-11
lines changed

5 files changed

+148
-11
lines changed

config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,6 +2731,19 @@ public CsrfSpec requireCsrfProtectionMatcher(
27312731
return this;
27322732
}
27332733

2734+
/**
2735+
* Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
2736+
* data requests.
2737+
*
2738+
* @param enabled true if should read from multipart form body, else false. Default is false
2739+
* @return the {@link CsrfSpec} for additional configuration
2740+
*/
2741+
public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
2742+
this.filter.setTokenFromMultipartDataEnabled(enabled);
2743+
return this;
2744+
}
2745+
2746+
27342747
/**
27352748
* Allows method chaining to continue configuring the {@link ServerHttpSecurity}
27362749
* @return the {@link ServerHttpSecurity} to continue configuring

gradle/dependency-management.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ dependencyManagement {
210210
dependency 'org.slf4j:slf4j-nop:1.7.28'
211211
dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111'
212212
dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE'
213+
dependency 'org.synchronoss.cloud:nio-multipart-parser:1.1.0'
213214
dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE'
214215
dependency 'org.unbescape:unbescape:1.1.5.RELEASE'
215216
dependency 'org.w3c.css:sac:1.3'

web/spring-security-web.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies {
2525
testCompile 'org.codehaus.groovy:groovy-all'
2626
testCompile 'org.skyscreamer:jsonassert'
2727
testCompile 'org.springframework:spring-webflux'
28+
testCompile 'org.synchronoss.cloud:nio-multipart-parser'
2829
testCompile powerMock2Dependencies
2930
testCompile spockDependencies
3031

web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@
1616

1717
package org.springframework.security.web.server.csrf;
1818

19-
import java.util.Arrays;
20-
import java.util.HashSet;
21-
import java.util.Set;
22-
23-
import reactor.core.publisher.Mono;
24-
19+
import org.springframework.http.HttpHeaders;
2520
import org.springframework.http.HttpMethod;
2621
import org.springframework.http.HttpStatus;
22+
import org.springframework.http.MediaType;
23+
import org.springframework.http.codec.multipart.FormFieldPart;
24+
import org.springframework.http.server.reactive.ServerHttpRequest;
2725
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
2826
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
2927
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
3028
import org.springframework.util.Assert;
3129
import org.springframework.web.server.ServerWebExchange;
3230
import org.springframework.web.server.WebFilter;
3331
import org.springframework.web.server.WebFilterChain;
32+
import reactor.core.publisher.Mono;
33+
34+
import java.util.Arrays;
35+
import java.util.HashSet;
36+
import java.util.Set;
3437

3538
import static java.lang.Boolean.TRUE;
3639

@@ -78,6 +81,8 @@ public class CsrfWebFilter implements WebFilter {
7881

7982
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
8083

84+
private boolean isTokenFromMultipartDataEnabled;
85+
8186
public void setAccessDeniedHandler(
8287
ServerAccessDeniedHandler accessDeniedHandler) {
8388
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
@@ -96,6 +101,15 @@ public void setRequireCsrfProtectionMatcher(
96101
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
97102
}
98103

104+
/**
105+
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
106+
* data requests.
107+
* @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
108+
*/
109+
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
110+
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
111+
}
112+
99113
@Override
100114
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
101115
if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
@@ -128,9 +142,26 @@ private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfTok
128142
return exchange.getFormData()
129143
.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
130144
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
145+
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
131146
.map(actual -> actual.equals(expected.getToken()));
132147
}
133148

149+
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
150+
if (!this.isTokenFromMultipartDataEnabled) {
151+
return Mono.empty();
152+
}
153+
ServerHttpRequest request = exchange.getRequest();
154+
HttpHeaders headers = request.getHeaders();
155+
MediaType contentType = headers.getContentType();
156+
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
157+
return Mono.empty();
158+
}
159+
return exchange.getMultipartData()
160+
.map(d -> d.getFirst(expected.getParameterName()))
161+
.cast(FormFieldPart.class)
162+
.map(FormFieldPart::value);
163+
}
164+
134165
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
135166
return Mono.defer(() ->{
136167
Mono<CsrfToken> csrfToken = csrfToken(exchange);

web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,28 @@
2020
import org.junit.runner.RunWith;
2121
import org.mockito.Mock;
2222
import org.mockito.junit.MockitoJUnitRunner;
23-
import reactor.core.publisher.Mono;
24-
import reactor.test.StepVerifier;
25-
import reactor.test.publisher.PublisherProbe;
26-
2723
import org.springframework.http.HttpStatus;
2824
import org.springframework.http.MediaType;
2925
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
3026
import org.springframework.mock.web.server.MockServerWebExchange;
3127
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
28+
import org.springframework.test.web.reactive.server.WebTestClient;
29+
import org.springframework.web.bind.annotation.RequestMapping;
30+
import org.springframework.web.bind.annotation.RestController;
31+
import org.springframework.web.server.ServerWebExchange;
3232
import org.springframework.web.server.WebFilterChain;
3333
import org.springframework.web.server.WebSession;
34+
import reactor.core.publisher.Mono;
35+
import reactor.test.StepVerifier;
36+
import reactor.test.publisher.PublisherProbe;
3437

3538
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
3639
import static org.mockito.ArgumentMatchers.any;
3740
import static org.mockito.Mockito.mock;
3841
import static org.mockito.Mockito.verifyZeroInteractions;
3942
import static org.mockito.Mockito.when;
4043
import static org.springframework.mock.web.server.MockServerWebExchange.from;
44+
import static org.springframework.web.reactive.function.BodyInserters.fromMultipartData;
4145

4246
/**
4347
* @author Rob Winch
@@ -57,7 +61,7 @@ public class CsrfWebFilterTests {
5761
private MockServerWebExchange get = from(
5862
MockServerHttpRequest.get("/"));
5963

60-
private MockServerWebExchange post = from(
64+
private ServerWebExchange post = from(
6165
MockServerHttpRequest.post("/"));
6266

6367
@Test
@@ -193,4 +197,91 @@ public void doFilterWhenSkipExchangeInvokedThenSkips() {
193197

194198
verifyZeroInteractions(matcher);
195199
}
200+
201+
@Test
202+
public void filterWhenMultipartFormDataAndNotEnabledThenDenied() {
203+
this.csrfFilter.setCsrfTokenRepository(this.repository);
204+
when(this.repository.loadToken(any()))
205+
.thenReturn(Mono.just(this.token));
206+
207+
WebTestClient client = WebTestClient.bindToController(new OkController())
208+
.webFilter(this.csrfFilter)
209+
.build();
210+
211+
client.post()
212+
.uri("/")
213+
.contentType(MediaType.MULTIPART_FORM_DATA)
214+
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
215+
.exchange()
216+
.expectStatus().isForbidden();
217+
}
218+
219+
@Test
220+
public void filterWhenMultipartFormDataAndEnabledThenGranted() {
221+
this.csrfFilter.setCsrfTokenRepository(this.repository);
222+
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
223+
when(this.repository.loadToken(any()))
224+
.thenReturn(Mono.just(this.token));
225+
when(this.repository.generateToken(any()))
226+
.thenReturn(Mono.just(this.token));
227+
228+
WebTestClient client = WebTestClient.bindToController(new OkController())
229+
.webFilter(this.csrfFilter)
230+
.build();
231+
232+
client.post()
233+
.uri("/")
234+
.contentType(MediaType.MULTIPART_FORM_DATA)
235+
.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
236+
.exchange()
237+
.expectStatus().is2xxSuccessful();
238+
}
239+
240+
@Test
241+
public void filterWhenFormDataAndEnabledThenGranted() {
242+
this.csrfFilter.setCsrfTokenRepository(this.repository);
243+
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
244+
when(this.repository.loadToken(any()))
245+
.thenReturn(Mono.just(this.token));
246+
when(this.repository.generateToken(any()))
247+
.thenReturn(Mono.just(this.token));
248+
249+
WebTestClient client = WebTestClient.bindToController(new OkController())
250+
.webFilter(this.csrfFilter)
251+
.build();
252+
253+
client.post()
254+
.uri("/")
255+
.contentType(MediaType.APPLICATION_FORM_URLENCODED)
256+
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
257+
.exchange()
258+
.expectStatus().is2xxSuccessful();
259+
}
260+
261+
@Test
262+
public void filterWhenMultipartMixedAndEnabledThenNotRead() {
263+
this.csrfFilter.setCsrfTokenRepository(this.repository);
264+
this.csrfFilter.setTokenFromMultipartDataEnabled(true);
265+
when(this.repository.loadToken(any()))
266+
.thenReturn(Mono.just(this.token));
267+
268+
WebTestClient client = WebTestClient.bindToController(new OkController())
269+
.webFilter(this.csrfFilter)
270+
.build();
271+
272+
client.post()
273+
.uri("/")
274+
.contentType(MediaType.MULTIPART_MIXED)
275+
.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
276+
.exchange()
277+
.expectStatus().isForbidden();
278+
}
279+
280+
@RestController
281+
static class OkController {
282+
@RequestMapping("/**")
283+
String ok() {
284+
return "ok";
285+
}
286+
}
196287
}

0 commit comments

Comments
 (0)