20
20
import org .junit .runner .RunWith ;
21
21
import org .mockito .Mock ;
22
22
import org .mockito .junit .MockitoJUnitRunner ;
23
+ import reactor .core .publisher .Mono ;
24
+ import reactor .test .StepVerifier ;
25
+ import reactor .test .publisher .PublisherProbe ;
26
+
23
27
import org .springframework .http .HttpStatus ;
24
28
import org .springframework .http .MediaType ;
25
29
import org .springframework .mock .http .server .reactive .MockServerHttpRequest ;
26
30
import org .springframework .mock .web .server .MockServerWebExchange ;
31
+ import org .springframework .security .web .server .util .matcher .ServerWebExchangeMatcher ;
27
32
import org .springframework .web .server .WebFilterChain ;
28
33
import org .springframework .web .server .WebSession ;
29
- import reactor .core .publisher .Mono ;
30
- import reactor .test .StepVerifier ;
31
- import reactor .test .publisher .PublisherProbe ;
32
34
33
35
import static org .assertj .core .api .AssertionsForInterfaceTypes .assertThat ;
34
36
import static org .mockito .ArgumentMatchers .any ;
37
+ import static org .mockito .Mockito .mock ;
38
+ import static org .mockito .Mockito .verifyZeroInteractions ;
35
39
import static org .mockito .Mockito .when ;
40
+ import static org .springframework .mock .web .server .MockServerWebExchange .from ;
36
41
37
42
/**
38
43
* @author Rob Winch
@@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
49
54
50
55
private CsrfWebFilter csrfFilter = new CsrfWebFilter ();
51
56
52
- private MockServerWebExchange get = MockServerWebExchange . from (
57
+ private MockServerWebExchange get = from (
53
58
MockServerHttpRequest .get ("/" ));
54
59
55
- private MockServerWebExchange post = MockServerWebExchange . from (
60
+ private MockServerWebExchange post = from (
56
61
MockServerHttpRequest .post ("/" ));
57
62
58
63
@ Test
@@ -104,7 +109,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamInvalidTokenThen
104
109
this .csrfFilter .setCsrfTokenRepository (this .repository );
105
110
when (this .repository .loadToken (any ()))
106
111
.thenReturn (Mono .just (this .token ));
107
- this .post = MockServerWebExchange . from (MockServerHttpRequest .post ("/" )
112
+ this .post = from (MockServerHttpRequest .post ("/" )
108
113
.body (this .token .getParameterName () + "=" +this .token .getToken ()+"INVALID" ));
109
114
110
115
Mono <Void > result = this .csrfFilter .filter (this .post , this .chain );
@@ -125,7 +130,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenCo
125
130
.thenReturn (Mono .just (this .token ));
126
131
when (this .repository .generateToken (any ()))
127
132
.thenReturn (Mono .just (this .token ));
128
- this .post = MockServerWebExchange . from (MockServerHttpRequest .post ("/" )
133
+ this .post = from (MockServerHttpRequest .post ("/" )
129
134
.contentType (MediaType .APPLICATION_FORM_URLENCODED )
130
135
.body (this .token .getParameterName () + "=" +this .token .getToken ()));
131
136
@@ -142,7 +147,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderInvalidTokenThenCsrfEx
142
147
this .csrfFilter .setCsrfTokenRepository (this .repository );
143
148
when (this .repository .loadToken (any ()))
144
149
.thenReturn (Mono .just (this .token ));
145
- this .post = MockServerWebExchange . from (MockServerHttpRequest .post ("/" )
150
+ this .post = from (MockServerHttpRequest .post ("/" )
146
151
.header (this .token .getHeaderName (), this .token .getToken ()+"INVALID" ));
147
152
148
153
Mono <Void > result = this .csrfFilter .filter (this .post , this .chain );
@@ -163,7 +168,7 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue
163
168
.thenReturn (Mono .just (this .token ));
164
169
when (this .repository .generateToken (any ()))
165
170
.thenReturn (Mono .just (this .token ));
166
- this .post = MockServerWebExchange . from (MockServerHttpRequest .post ("/" )
171
+ this .post = from (MockServerHttpRequest .post ("/" )
167
172
.header (this .token .getHeaderName (), this .token .getToken ()));
168
173
169
174
Mono <Void > result = this .csrfFilter .filter (this .post , this .chain );
@@ -173,4 +178,19 @@ public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinue
173
178
174
179
chainResult .assertWasSubscribed ();
175
180
}
181
+
182
+ @ Test
183
+ public void doFilterWhenSkipExchangeInvokedThenSkips () {
184
+ PublisherProbe <Void > chainResult = PublisherProbe .empty ();
185
+ when (this .chain .filter (any ())).thenReturn (chainResult .mono ());
186
+
187
+ ServerWebExchangeMatcher matcher = mock (ServerWebExchangeMatcher .class );
188
+ this .csrfFilter .setRequireCsrfProtectionMatcher (matcher );
189
+
190
+ MockServerWebExchange exchange = from (MockServerHttpRequest .post ("/post" ).build ());
191
+ CsrfWebFilter .skipExchange (exchange );
192
+ this .csrfFilter .filter (exchange , this .chain ).block ();
193
+
194
+ verifyZeroInteractions (matcher );
195
+ }
176
196
}
0 commit comments