Skip to content

Add XorCsrfChannelInterceptor #12562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration

private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";

private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";

private MessageMatcherDelegatingAuthorizationManager b;

private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
Expand All @@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration

private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();

private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();

private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(
ANY_MESSAGE_AUTHENTICATED);
Expand All @@ -86,6 +88,12 @@ public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentRes

@Override
public void configureClientInboundChannel(ChannelRegistration registration) {
ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
ChannelInterceptor.class);
if (csrfChannelInterceptor != null) {
this.csrfChannelInterceptor = csrfChannelInterceptor;
}

this.authorizationChannelInterceptor
.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,6 +61,7 @@
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
Expand All @@ -79,6 +80,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;

public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {

Expand Down Expand Up @@ -284,7 +286,7 @@ public void inboundChannelSecurityDefinedByBean() {

private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
Expand All @@ -306,7 +308,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.config.annotation.web.socket;

import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;

/**
* @author Steve Riesenberg
*/
final class TestDeferredCsrfToken implements DeferredCsrfToken {

private final CsrfToken csrfToken;

TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}

@Override
public CsrfToken get() {
return this.csrfToken;
}

@Override
public boolean isGenerated() {
return false;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,6 +70,7 @@
import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.MissingCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.util.ReflectionTestUtils;
Expand All @@ -92,6 +93,7 @@
import static org.assertj.core.api.Assertions.fail;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;

public class WebSocketMessageBrokerSecurityConfigurationTests {

Expand Down Expand Up @@ -367,7 +369,7 @@ public void sendMessageWhenAnonymousConfiguredAndLoggedInUserThenAccessDeniedExc

private void assertHandshake(HttpServletRequest request) {
TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
assertThat(handshakeHandler.attributes.get(this.sessionAttr))
.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
}
Expand All @@ -389,7 +391,7 @@ private MockHttpServletRequest sockjsHttpRequest(String mapping) {
request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
request.getSession().setAttribute(this.sessionAttr, "sessionValue");
request.setAttribute(CsrfToken.class.getName(), this.token);
request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
return request;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,6 +61,7 @@
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.security.web.csrf.CsrfToken;
import org.springframework.security.web.csrf.DefaultCsrfToken;
import org.springframework.security.web.csrf.DeferredCsrfToken;
import org.springframework.security.web.csrf.InvalidCsrfTokenException;
import org.springframework.stereotype.Controller;
import org.springframework.test.context.junit.jupiter.SpringExtension;
Expand All @@ -77,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify;
import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;

/**
Expand Down Expand Up @@ -381,12 +383,14 @@ public void requestWhenConnectMessageThenUsesCsrfTokenHandshakeInterceptor() thr
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
MvcResult result = mvc.perform(
get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue"))
.andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
Expand All @@ -398,12 +402,13 @@ public void requestWhenConnectMessageAndUsingSockJsThenUsesCsrfTokenHandshakeInt
MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
String csrfAttributeName = CsrfToken.class.getName();
String customAttributeName = this.getClass().getName();
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
.sessionAttr(customAttributeName, "attributeValue")).andReturn();
CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
assertThat(handshakeValue).isEqualTo(sessionValue)
.withFailMessage("Explicitly listed session variables are not overridden");
}
Expand Down Expand Up @@ -526,6 +531,26 @@ private SecurityContextHolderStrategy getSecurityContextHolderStrategy() {
return SecurityContextHolder.getContextHolderStrategy();
}

private static final class TestDeferredCsrfToken implements DeferredCsrfToken {

private final CsrfToken csrfToken;

TestDeferredCsrfToken(CsrfToken csrfToken) {
this.csrfToken = csrfToken;
}

@Override
public CsrfToken get() {
return this.csrfToken;
}

@Override
public boolean isGenerated() {
return false;
}

}

@Controller
static class MessageController {

Expand Down
62 changes: 62 additions & 0 deletions docs/modules/ROOT/pages/migration/servlet/exploits.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,65 @@ open fun springSecurity(http: HttpSecurity): SecurityFilterChain {
==== I need to opt out of CSRF BREACH protection for another reason

If CSRF BREACH protection does not work for you for another reason, you can opt out using the configuration from the <<servlet-opt-in-defer-loading-csrf-token>> section.

== CSRF BREACH with WebSocket support

If the steps for <<Protect against CSRF BREACH>> work for normal HTTP requests and you are using xref:servlet/integrations/websocket.adoc[WebSocket Security] support, then you can also opt into Spring Security 6's default support for BREACH protection of the `CsrfToken` with xref:servlet/integrations/websocket.adoc#websocket-sameorigin-csrf[Stomp headers].

.WebSocket Security BREACH Protection
====
.Java
[source,java,role="primary"]
----
@Bean
ChannelInterceptor csrfChannelInterceptor() {
return new XorCsrfChannelInterceptor();
}
----

.Kotlin
[source,kotlin,role="secondary"]
----
@Bean
open fun csrfChannelInterceptor(): ChannelInterceptor {
return XorCsrfChannelInterceptor()
}
----

.XML
[source,xml,role="secondary"]
----
<b:bean id="csrfChannelInterceptor"
class="org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor"/>
----
====

If configuring CSRF BREACH protection for WebSocket Security gives you trouble, you can configure the 5.8 default using the following configuration:

.Configure WebSocket Security with 5.8 default
====
.Java
[source,java,role="primary"]
----
@Bean
ChannelInterceptor csrfChannelInterceptor() {
return new CsrfChannelInterceptor();
}
----

.Kotlin
[source,kotlin,role="secondary"]
----
@Bean
open fun csrfChannelInterceptor(): ChannelInterceptor {
return CsrfChannelInterceptor()
}
----

.XML
[source,xml,role="secondary"]
----
<b:bean id="csrfChannelInterceptor"
class="org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor"/>
----
====
Loading