Skip to content

Commit fc4372b

Browse files
authored
Merge pull request #1399 from cshannon/buffer-validation
AMQ-6596 - Validate size of buffers during unmarshalling
2 parents 78ee343 + 3037ce8 commit fc4372b

File tree

74 files changed

+796
-259
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+796
-259
lines changed

activemq-client/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@
7979
<artifactId>log4j-slf4j2-impl</artifactId>
8080
<scope>test</scope>
8181
</dependency>
82+
<dependency>
83+
<groupId>org.javassist</groupId>
84+
<artifactId>javassist</artifactId>
85+
<scope>test</scope>
86+
</dependency>
8287

8388
</dependencies>
8489

activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireFormat.java

+85-33
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public final class OpenWireFormat implements WireFormat {
4848
private static final int MARSHAL_CACHE_SIZE = Short.MAX_VALUE / 2;
4949
private static final int MARSHAL_CACHE_FREE_SPACE = 100;
5050

51-
private DataStreamMarshaller dataMarshallers[];
51+
private DataStreamMarshaller[] dataMarshallers;
5252
private int version;
5353
private boolean stackTraceEnabled;
5454
private boolean tcpNoDelayEnabled;
@@ -61,13 +61,22 @@ public final class OpenWireFormat implements WireFormat {
6161
// The following fields are used for value caching
6262
private short nextMarshallCacheIndex;
6363
private short nextMarshallCacheEvictionIndex;
64-
private Map<DataStructure, Short> marshallCacheMap = new HashMap<DataStructure, Short>();
64+
private Map<DataStructure, Short> marshallCacheMap = new HashMap<>();
6565
private DataStructure marshallCache[] = null;
6666
private DataStructure unmarshallCache[] = null;
67-
private DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
68-
private DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
67+
private final DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream();
68+
private final DataByteArrayInputStream bytesIn = new DataByteArrayInputStream();
6969
private WireFormatInfo preferedWireFormatInfo;
7070

71+
// Used to track the currentFrameSize for validation during unmarshalling
72+
// Ideally we would pass the MarshallingContext directly to the marshalling methods,
73+
// however this would require modifying the DataStreamMarshaller interface which would result
74+
// in hundreds of existing methods having to be updated so this allows avoiding that and
75+
// tracking the state without breaking the existing API.
76+
// Note that while this is currently only used during unmarshalling, but if necessary could
77+
// be extended in the future to be used during marshalling as well.
78+
private final ThreadLocal<MarshallingContext> marshallingContext = new ThreadLocal<>();
79+
7180
public OpenWireFormat() {
7281
this(DEFAULT_STORE_VERSION);
7382
}
@@ -191,26 +200,23 @@ public synchronized ByteSequence marshal(Object command) throws IOException {
191200
@Override
192201
public synchronized Object unmarshal(ByteSequence sequence) throws IOException {
193202
bytesIn.restart(sequence);
194-
// DataInputStream dis = new DataInputStream(new
195-
// ByteArrayInputStream(sequence));
196-
197-
if (!sizePrefixDisabled) {
198-
int size = bytesIn.readInt();
199-
if (sequence.getLength() - 4 != size) {
200-
// throw new IOException("Packet size does not match marshaled
201-
// size");
202-
}
203203

204-
if (maxFrameSizeEnabled && size > maxFrameSize) {
205-
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
204+
try {
205+
final var context = new MarshallingContext();
206+
marshallingContext.set(context);
207+
208+
if (!sizePrefixDisabled) {
209+
int size = bytesIn.readInt();
210+
if (maxFrameSizeEnabled && size > maxFrameSize) {
211+
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
212+
}
213+
context.setFrameSize(size);
206214
}
215+
return doUnmarshal(bytesIn);
216+
} finally {
217+
// After we unmarshal we can clear the context
218+
marshallingContext.remove();
207219
}
208-
209-
Object command = doUnmarshal(bytesIn);
210-
// if( !cacheEnabled && ((DataStructure)command).isMarshallAware() ) {
211-
// ((MarshallAware) command).setCachedMarshalledForm(this, sequence);
212-
// }
213-
return command;
214220
}
215221

216222
@Override
@@ -275,19 +281,22 @@ public synchronized void marshal(Object o, DataOutput dataOut) throws IOExceptio
275281

276282
@Override
277283
public Object unmarshal(DataInput dis) throws IOException {
278-
DataInput dataIn = dis;
279-
if (!sizePrefixDisabled) {
280-
int size = dis.readInt();
281-
if (maxFrameSizeEnabled && size > maxFrameSize) {
282-
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
284+
try {
285+
final var context = new MarshallingContext();
286+
marshallingContext.set(context);
287+
288+
if (!sizePrefixDisabled) {
289+
int size = dis.readInt();
290+
if (maxFrameSizeEnabled && size > maxFrameSize) {
291+
throw IOExceptionSupport.createFrameSizeException(size, maxFrameSize);
292+
}
293+
context.setFrameSize(size);
283294
}
284-
// int size = dis.readInt();
285-
// byte[] data = new byte[size];
286-
// dis.readFully(data);
287-
// bytesIn.restart(data);
288-
// dataIn = bytesIn;
295+
return doUnmarshal(dis);
296+
} finally {
297+
// After we unmarshal we can clear
298+
marshallingContext.remove();
289299
}
290-
return doUnmarshal(dataIn);
291300
}
292301

293302
/**
@@ -363,7 +372,7 @@ public void setVersion(int version) {
363372
this.version = version;
364373
}
365374

366-
public Object doUnmarshal(DataInput dis) throws IOException {
375+
private Object doUnmarshal(DataInput dis) throws IOException {
367376
byte dataType = dis.readByte();
368377
if (dataType != NULL_TYPE) {
369378
DataStreamMarshaller dsm = dataMarshallers[dataType & 0xFF];
@@ -698,4 +707,47 @@ protected long min(long version1, long version2) {
698707
}
699708
return version2;
700709
}
710+
711+
MarshallingContext getMarshallingContext() {
712+
return marshallingContext.get();
713+
}
714+
715+
// Used to track the estimated allocated buffer sizes to validate
716+
// against the current frame being processed
717+
static class MarshallingContext {
718+
// Use primitives to minimize memory footprint
719+
private int frameSize = -1;
720+
private int estimatedAllocated = 0;
721+
722+
void setFrameSize(int frameSize) throws IOException {
723+
this.frameSize = frameSize;
724+
if (frameSize < 0) {
725+
throw error("Frame size " + frameSize + " can't be negative.");
726+
}
727+
}
728+
729+
void increment(int size) throws IOException {
730+
if (size < 0) {
731+
throw error("Size " + size + " can't be negative.");
732+
}
733+
try {
734+
estimatedAllocated = Math.addExact(estimatedAllocated, size);
735+
} catch (ArithmeticException e) {
736+
throw error("Buffer overflow when incrementing size value: " + size);
737+
}
738+
}
739+
740+
public int getFrameSize() {
741+
return frameSize;
742+
}
743+
744+
public int getEstimatedAllocated() {
745+
return estimatedAllocated;
746+
}
747+
748+
private static IOException error(String errorMessage) {
749+
return new IOException(new IllegalArgumentException(errorMessage));
750+
}
751+
}
752+
701753
}

activemq-client/src/main/java/org/apache/activemq/openwire/OpenWireUtil.java

+49-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
*/
1717
package org.apache.activemq.openwire;
1818

19+
import java.io.IOException;
20+
import org.apache.activemq.util.IOExceptionSupport;
21+
1922
public class OpenWireUtil {
2023

21-
private static final String jmsPackageToReplace = "javax.jms";
22-
private static final String jmsPackageToUse = "jakarta.jms";
24+
static final String jmsPackageToReplace = "javax.jms";
25+
static final String jmsPackageToUse = "jakarta.jms";
2326

2427
/**
2528
* Verify that the provided class extends {@link Throwable} and throw an
@@ -33,6 +36,50 @@ public static void validateIsThrowable(Class<?> clazz) {
3336
}
3437
}
3538

39+
/**
40+
* Verify that the buffer size that will be allocated will not push the total allocated
41+
* size of this frame above the expected frame size. This is an estimate as the current
42+
* size is only tracked when calls to this method are made and is primarily intended
43+
* to prevent large arrays from being created due to an invalid size.
44+
*
45+
* Also verify the size against configured max frame size.
46+
* This check is a sanity check in case of corrupt packets contain invalid size values.
47+
*
48+
* @param wireFormat configured OpenWireFormat
49+
* @param size buffer size to verify
50+
* @throws IOException If size is larger than currentFrameSize or maxFrameSize
51+
*/
52+
public static void validateBufferSize(OpenWireFormat wireFormat, int size) throws IOException {
53+
validateLessThanFrameSize(wireFormat, size);
54+
55+
// if currentFrameSize is set and was checked above then this check should not be needed,
56+
// but it doesn't hurt to verify again in case the max frame size check was missed
57+
// somehow
58+
if (wireFormat.isMaxFrameSizeEnabled() && size > wireFormat.getMaxFrameSize()) {
59+
throw IOExceptionSupport.createFrameSizeException(size, wireFormat.getMaxFrameSize());
60+
}
61+
}
62+
63+
// Verify total tracked sizes will not exceed the overall size of the frame
64+
private static void validateLessThanFrameSize(OpenWireFormat wireFormat, int size)
65+
throws IOException {
66+
final var context = wireFormat.getMarshallingContext();
67+
// No information on current frame size so just return
68+
if (context == null || context.getFrameSize() < 0) {
69+
return;
70+
}
71+
72+
// Increment existing estimated buffer size with new size
73+
context.increment(size);
74+
75+
// We should never be trying to allocate a buffer that is going to push the total
76+
// size greater than the entire frame itself
77+
if (context.getEstimatedAllocated() > context.getFrameSize()) {
78+
throw IOExceptionSupport.createFrameSizeBufferException(
79+
context.getEstimatedAllocated(), context.getFrameSize());
80+
}
81+
}
82+
3683
/**
3784
* This method can be used to convert from javax -> jakarta or
3885
* vice versa depending on the version used by the client

activemq-client/src/main/java/org/apache/activemq/openwire/v1/BaseDataStreamMarshaller.java

+8-4
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,11 @@ protected void tightMarshalByteArray2(byte[] data, DataOutput dataOut, BooleanSt
411411
}
412412
}
413413

414-
protected byte[] tightUnmarshalByteArray(DataInput dataIn, BooleanStream bs) throws IOException {
414+
protected byte[] tightUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
415415
byte rc[] = null;
416416
if (bs.readBoolean()) {
417417
int size = dataIn.readInt();
418+
OpenWireUtil.validateBufferSize(wireFormat, size);
418419
rc = new byte[size];
419420
dataIn.readFully(rc);
420421
}
@@ -438,10 +439,11 @@ protected void tightMarshalByteSequence2(ByteSequence data, DataOutput dataOut,
438439
}
439440
}
440441

441-
protected ByteSequence tightUnmarshalByteSequence(DataInput dataIn, BooleanStream bs) throws IOException {
442+
protected ByteSequence tightUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn, BooleanStream bs) throws IOException {
442443
ByteSequence rc = null;
443444
if (bs.readBoolean()) {
444445
int size = dataIn.readInt();
446+
OpenWireUtil.validateBufferSize(wireFormat, size);
445447
byte[] t = new byte[size];
446448
dataIn.readFully(t);
447449
return new ByteSequence(t, 0, size);
@@ -618,10 +620,11 @@ protected void looseMarshalByteArray(OpenWireFormat wireFormat, byte[] data, Dat
618620
}
619621
}
620622

621-
protected byte[] looseUnmarshalByteArray(DataInput dataIn) throws IOException {
623+
protected byte[] looseUnmarshalByteArray(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
622624
byte rc[] = null;
623625
if (dataIn.readBoolean()) {
624626
int size = dataIn.readInt();
627+
OpenWireUtil.validateBufferSize(wireFormat, size);
625628
rc = new byte[size];
626629
dataIn.readFully(rc);
627630
}
@@ -637,10 +640,11 @@ protected void looseMarshalByteSequence(OpenWireFormat wireFormat, ByteSequence
637640
}
638641
}
639642

640-
protected ByteSequence looseUnmarshalByteSequence(DataInput dataIn) throws IOException {
643+
protected ByteSequence looseUnmarshalByteSequence(OpenWireFormat wireFormat, DataInput dataIn) throws IOException {
641644
ByteSequence rc = null;
642645
if (dataIn.readBoolean()) {
643646
int size = dataIn.readInt();
647+
OpenWireUtil.validateBufferSize(wireFormat, size);
644648
byte[] t = new byte[size];
645649
dataIn.readFully(t);
646650
rc = new ByteSequence(t, 0, size);

activemq-client/src/main/java/org/apache/activemq/openwire/v1/MessageMarshaller.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
6565
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
6666
info.setTimestamp(tightUnmarshalLong(wireFormat, dataIn, bs));
6767
info.setType(tightUnmarshalString(dataIn, bs));
68-
info.setContent(tightUnmarshalByteSequence(dataIn, bs));
69-
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
68+
info.setContent(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
69+
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
7070
info.setDataStructure((org.apache.activemq.command.DataStructure)tightUnmarsalNestedObject(wireFormat, dataIn, bs));
7171
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)tightUnmarsalCachedObject(wireFormat, dataIn, bs));
7272
info.setCompressed(bs.readBoolean());
@@ -196,8 +196,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
196196
info.setReplyTo((org.apache.activemq.command.ActiveMQDestination)looseUnmarsalNestedObject(wireFormat, dataIn));
197197
info.setTimestamp(looseUnmarshalLong(wireFormat, dataIn));
198198
info.setType(looseUnmarshalString(dataIn));
199-
info.setContent(looseUnmarshalByteSequence(dataIn));
200-
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
199+
info.setContent(looseUnmarshalByteSequence(wireFormat, dataIn));
200+
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));
201201
info.setDataStructure((org.apache.activemq.command.DataStructure)looseUnmarsalNestedObject(wireFormat, dataIn));
202202
info.setTargetConsumerId((org.apache.activemq.command.ConsumerId)looseUnmarsalCachedObject(wireFormat, dataIn));
203203
info.setCompressed(dataIn.readBoolean());

activemq-client/src/main/java/org/apache/activemq/openwire/v1/PartialCommandMarshaller.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
6868

6969
PartialCommand info = (PartialCommand)o;
7070
info.setCommandId(dataIn.readInt());
71-
info.setData(tightUnmarshalByteArray(dataIn, bs));
71+
info.setData(tightUnmarshalByteArray(wireFormat, dataIn, bs));
7272

7373
}
7474

@@ -114,7 +114,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
114114

115115
PartialCommand info = (PartialCommand)o;
116116
info.setCommandId(dataIn.readInt());
117-
info.setData(looseUnmarshalByteArray(dataIn));
117+
info.setData(looseUnmarshalByteArray(wireFormat, dataIn));
118118

119119
}
120120

activemq-client/src/main/java/org/apache/activemq/openwire/v1/WireFormatInfoMarshaller.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
7272

7373
info.setMagic(tightUnmarshalConstByteArray(dataIn, bs, 8));
7474
info.setVersion(dataIn.readInt());
75-
info.setMarshalledProperties(tightUnmarshalByteSequence(dataIn, bs));
75+
info.setMarshalledProperties(tightUnmarshalByteSequence(wireFormat, dataIn, bs));
7676

7777
info.afterUnmarshall(wireFormat);
7878

@@ -130,7 +130,7 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
130130

131131
info.setMagic(looseUnmarshalConstByteArray(dataIn, 8));
132132
info.setVersion(dataIn.readInt());
133-
info.setMarshalledProperties(looseUnmarshalByteSequence(dataIn));
133+
info.setMarshalledProperties(looseUnmarshalByteSequence(wireFormat, dataIn));
134134

135135
info.afterUnmarshall(wireFormat);
136136

activemq-client/src/main/java/org/apache/activemq/openwire/v1/XATransactionIdMarshaller.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ public void tightUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
6868

6969
XATransactionId info = (XATransactionId)o;
7070
info.setFormatId(dataIn.readInt());
71-
info.setGlobalTransactionId(tightUnmarshalByteArray(dataIn, bs));
72-
info.setBranchQualifier(tightUnmarshalByteArray(dataIn, bs));
71+
info.setGlobalTransactionId(tightUnmarshalByteArray(wireFormat, dataIn, bs));
72+
info.setBranchQualifier(tightUnmarshalByteArray(wireFormat, dataIn, bs));
7373

7474
}
7575

@@ -117,8 +117,8 @@ public void looseUnmarshal(OpenWireFormat wireFormat, Object o, DataInput dataIn
117117

118118
XATransactionId info = (XATransactionId)o;
119119
info.setFormatId(dataIn.readInt());
120-
info.setGlobalTransactionId(looseUnmarshalByteArray(dataIn));
121-
info.setBranchQualifier(looseUnmarshalByteArray(dataIn));
120+
info.setGlobalTransactionId(looseUnmarshalByteArray(wireFormat, dataIn));
121+
info.setBranchQualifier(looseUnmarshalByteArray(wireFormat, dataIn));
122122

123123
}
124124

0 commit comments

Comments
 (0)