20
20
import org .junit .runner .RunWith ;
21
21
import org .mockito .Mock ;
22
22
import org .mockito .junit .MockitoJUnitRunner ;
23
- import reactor .core .publisher .Mono ;
24
- import reactor .test .StepVerifier ;
25
- import reactor .test .publisher .PublisherProbe ;
26
-
27
23
import org .springframework .http .HttpStatus ;
28
24
import org .springframework .http .MediaType ;
29
25
import org .springframework .mock .http .server .reactive .MockServerHttpRequest ;
30
26
import org .springframework .mock .web .server .MockServerWebExchange ;
31
27
import org .springframework .security .web .server .util .matcher .ServerWebExchangeMatcher ;
28
+ import org .springframework .test .web .reactive .server .WebTestClient ;
29
+ import org .springframework .web .bind .annotation .RequestMapping ;
30
+ import org .springframework .web .bind .annotation .RestController ;
31
+ import org .springframework .web .server .ServerWebExchange ;
32
32
import org .springframework .web .server .WebFilterChain ;
33
33
import org .springframework .web .server .WebSession ;
34
+ import reactor .core .publisher .Mono ;
35
+ import reactor .test .StepVerifier ;
36
+ import reactor .test .publisher .PublisherProbe ;
34
37
35
38
import static org .assertj .core .api .AssertionsForInterfaceTypes .assertThat ;
36
39
import static org .mockito .ArgumentMatchers .any ;
37
40
import static org .mockito .Mockito .mock ;
38
41
import static org .mockito .Mockito .verifyZeroInteractions ;
39
42
import static org .mockito .Mockito .when ;
40
43
import static org .springframework .mock .web .server .MockServerWebExchange .from ;
44
+ import static org .springframework .web .reactive .function .BodyInserters .fromMultipartData ;
41
45
42
46
/**
43
47
* @author Rob Winch
@@ -57,7 +61,7 @@ public class CsrfWebFilterTests {
57
61
private MockServerWebExchange get = from (
58
62
MockServerHttpRequest .get ("/" ));
59
63
60
- private MockServerWebExchange post = from (
64
+ private ServerWebExchange post = from (
61
65
MockServerHttpRequest .post ("/" ));
62
66
63
67
@ Test
@@ -193,4 +197,91 @@ public void doFilterWhenSkipExchangeInvokedThenSkips() {
193
197
194
198
verifyZeroInteractions (matcher );
195
199
}
200
+
201
+ @ Test
202
+ public void filterWhenMultipartFormDataAndNotEnabledThenDenied () {
203
+ this .csrfFilter .setCsrfTokenRepository (this .repository );
204
+ when (this .repository .loadToken (any ()))
205
+ .thenReturn (Mono .just (this .token ));
206
+
207
+ WebTestClient client = WebTestClient .bindToController (new OkController ())
208
+ .webFilter (this .csrfFilter )
209
+ .build ();
210
+
211
+ client .post ()
212
+ .uri ("/" )
213
+ .contentType (MediaType .MULTIPART_FORM_DATA )
214
+ .body (fromMultipartData (this .token .getParameterName (), this .token .getToken ()))
215
+ .exchange ()
216
+ .expectStatus ().isForbidden ();
217
+ }
218
+
219
+ @ Test
220
+ public void filterWhenMultipartFormDataAndEnabledThenGranted () {
221
+ this .csrfFilter .setCsrfTokenRepository (this .repository );
222
+ this .csrfFilter .setTokenFromMultipartDataEnabled (true );
223
+ when (this .repository .loadToken (any ()))
224
+ .thenReturn (Mono .just (this .token ));
225
+ when (this .repository .generateToken (any ()))
226
+ .thenReturn (Mono .just (this .token ));
227
+
228
+ WebTestClient client = WebTestClient .bindToController (new OkController ())
229
+ .webFilter (this .csrfFilter )
230
+ .build ();
231
+
232
+ client .post ()
233
+ .uri ("/" )
234
+ .contentType (MediaType .MULTIPART_FORM_DATA )
235
+ .body (fromMultipartData (this .token .getParameterName (), this .token .getToken ()))
236
+ .exchange ()
237
+ .expectStatus ().is2xxSuccessful ();
238
+ }
239
+
240
+ @ Test
241
+ public void filterWhenFormDataAndEnabledThenGranted () {
242
+ this .csrfFilter .setCsrfTokenRepository (this .repository );
243
+ this .csrfFilter .setTokenFromMultipartDataEnabled (true );
244
+ when (this .repository .loadToken (any ()))
245
+ .thenReturn (Mono .just (this .token ));
246
+ when (this .repository .generateToken (any ()))
247
+ .thenReturn (Mono .just (this .token ));
248
+
249
+ WebTestClient client = WebTestClient .bindToController (new OkController ())
250
+ .webFilter (this .csrfFilter )
251
+ .build ();
252
+
253
+ client .post ()
254
+ .uri ("/" )
255
+ .contentType (MediaType .APPLICATION_FORM_URLENCODED )
256
+ .bodyValue (this .token .getParameterName () + "=" +this .token .getToken ())
257
+ .exchange ()
258
+ .expectStatus ().is2xxSuccessful ();
259
+ }
260
+
261
+ @ Test
262
+ public void filterWhenMultipartMixedAndEnabledThenNotRead () {
263
+ this .csrfFilter .setCsrfTokenRepository (this .repository );
264
+ this .csrfFilter .setTokenFromMultipartDataEnabled (true );
265
+ when (this .repository .loadToken (any ()))
266
+ .thenReturn (Mono .just (this .token ));
267
+
268
+ WebTestClient client = WebTestClient .bindToController (new OkController ())
269
+ .webFilter (this .csrfFilter )
270
+ .build ();
271
+
272
+ client .post ()
273
+ .uri ("/" )
274
+ .contentType (MediaType .MULTIPART_MIXED )
275
+ .bodyValue (this .token .getParameterName () + "=" +this .token .getToken ())
276
+ .exchange ()
277
+ .expectStatus ().isForbidden ();
278
+ }
279
+
280
+ @ RestController
281
+ static class OkController {
282
+ @ RequestMapping ("/**" )
283
+ String ok () {
284
+ return "ok" ;
285
+ }
286
+ }
196
287
}
0 commit comments