diff --git a/framework-docs/modules/ROOT/pages/web/websocket/stomp/client.adoc b/framework-docs/modules/ROOT/pages/web/websocket/stomp/client.adoc index 1eae09021206..5b908999cff2 100644 --- a/framework-docs/modules/ROOT/pages/web/websocket/stomp/client.adoc +++ b/framework-docs/modules/ROOT/pages/web/websocket/stomp/client.adoc @@ -105,5 +105,20 @@ it handle ERROR frames in addition to the `handleException` callback for exceptions from the handling of messages and `handleTransportError` for transport-level errors including `ConnectionLostException`. +You can also use `setInboundMessageSizeLimit(limit)` and `setOutboundMessageSizeLimit(limit)` +to limit the maximum size of inbound and outbound message size. +When outbound message size exceeds `outboundMessageSizeLimit`, message is split into multiple incomplete frames. +Then receiver buffers these incomplete frames and reassemble to complete message. +When inbound message size exceeds `inboundMessageSizeLimit`, throw `StompConversionException`. +The default value of in&outboundMessageSizeLimit is `64KB`. + +[source,java,indent=0,subs="verbatim,quotes"] +---- + WebSocketClient webSocketClient = new StandardWebSocketClient(); + WebSocketStompClient stompClient = new WebSocketStompClient(webSocketClient); + stompClient.setInboundMessageSizeLimit(64 * 1024); // 64KB + stompClient.setOutboundMessageSizeLimit(64 * 1024); // 64KB +---- + diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/SplittingStompEncoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/SplittingStompEncoder.java new file mode 100644 index 000000000000..eec6e54dfe03 --- /dev/null +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/SplittingStompEncoder.java @@ -0,0 +1,68 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.springframework.util.Assert; + +/** + * An extension of {@link org.springframework.messaging.simp.stomp.StompEncoder} + * that splits the STOMP message to multiple incomplete STOMP frames + * when the encoded bytes length exceeds {@link SplittingStompEncoder#bufferSizeLimit}. + * + * @author Injae Kim + * @since 6.2 + * @see StompEncoder + */ +public class SplittingStompEncoder { + + private final StompEncoder encoder; + + private final int bufferSizeLimit; + + public SplittingStompEncoder(StompEncoder encoder, int bufferSizeLimit) { + Assert.notNull(encoder, "StompEncoder is required"); + Assert.isTrue(bufferSizeLimit > 0, "Buffer size limit must be greater than 0"); + this.encoder = encoder; + this.bufferSizeLimit = bufferSizeLimit; + } + + /** + * Encodes the given payload and headers into a list of one or more {@code byte[]}s. + * @param headers the headers + * @param payload the payload + * @return the list of one or more encoded messages + */ + public List encode(Map headers, byte[] payload) { + byte[] result = this.encoder.encode(headers, payload); + int length = result.length; + + if (length <= this.bufferSizeLimit) { + return List.of(result); + } + + List frames = new ArrayList<>(); + for (int i = 0; i < length; i += this.bufferSizeLimit) { + frames.add(Arrays.copyOfRange(result, i, Math.min(i + this.bufferSizeLimit, length))); + } + return frames; + } +} diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java index 234d9917e06f..18f917ca32df 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/stomp/StompDecoder.java @@ -78,7 +78,7 @@ public MessageHeaderInitializer getHeaderInitializer() { * Decodes one or more STOMP frames from the given {@code ByteBuffer} into a * list of {@link Message Messages}. If the input buffer contains partial STOMP frame * content, or additional content with a partial STOMP frame, the buffer is - * reset and {@code null} is returned. + * reset and an empty list is returned. * @param byteBuffer the buffer to decode the STOMP frame from * @return the decoded messages, or an empty list if none * @throws StompConversionException raised in case of decoding issues diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/SplittingStompEncoderTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/SplittingStompEncoderTests.java new file mode 100644 index 000000000000..8b37d0ea297a --- /dev/null +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/stomp/SplittingStompEncoderTests.java @@ -0,0 +1,382 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.simp.stomp; + +import java.io.ByteArrayOutputStream; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link SplittingStompEncoder}. + * + * @author Injae Kim + * @since 6.2 + */ +public class SplittingStompEncoderTests { + + private final StompEncoder STOMP_ENCODER = new StompEncoder(); + + private static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024; + + @Test + public void encodeFrameWithNoHeadersAndNoBody() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0"); + assertThat(actual.size()).isOne(); + } + + @Test + public void encodeFrameWithNoHeadersAndNoBodySplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 7); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0"); + assertThat(actual.size()).isEqualTo(2); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 7)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 7, outputStream.size())); + } + + @Test + public void encodeFrameWithNoHeadersAndNoBodySplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 3); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\n\n\0"); + assertThat(actual.size()).isEqualTo(5); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 3)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 3, 6)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 6, 9)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 9, 12)); + assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 12, outputStream.size())); + } + + @Test + public void encodeFrameWithHeaders() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.2"); + headers.setHost("github.org"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + String actualString = outputStream.toString(); + + assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) || + "CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue(); + assertThat(actual.size()).isOne(); + } + + @Test + public void encodeFrameWithHeadersSplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.2"); + headers.setHost("github.org"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + String actualString = outputStream.toString(); + + assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) || + "CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue(); + assertThat(actual.size()).isEqualTo(2); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size())); + } + + @Test + public void encodeFrameWithHeadersSplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT); + headers.setAcceptVersion("1.2"); + headers.setHost("github.org"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + String actualString = outputStream.toString(); + + assertThat("CONNECT\naccept-version:1.2\nhost:github.org\n\n\0".equals(actualString) || + "CONNECT\nhost:github.org\naccept-version:1.2\n\n\0".equals(actualString)).isTrue(); + assertThat(actual.size()).isEqualTo(5); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40)); + assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size())); + } + + @Test + public void encodeFrameWithHeadersThatShouldBeEscaped() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0"); + assertThat(actual.size()).isOne(); + } + + @Test + public void encodeFrameWithHeadersThatShouldBeEscapedSplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0"); + assertThat(actual.size()).isEqualTo(2); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size())); + } + + + @Test + public void encodeFrameWithHeadersThatShouldBeEscapedSplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.DISCONNECT); + headers.addNativeHeader("a:\r\n\\b", "alpha:bravo\r\n\\"); + Message frame = MessageBuilder.createMessage(new byte[0], headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + String actualString = outputStream.toString(); + + assertThat(outputStream.toString()).isEqualTo("DISCONNECT\na\\c\\r\\n\\\\b:alpha\\cbravo\\r\\n\\\\\n\n\0"); + assertThat(actual.size()).isEqualTo(5); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40)); + assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size())); + } + + + @Test + public void encodeFrameWithHeadersBody() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "alpha"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isOne(); + } + + @Test + public void encodeFrameWithHeadersBodySplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 30); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "alpha"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(2); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 30)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size())); + } + + @Test + public void encodeFrameWithHeadersBodySplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "alpha"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:alpha\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(5); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, 40)); + assertThat(actual.get(4)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 40, outputStream.size())); + } + + @Test + public void encodeFrameWithContentLengthPresent() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, DEFAULT_MESSAGE_MAX_SIZE); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setContentLength(12); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isOne(); + } + + @Test + public void encodeFrameWithContentLengthPresentSplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 20); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setContentLength(12); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(2); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 20)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, outputStream.size())); + } + + @Test + public void encodeFrameWithContentLengthPresentSplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 10); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.setContentLength(12); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(4); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 10)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 10, 20)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 20, 30)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 30, outputStream.size())); + } + + @Test + public void sameLengthAndBufferSizeLimit() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 44); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "1234"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isOne(); + assertThat(outputStream.toByteArray().length).isEqualTo(44); + } + + @Test + public void lengthAndBufferSizeLimitExactlySplitTwoFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 22); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "1234"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(2); + assertThat(outputStream.toByteArray().length).isEqualTo(44); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 22)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 22, 44)); + } + + @Test + public void lengthAndBufferSizeLimitExactlySplitMultipleFrames() { + SplittingStompEncoder encoder = new SplittingStompEncoder(STOMP_ENCODER, 11); + StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND); + headers.addNativeHeader("a", "1234"); + Message frame = MessageBuilder.createMessage( + "Message body".getBytes(), headers.getMessageHeaders()); + + List actual = encoder.encode(frame.getHeaders(), frame.getPayload()); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + actual.forEach(outputStream::writeBytes); + + assertThat(outputStream.toString()).isEqualTo("SEND\na:1234\ncontent-length:12\n\nMessage body\0"); + assertThat(actual.size()).isEqualTo(4); + assertThat(outputStream.toByteArray().length).isEqualTo(44); + assertThat(actual.get(0)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 0, 11)); + assertThat(actual.get(1)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 11, 22)); + assertThat(actual.get(2)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 22, 33)); + assertThat(actual.get(3)).isEqualTo(Arrays.copyOfRange(outputStream.toByteArray(), 33, 44)); + } + + @Test + public void bufferSizeLimitShouldBePositive() { + assertThatThrownBy(() -> new SplittingStompEncoder(STOMP_ENCODER, 0)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new SplittingStompEncoder(STOMP_ENCODER, -1)) + .isInstanceOf(IllegalArgumentException.class); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/WebSocketStompClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/WebSocketStompClient.java index a6ebe75ec20e..a6462d78d746 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/WebSocketStompClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/WebSocketStompClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,7 @@ import org.springframework.messaging.Message; import org.springframework.messaging.simp.stomp.BufferingStompDecoder; import org.springframework.messaging.simp.stomp.ConnectionHandlingStompSession; +import org.springframework.messaging.simp.stomp.SplittingStompEncoder; import org.springframework.messaging.simp.stomp.StompClientSupport; import org.springframework.messaging.simp.stomp.StompDecoder; import org.springframework.messaging.simp.stomp.StompEncoder; @@ -68,15 +69,23 @@ * SockJsClient}. * * @author Rossen Stoyanchev + * @author Injae Kim * @since 4.2 */ public class WebSocketStompClient extends StompClientSupport implements SmartLifecycle { private static final Log logger = LogFactory.getLog(WebSocketStompClient.class); + /** + * The default max size for in&outbound STOMP message. + */ + private static final int DEFAULT_MESSAGE_MAX_SIZE = 64 * 1024; + private final WebSocketClient webSocketClient; - private int inboundMessageSizeLimit = 64 * 1024; + private int inboundMessageSizeLimit = DEFAULT_MESSAGE_MAX_SIZE; + + private int outboundMessageSizeLimit = DEFAULT_MESSAGE_MAX_SIZE; private boolean autoStartup = true; @@ -123,7 +132,7 @@ public void setTaskScheduler(@Nullable TaskScheduler taskScheduler) { * Since a STOMP message can be received in multiple WebSocket messages, * buffering may be required and this property determines the maximum buffer * size per message. - *

