@@ -48,7 +48,7 @@ public final class OpenWireFormat implements WireFormat {
48
48
private static final int MARSHAL_CACHE_SIZE = Short .MAX_VALUE / 2 ;
49
49
private static final int MARSHAL_CACHE_FREE_SPACE = 100 ;
50
50
51
- private DataStreamMarshaller dataMarshallers [] ;
51
+ private DataStreamMarshaller [] dataMarshallers ;
52
52
private int version ;
53
53
private boolean stackTraceEnabled ;
54
54
private boolean tcpNoDelayEnabled ;
@@ -61,13 +61,22 @@ public final class OpenWireFormat implements WireFormat {
61
61
// The following fields are used for value caching
62
62
private short nextMarshallCacheIndex ;
63
63
private short nextMarshallCacheEvictionIndex ;
64
- private Map <DataStructure , Short > marshallCacheMap = new HashMap <DataStructure , Short >();
64
+ private Map <DataStructure , Short > marshallCacheMap = new HashMap <>();
65
65
private DataStructure marshallCache [] = null ;
66
66
private DataStructure unmarshallCache [] = null ;
67
- private DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream ();
68
- private DataByteArrayInputStream bytesIn = new DataByteArrayInputStream ();
67
+ private final DataByteArrayOutputStream bytesOut = new DataByteArrayOutputStream ();
68
+ private final DataByteArrayInputStream bytesIn = new DataByteArrayInputStream ();
69
69
private WireFormatInfo preferedWireFormatInfo ;
70
70
71
+ // Used to track the currentFrameSize for validation during unmarshalling
72
+ // Ideally we would pass the MarshallingContext directly to the marshalling methods,
73
+ // however this would require modifying the DataStreamMarshaller interface which would result
74
+ // in hundreds of existing methods having to be updated so this allows avoiding that and
75
+ // tracking the state without breaking the existing API.
76
+ // Note that while this is currently only used during unmarshalling, but if necessary could
77
+ // be extended in the future to be used during marshalling as well.
78
+ private final ThreadLocal <MarshallingContext > marshallingContext = new ThreadLocal <>();
79
+
71
80
public OpenWireFormat () {
72
81
this (DEFAULT_STORE_VERSION );
73
82
}
@@ -191,26 +200,23 @@ public synchronized ByteSequence marshal(Object command) throws IOException {
191
200
@ Override
192
201
public synchronized Object unmarshal (ByteSequence sequence ) throws IOException {
193
202
bytesIn .restart (sequence );
194
- // DataInputStream dis = new DataInputStream(new
195
- // ByteArrayInputStream(sequence));
196
-
197
- if (!sizePrefixDisabled ) {
198
- int size = bytesIn .readInt ();
199
- if (sequence .getLength () - 4 != size ) {
200
- // throw new IOException("Packet size does not match marshaled
201
- // size");
202
- }
203
203
204
- if (maxFrameSizeEnabled && size > maxFrameSize ) {
205
- throw IOExceptionSupport .createFrameSizeException (size , maxFrameSize );
204
+ try {
205
+ final var context = new MarshallingContext ();
206
+ marshallingContext .set (context );
207
+
208
+ if (!sizePrefixDisabled ) {
209
+ int size = bytesIn .readInt ();
210
+ if (maxFrameSizeEnabled && size > maxFrameSize ) {
211
+ throw IOExceptionSupport .createFrameSizeException (size , maxFrameSize );
212
+ }
213
+ context .setFrameSize (size );
206
214
}
215
+ return doUnmarshal (bytesIn );
216
+ } finally {
217
+ // After we unmarshal we can clear the context
218
+ marshallingContext .remove ();
207
219
}
208
-
209
- Object command = doUnmarshal (bytesIn );
210
- // if( !cacheEnabled && ((DataStructure)command).isMarshallAware() ) {
211
- // ((MarshallAware) command).setCachedMarshalledForm(this, sequence);
212
- // }
213
- return command ;
214
220
}
215
221
216
222
@ Override
@@ -275,19 +281,22 @@ public synchronized void marshal(Object o, DataOutput dataOut) throws IOExceptio
275
281
276
282
@ Override
277
283
public Object unmarshal (DataInput dis ) throws IOException {
278
- DataInput dataIn = dis ;
279
- if (!sizePrefixDisabled ) {
280
- int size = dis .readInt ();
281
- if (maxFrameSizeEnabled && size > maxFrameSize ) {
282
- throw IOExceptionSupport .createFrameSizeException (size , maxFrameSize );
284
+ try {
285
+ final var context = new MarshallingContext ();
286
+ marshallingContext .set (context );
287
+
288
+ if (!sizePrefixDisabled ) {
289
+ int size = dis .readInt ();
290
+ if (maxFrameSizeEnabled && size > maxFrameSize ) {
291
+ throw IOExceptionSupport .createFrameSizeException (size , maxFrameSize );
292
+ }
293
+ context .setFrameSize (size );
283
294
}
284
- // int size = dis.readInt();
285
- // byte[] data = new byte[size];
286
- // dis.readFully(data);
287
- // bytesIn.restart(data);
288
- // dataIn = bytesIn;
295
+ return doUnmarshal (dis );
296
+ } finally {
297
+ // After we unmarshal we can clear
298
+ marshallingContext .remove ();
289
299
}
290
- return doUnmarshal (dataIn );
291
300
}
292
301
293
302
/**
@@ -363,7 +372,7 @@ public void setVersion(int version) {
363
372
this .version = version ;
364
373
}
365
374
366
- public Object doUnmarshal (DataInput dis ) throws IOException {
375
+ private Object doUnmarshal (DataInput dis ) throws IOException {
367
376
byte dataType = dis .readByte ();
368
377
if (dataType != NULL_TYPE ) {
369
378
DataStreamMarshaller dsm = dataMarshallers [dataType & 0xFF ];
@@ -698,4 +707,47 @@ protected long min(long version1, long version2) {
698
707
}
699
708
return version2 ;
700
709
}
710
+
711
+ MarshallingContext getMarshallingContext () {
712
+ return marshallingContext .get ();
713
+ }
714
+
715
+ // Used to track the estimated allocated buffer sizes to validate
716
+ // against the current frame being processed
717
+ static class MarshallingContext {
718
+ // Use primitives to minimize memory footprint
719
+ private int frameSize = -1 ;
720
+ private int estimatedAllocated = 0 ;
721
+
722
+ void setFrameSize (int frameSize ) throws IOException {
723
+ this .frameSize = frameSize ;
724
+ if (frameSize < 0 ) {
725
+ throw error ("Frame size " + frameSize + " can't be negative." );
726
+ }
727
+ }
728
+
729
+ void increment (int size ) throws IOException {
730
+ if (size < 0 ) {
731
+ throw error ("Size " + size + " can't be negative." );
732
+ }
733
+ try {
734
+ estimatedAllocated = Math .addExact (estimatedAllocated , size );
735
+ } catch (ArithmeticException e ) {
736
+ throw error ("Buffer overflow when incrementing size value: " + size );
737
+ }
738
+ }
739
+
740
+ public int getFrameSize () {
741
+ return frameSize ;
742
+ }
743
+
744
+ public int getEstimatedAllocated () {
745
+ return estimatedAllocated ;
746
+ }
747
+
748
+ private static IOException error (String errorMessage ) {
749
+ return new IOException (new IllegalArgumentException (errorMessage ));
750
+ }
751
+ }
752
+
701
753
}
0 commit comments