diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java index 8dfc4da381d..1a23f8d5eca 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping"; + private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor"; + private MessageMatcherDelegatingAuthorizationManager b; private static final AuthorizationManager> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager @@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor(); - private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); + private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor(); private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor( ANY_MESSAGE_AUTHENTICATED); @@ -86,6 +88,12 @@ public void addArgumentResolvers(List argumentRes @Override public void configureClientInboundChannel(ChannelRegistration registration) { + ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME, + ChannelInterceptor.class); + if (csrfChannelInterceptor != null) { + this.csrfChannelInterceptor = csrfChannelInterceptor; + } + this.authorizationChannelInterceptor .setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context)); this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 8d0ad848357..2286ff6f602 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -79,6 +80,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @@ -284,7 +286,7 @@ public void inboundChannelSecurityDefinedByBean() { private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -306,7 +308,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java new file mode 100644 index 00000000000..11aa0de8365 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.annotation.web.socket; + +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; + +/** + * @author Steve Riesenberg + */ +final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java index 8823fe39cbc..6977207218e 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.MissingCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.util.ReflectionTestUtils; @@ -92,6 +93,7 @@ import static org.assertj.core.api.Assertions.fail; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; public class WebSocketMessageBrokerSecurityConfigurationTests { @@ -367,7 +369,7 @@ public void sendMessageWhenAnonymousConfiguredAndLoggedInUserThenAccessDeniedExc private void assertHandshake(HttpServletRequest request) { TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class); - assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token); + assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token); assertThat(handshakeHandler.attributes.get(this.sessionAttr)) .isEqualTo(request.getSession().getAttribute(this.sessionAttr)); } @@ -389,7 +391,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) { request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket"); request.setRequestURI(mapping + "/289/tpyx6mde/websocket"); request.getSession().setAttribute(this.sessionAttr, "sessionValue"); - request.setAttribute(CsrfToken.class.getName(), this.token); + request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)); return request; } diff --git a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java index d39e2c93710..f9243db209b 100644 --- a/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java +++ b/config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -61,6 +61,7 @@ import org.springframework.security.test.context.support.WithMockUser; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.security.web.csrf.InvalidCsrfTokenException; import org.springframework.stereotype.Controller; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -77,6 +78,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.verify; +import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; /** @@ -381,12 +383,14 @@ public void requestWhenConnectMessageThenUsesCsrfTokenHandshakeInterceptor() thr MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token) - .sessionAttr(customAttributeName, "attributeValue")).andReturn(); + MvcResult result = mvc.perform( + get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) + .sessionAttr(customAttributeName, "attributeValue")) + .andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -398,12 +402,13 @@ public void requestWhenConnectMessageAndUsingSockJsThenUsesCsrfTokenHandshakeInt MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build(); String csrfAttributeName = CsrfToken.class.getName(); String customAttributeName = this.getClass().getName(); - MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token) + MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket") + .requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token)) .sessionAttr(customAttributeName, "attributeValue")).andReturn(); CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName); String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName); String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName); - assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); + assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated"); assertThat(handshakeValue).isEqualTo(sessionValue) .withFailMessage("Explicitly listed session variables are not overridden"); } @@ -526,6 +531,26 @@ private SecurityContextHolderStrategy getSecurityContextHolderStrategy() { return SecurityContextHolder.getContextHolderStrategy(); } + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + + } + @Controller static class MessageController { diff --git a/docs/modules/ROOT/pages/migration/servlet/exploits.adoc b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc index 0e41a8a263b..379c1f83ac7 100644 --- a/docs/modules/ROOT/pages/migration/servlet/exploits.adoc +++ b/docs/modules/ROOT/pages/migration/servlet/exploits.adoc @@ -243,3 +243,65 @@ open fun springSecurity(http: HttpSecurity): SecurityFilterChain { ==== I need to opt out of CSRF BREACH protection for another reason If CSRF BREACH protection does not work for you for another reason, you can opt out using the configuration from the <> section. + +== CSRF BREACH with WebSocket support + +If the steps for <> work for normal HTTP requests and you are using xref:servlet/integrations/websocket.adoc[WebSocket Security] support, then you can also opt into Spring Security 6's default support for BREACH protection of the `CsrfToken` with xref:servlet/integrations/websocket.adoc#websocket-sameorigin-csrf[Stomp headers]. + +.WebSocket Security BREACH Protection +==== +.Java +[source,java,role="primary"] +---- +@Bean +ChannelInterceptor csrfChannelInterceptor() { + return new XorCsrfChannelInterceptor(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +open fun csrfChannelInterceptor(): ChannelInterceptor { + return XorCsrfChannelInterceptor() +} +---- + +.XML +[source,xml,role="secondary"] +---- + +---- +==== + +If configuring CSRF BREACH protection for WebSocket Security gives you trouble, you can configure the 5.8 default using the following configuration: + +.Configure WebSocket Security with 5.8 default +==== +.Java +[source,java,role="primary"] +---- +@Bean +ChannelInterceptor csrfChannelInterceptor() { + return new CsrfChannelInterceptor(); +} +---- + +.Kotlin +[source,kotlin,role="secondary"] +---- +@Bean +open fun csrfChannelInterceptor(): ChannelInterceptor { + return CsrfChannelInterceptor() +} +---- + +.XML +[source,xml,role="secondary"] +---- + +---- +==== diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java new file mode 100644 index 00000000000..2d7b3d1c8c6 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.security.MessageDigest; +import java.util.Map; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.ChannelInterceptor; +import org.springframework.security.crypto.codec.Utf8; +import org.springframework.security.messaging.util.matcher.MessageMatcher; +import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +/** + * {@link ChannelInterceptor} that validates a CSRF token masked by the + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in + * the header of any {@link SimpMessageType#CONNECT} message. + * + * @author Steve Riesenberg + * @since 5.8 + */ +public final class XorCsrfChannelInterceptor implements ChannelInterceptor { + + private final MessageMatcher matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); + + @Override + public Message preSend(Message message, MessageChannel channel) { + if (!this.matcher.matches(message)) { + return message; + } + Map sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); + CsrfToken expectedToken = (sessionAttributes != null) + ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; + if (expectedToken == null) { + throw new MissingCsrfTokenException(null); + } + String actualToken = SimpMessageHeaderAccessor.wrap(message) + .getFirstNativeHeader(expectedToken.getHeaderName()); + String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken()); + boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue); + if (!csrfCheckPassed) { + throw new InvalidCsrfTokenException(expectedToken, actualToken); + } + return message; + } + + /** + * Constant time comparison to prevent against timing attacks. + * @param expected + * @param actual + * @return + */ + private static boolean equalsConstantTime(String expected, String actual) { + if (expected == actual) { + return true; + } + if (expected == null || actual == null) { + return false; + } + // Encode after ensure that the string is not null + byte[] expectedBytes = Utf8.encode(expected); + byte[] actualBytes = Utf8.encode(actual); + return MessageDigest.isEqual(expectedBytes, actualBytes); + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java new file mode 100644 index 00000000000..46a67cc4d39 --- /dev/null +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java @@ -0,0 +1,72 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.Base64; + +import org.springframework.security.crypto.codec.Utf8; + +/** + * Copied from + * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}. + * + * @see gh-12378 + */ +final class XorCsrfTokenUtils { + + private XorCsrfTokenUtils() { + } + + static String getTokenValue(String actualToken, String token) { + byte[] actualBytes; + try { + actualBytes = Base64.getUrlDecoder().decode(actualToken); + } + catch (Exception ex) { + return null; + } + + byte[] tokenBytes = Utf8.encode(token); + int tokenSize = tokenBytes.length; + if (actualBytes.length < tokenSize) { + return null; + } + + // extract token and random bytes + int randomBytesSize = actualBytes.length - tokenSize; + byte[] xoredCsrf = new byte[tokenSize]; + byte[] randomBytes = new byte[randomBytesSize]; + + System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize); + System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize); + + byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf); + return Utf8.decode(csrfBytes); + } + + private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) { + int len = Math.min(randomBytes.length, csrfBytes.length); + byte[] xoredCsrf = new byte[len]; + System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length); + for (int i = 0; i < len; i++) { + xoredCsrf[i] ^= randomBytes[i]; + } + return xoredCsrf; + } + +} diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java index aa40975f2f6..1c917d82ee2 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2015 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpResponse; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.server.HandshakeInterceptor; /** - * Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket - * attributes. This is used as the expected CsrfToken when validating connection requests - * to ensure only the same origin connects. + * Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the + * WebSocket attributes. This is used as the expected CsrfToken when validating connection + * requests to ensure only the same origin connects. * * @author Rob Winch + * @author Steve Riesenberg * @since 4.0 */ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { @@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map attributes) { HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest(); - CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); - if (token == null) { + DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest + .getAttribute(DeferredCsrfToken.class.getName()); + if (deferredCsrfToken == null) { return true; } - attributes.put(CsrfToken.class.getName(), token); + CsrfToken csrfToken = deferredCsrfToken.get(); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), + csrfToken.getToken()); + attributes.put(CsrfToken.class.getName(), resolvedCsrfToken); return true; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java new file mode 100644 index 00000000000..884c3d2fc20 --- /dev/null +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.messaging.web.csrf; + +import java.util.HashMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessageType; +import org.springframework.messaging.support.MessageBuilder; +import org.springframework.security.web.csrf.CsrfToken; +import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.InvalidCsrfTokenException; +import org.springframework.security.web.csrf.MissingCsrfTokenException; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link XorCsrfChannelInterceptor}. + * + * @author Steve Riesenberg + */ +public class XorCsrfChannelInterceptorTests { + + private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA=="; + + private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ=="; + + private CsrfToken token; + + private SimpMessageHeaderAccessor messageHeaders; + + private MessageChannel channel; + + private XorCsrfChannelInterceptor interceptor; + + @BeforeEach + public void setup() { + this.token = new DefaultCsrfToken("header", "param", "token"); + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); + this.messageHeaders.setSessionAttributes(new HashMap<>()); + this.channel = mock(MessageChannel.class); + this.interceptor = new XorCsrfChannelInterceptor(); + } + + @Test + public void preSendWhenConnectWithValidTokenThenSuccess() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE); + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() { + this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token); + // @formatter:off + assertThatExceptionOfType(InvalidCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() { + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() { + this.messageHeaders.setSessionAttributes(null); + // @formatter:off + assertThatExceptionOfType(MissingCsrfTokenException.class) + .isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class))); + // @formatter:on + } + + @Test + public void preSendWhenAckThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenDisconnectThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenHeartbeatThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenMessageThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenOtherThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); + this.interceptor.preSend(message(), this.channel); + } + + @Test + public void preSendWhenUnsubscribeThenIgnores() { + this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + this.interceptor.preSend(message(), this.channel); + } + + private Message message() { + return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build(); + } + +} diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java index a3c38d14d91..760fe3aa1b7 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.web.csrf.CsrfToken; import org.springframework.security.web.csrf.DefaultCsrfToken; +import org.springframework.security.web.csrf.DeferredCsrfToken; import org.springframework.web.socket.WebSocketHandler; import static org.assertj.core.api.Assertions.assertThat; @@ -72,10 +73,38 @@ public void beforeHandshakeNoAttribute() throws Exception { @Test public void beforeHandshake() throws Exception { CsrfToken token = new DefaultCsrfToken("header", "param", "token"); - this.httpRequest.setAttribute(CsrfToken.class.getName(), token); + this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token)); this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes); assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName()); - assertThat(this.attributes.values()).containsOnly(token); + CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName()); + assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName()); + assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName()); + assertThat(csrfToken.getToken()).isEqualTo(token.getToken()); + // Ensure the values of the CsrfToken are copied into a new token so the old token + // is available for garbage collection. + // This is required because the original token could hold a reference to the + // HttpServletRequest/Response of the handshake request. + assertThat(csrfToken).isNotSameAs(token); + } + + private static final class TestDeferredCsrfToken implements DeferredCsrfToken { + + private final CsrfToken csrfToken; + + private TestDeferredCsrfToken(CsrfToken csrfToken) { + this.csrfToken = csrfToken; + } + + @Override + public CsrfToken get() { + return this.csrfToken; + } + + @Override + public boolean isGenerated() { + return false; + } + } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 5f3b94b6c91..3f966832a44 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -108,6 +108,7 @@ protected boolean shouldNotFilter(HttpServletRequest request) throws ServletExce protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response); + request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken); this.requestHandler.handle(request, response, deferredCsrfToken::get); if (!this.requireCsrfProtectionMatcher.matches(request)) { if (this.logger.isTraceEnabled()) { diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 4ad810329a3..68875e05a87 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -126,11 +126,12 @@ public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException, @Test public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -138,12 +139,13 @@ public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOExce @Test public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -151,12 +153,13 @@ public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, @Test public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -165,13 +168,14 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletExce public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID"); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -179,11 +183,12 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParamete @Test public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -191,11 +196,12 @@ public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOExc @Test public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -203,12 +209,13 @@ public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOExc @Test public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -217,13 +224,14 @@ public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID"); this.request.addHeader(this.token.getHeaderName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); } @@ -231,12 +239,13 @@ public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() @Test public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); verifyNoMoreInteractions(this.deniedHandler); verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), @@ -246,12 +255,13 @@ public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOExce @Test public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); // LazyCsrfTokenRepository requires the response as an attribute assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response); verify(this.filterChain).doFilter(this.request, this.response); @@ -316,11 +326,12 @@ public void doFilterDefaultAccessDenied() throws ServletException, IOException { this.filter = new CsrfFilter(this.tokenRepository); this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher); given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); this.filter.doFilter(this.request, this.response, this.filterChain); assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token); assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN); verifyNoMoreInteractions(this.filterChain); } @@ -344,22 +355,24 @@ public void doFilterWhenTokenIsNullThenNoNullPointer() throws Exception { given(token.getToken()).willReturn(null); given(token.getHeaderName()).willReturn(this.token.getHeaderName()); given(token.getParameterName()).willReturn(this.token.getParameterName()); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); given(this.requestMatcher.matches(this.request)).willReturn(true); filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); } @Test public void doFilterWhenRequestHandlerThenUsed() throws Exception { - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class); this.filter = createCsrfFilter(this.tokenRepository); this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.tokenRepository).loadDeferredToken(this.request, this.response); verify(requestHandler).handle(eq(this.request), eq(this.response), any()); verify(this.filterChain).doFilter(this.request, this.response); @@ -368,14 +381,15 @@ public void doFilterWhenRequestHandlerThenUsed() throws Exception { @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(false); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); requestHandler.setCsrfRequestAttributeName(this.token.getParameterName()); this.filter.setRequestHandler(requestHandler); this.filter.doFilter(this.request, this.response, this.filterChain); assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull(); assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull(); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.filterChain).doFilter(this.request, this.response); assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); @@ -394,12 +408,13 @@ public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSucc @Test public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception { given(this.requestMatcher.matches(this.request)).willReturn(true); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(this.token, false)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler(); this.filter.setRequestHandler(requestHandler); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class)); verifyNoMoreInteractions(this.filterChain); } @@ -424,10 +439,11 @@ public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGe requestHandler.setCsrfRequestAttributeName(csrfAttrName); filter.setRequestHandler(requestHandler); CsrfToken expectedCsrfToken = mock(CsrfToken.class); - given(this.tokenRepository.loadDeferredToken(this.request, this.response)) - .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true)); + DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true); + given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken); filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken); verifyNoInteractions(expectedCsrfToken); CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);