By default this is set to 64 * 1024 (64K). + *

By default this is set to 64 * 1024 (64K), see {@link WebSocketStompClient#DEFAULT_MESSAGE_MAX_SIZE}. */ public void setInboundMessageSizeLimit(int inboundMessageSizeLimit) { this.inboundMessageSizeLimit = inboundMessageSizeLimit; @@ -136,6 +145,25 @@ public int getInboundMessageSizeLimit() { return this.inboundMessageSizeLimit; } + /** + * Configure the maximum size allowed for outbound STOMP message. + * If STOMP message's size exceeds {@link WebSocketStompClient#outboundMessageSizeLimit}, + * STOMP message is split into multiple frames. + *

By default this is set to 64 * 1024 (64K), see {@link WebSocketStompClient#DEFAULT_MESSAGE_MAX_SIZE}. + * @since 6.2 + */ + public void setOutboundMessageSizeLimit(int outboundMessageSizeLimit) { + this.outboundMessageSizeLimit = outboundMessageSizeLimit; + } + + /** + * Get the configured outbound message buffer size in bytes. + * @since 6.2 + */ + public int getOutboundMessageSizeLimit() { + return this.outboundMessageSizeLimit; + } + /** * Set whether to auto-start the contained WebSocketClient when the Spring * context has been refreshed. @@ -374,7 +402,8 @@ private class WebSocketTcpConnectionHandlerAdapter implements BiConsumer connectionHandler; - private final StompWebSocketMessageCodec codec = new StompWebSocketMessageCodec(getInboundMessageSizeLimit()); + private final StompWebSocketMessageCodec codec = + new StompWebSocketMessageCodec(getInboundMessageSizeLimit(),getOutboundMessageSizeLimit()); @Nullable private volatile WebSocketSession session; @@ -462,7 +491,9 @@ public CompletableFuture sendAsync(Message message) { try { WebSocketSession session = this.session; Assert.state(session != null, "No WebSocketSession available"); - session.sendMessage(this.codec.encode(message, session.getClass())); + for (WebSocketMessage webSocketMessage : this.codec.encode(message, session.getClass())) { + session.sendMessage(webSocketMessage); + } future.complete(null); } catch (Throwable ex) { @@ -547,8 +578,11 @@ private static class StompWebSocketMessageCodec { private final BufferingStompDecoder bufferingDecoder; - public StompWebSocketMessageCodec(int messageSizeLimit) { - this.bufferingDecoder = new BufferingStompDecoder(DECODER, messageSizeLimit); + private final SplittingStompEncoder splittingEncoder; + + public StompWebSocketMessageCodec(int inboundMessageSizeLimit, int outboundMessageSizeLimit) { + this.bufferingDecoder = new BufferingStompDecoder(DECODER, inboundMessageSizeLimit); + this.splittingEncoder = new SplittingStompEncoder(ENCODER, outboundMessageSizeLimit); } public List> decode(WebSocketMessage webSocketMessage) { @@ -574,17 +608,21 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) { return result; } - public WebSocketMessage encode(Message message, Class sessionType) { + public List> encode(Message message, Class sessionType) { StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class); Assert.notNull(accessor, "No StompHeaderAccessor available"); byte[] payload = message.getPayload(); - byte[] bytes = ENCODER.encode(accessor.getMessageHeaders(), payload); + List frames = splittingEncoder.encode(accessor.getMessageHeaders(), payload); boolean useBinary = (payload.length > 0 && !(SockJsSession.class.isAssignableFrom(sessionType)) && MimeTypeUtils.APPLICATION_OCTET_STREAM.isCompatibleWith(accessor.getContentType())); - return (useBinary ? new BinaryMessage(bytes) : new TextMessage(bytes)); + List> messages = new ArrayList<>(); + for (byte[] frame : frames) { + messages.add(useBinary ? new BinaryMessage(frame) : new TextMessage(frame)); + } + return messages; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientTests.java index 019bed803967..e28d29b9e194 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/WebSocketStompClientTests.java @@ -65,6 +65,7 @@ * Tests for {@link WebSocketStompClient}. * * @author Rossen Stoyanchev + * @author Injae Kim */ @MockitoSettings(strictness = Strictness.LENIENT) class WebSocketStompClientTests { @@ -211,6 +212,29 @@ void sendWebSocketMessage() throws Exception { assertThat(textMessage.getPayload()).isEqualTo("SEND\ndestination:/topic/foo\ncontent-length:7\n\npayload\0"); } + @Test + void sendWebSocketMessageExceedOutboundMessageSizeLimit() throws Exception { + stompClient.setOutboundMessageSizeLimit(30); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); + accessor.setDestination("/topic/foo"); + byte[] payload = "payload".getBytes(StandardCharsets.UTF_8); + + getTcpConnection().sendAsync(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); + + ArgumentCaptor textMessageCaptor = ArgumentCaptor.forClass(TextMessage.class); + verify(this.webSocketSession, times(2)).sendMessage(textMessageCaptor.capture()); + TextMessage textMessage = textMessageCaptor.getAllValues().get(0); + assertThat(textMessage).isNotNull(); + assertThat(textMessage.getPayload()).isEqualTo("SEND\ndestination:/topic/foo\nco"); + assertThat(textMessage.getPayload().getBytes().length).isEqualTo(30); + + textMessage = textMessageCaptor.getAllValues().get(1); + assertThat(textMessage).isNotNull(); + assertThat(textMessage.getPayload()).isEqualTo("ntent-length:7\n\npayload\0"); + assertThat(textMessage.getPayload().getBytes().length).isEqualTo(24); + } + + @Test void sendWebSocketBinary() throws Exception { StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); @@ -228,6 +252,49 @@ void sendWebSocketBinary() throws Exception { .isEqualTo("SEND\ndestination:/b\ncontent-type:application/octet-stream\ncontent-length:7\n\npayload\0"); } + @Test + void sendWebSocketBinaryExceedOutboundMessageSizeLimit() throws Exception { + stompClient.setOutboundMessageSizeLimit(50); + StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND); + accessor.setDestination("/b"); + accessor.setContentType(MimeTypeUtils.APPLICATION_OCTET_STREAM); + byte[] payload = "payload".getBytes(StandardCharsets.UTF_8); + + getTcpConnection().sendAsync(MessageBuilder.createMessage(payload, accessor.getMessageHeaders())); + + ArgumentCaptor binaryMessageCaptor = ArgumentCaptor.forClass(BinaryMessage.class); + verify(this.webSocketSession, times(2)).sendMessage(binaryMessageCaptor.capture()); + BinaryMessage binaryMessage = binaryMessageCaptor.getAllValues().get(0); + assertThat(binaryMessage).isNotNull(); + assertThat(new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8)) + .isEqualTo("SEND\ndestination:/b\ncontent-type:application/octet"); + assertThat(binaryMessage.getPayload().array().length).isEqualTo(50); + + binaryMessage = binaryMessageCaptor.getAllValues().get(1); + assertThat(binaryMessage).isNotNull(); + assertThat(new String(binaryMessage.getPayload().array(), StandardCharsets.UTF_8)) + .isEqualTo("-stream\ncontent-length:7\n\npayload\0"); + assertThat(binaryMessage.getPayload().array().length).isEqualTo(34); + } + + @Test + void reassembleReceivedIFragmentedFrames() throws Exception { + WebSocketHandler handler = connect(); + handler.handleMessage(this.webSocketSession, new TextMessage("SEND\ndestination:/topic/foo\nco")); + handler.handleMessage(this.webSocketSession, new TextMessage("ntent-length:7\n\npayload\0")); + + ArgumentCaptor receiveMessageCaptor = ArgumentCaptor.forClass(Message.class); + verify(this.stompSession).handleMessage(receiveMessageCaptor.capture()); + Message receiveMessage = receiveMessageCaptor.getValue(); + assertThat(receiveMessage).isNotNull(); + + StompHeaderAccessor headers = StompHeaderAccessor.wrap(receiveMessage); + assertThat(headers.toNativeHeaderMap()).hasSize(2); + assertThat(headers.getContentLength()).isEqualTo(7); + assertThat(headers.getDestination()).isEqualTo("/topic/foo"); + assertThat(new String(receiveMessage.getPayload())).isEqualTo("payload"); + } + @Test void heartbeatDefaultValue() { WebSocketStompClient stompClient = new WebSocketStompClient(mock());