Skip to content

Commit 24160d5

Browse files
mhansencopybara-github
authored andcommitted
CodedOutputStream: Avoid updating position to go beyond end of array.
When writing varints. This has twofold goals: 1. Correctness: if position overruns the array, checking space left may return a negative number. I'm not sure how bad that is, but let's avoid it. 2. Performance. This generates more optimal assembly code which can combine bounds checks, particularly on Android (I haven't looked at the generated assembly on the server JVM; it's possible the server JVM can already performance this hoist). The `position` field is stored on the object, so Android ART generates assembly codes for `this.position++` like "load, add, store": ``` ldr w3, [x1, #12] add w4, w3, #0x1 (1) str w4, [x1, #12] ``` There can be a lot of these loads/stores executed each step of a loop (e.g. writeFixed64NoTag updates position 8 times, and varint encoding could do it even more). It's faster if we can hoist these so we load once at the start of the function, and store once at the end of the function. This also has the nice benefit that it won't store if we've thrown an exception. See before/after in Compiler Explorer: https://godbolt.org/z/bWWYqsxK4. I'm not an assembly expert, but it seems clear that the increment instructions like `add w4, w0, #0x1 (1)` are no longer always surrounded by loads and stores in the new version. PiperOrigin-RevId: 681644516
1 parent d88a3d0 commit 24160d5

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

java/core/src/main/java/com/google/protobuf/CodedOutputStream.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,11 +1347,12 @@ public final void writeInt32NoTag(int value) throws IOException {
13471347

13481348
@Override
13491349
public final void writeUInt32NoTag(int value) throws IOException {
1350+
int position = this.position; // Perf: hoist field to register to avoid load/stores.
13501351
try {
13511352
while (true) {
13521353
if ((value & ~0x7F) == 0) {
13531354
buffer[position++] = (byte) value;
1354-
return;
1355+
break;
13551356
} else {
13561357
buffer[position++] = (byte) ((value | 0x80) & 0xFF);
13571358
value >>>= 7;
@@ -1360,6 +1361,7 @@ public final void writeUInt32NoTag(int value) throws IOException {
13601361
} catch (IndexOutOfBoundsException e) {
13611362
throw new OutOfSpaceException(position, limit, 1, e);
13621363
}
1364+
this.position = position; // Only update position if we stayed within the array bounds.
13631365
}
13641366

13651367
@Override
@@ -1379,11 +1381,12 @@ public final void writeFixed32NoTag(int value) throws IOException {
13791381

13801382
@Override
13811383
public final void writeUInt64NoTag(long value) throws IOException {
1384+
int position = this.position; // Perf: hoist field to register to avoid load/stores.
13821385
if (HAS_UNSAFE_ARRAY_OPERATIONS && spaceLeft() >= MAX_VARINT_SIZE) {
13831386
while (true) {
13841387
if ((value & ~0x7FL) == 0) {
13851388
UnsafeUtil.putByte(buffer, position++, (byte) value);
1386-
return;
1389+
break;
13871390
} else {
13881391
UnsafeUtil.putByte(buffer, position++, (byte) (((int) value | 0x80) & 0xFF));
13891392
value >>>= 7;
@@ -1394,7 +1397,7 @@ public final void writeUInt64NoTag(long value) throws IOException {
13941397
while (true) {
13951398
if ((value & ~0x7FL) == 0) {
13961399
buffer[position++] = (byte) value;
1397-
return;
1400+
break;
13981401
} else {
13991402
buffer[position++] = (byte) (((int) value | 0x80) & 0xFF);
14001403
value >>>= 7;
@@ -1404,6 +1407,7 @@ public final void writeUInt64NoTag(long value) throws IOException {
14041407
throw new OutOfSpaceException(position, limit, 1, e);
14051408
}
14061409
}
1410+
this.position = position; // Only update position if we stayed within the array bounds.
14071411
}
14081412

14091413
@Override
@@ -2034,29 +2038,34 @@ public void writeInt32NoTag(int value) throws IOException {
20342038

20352039
@Override
20362040
public void writeUInt32NoTag(int value) throws IOException {
2041+
long position = this.position; // Perf: hoist field to register to avoid load/stores.
20372042
if (position <= oneVarintLimit) {
20382043
// Optimization to avoid bounds checks on each iteration.
20392044
while (true) {
20402045
if ((value & ~0x7F) == 0) {
20412046
UnsafeUtil.putByte(position++, (byte) value);
2042-
return;
2047+
break;
20432048
} else {
20442049
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
20452050
value >>>= 7;
20462051
}
20472052
}
20482053
} else {
2049-
while (position < limit) {
2054+
while (true) {
2055+
if (position >= limit) {
2056+
throw new OutOfSpaceException(
2057+
String.format("Pos: %d, limit: %d, len: %d", position, limit, 1));
2058+
}
20502059
if ((value & ~0x7F) == 0) {
20512060
UnsafeUtil.putByte(position++, (byte) value);
2052-
return;
2061+
break;
20532062
} else {
20542063
UnsafeUtil.putByte(position++, (byte) ((value | 0x80) & 0xFF));
20552064
value >>>= 7;
20562065
}
20572066
}
2058-
throw new OutOfSpaceException(position, limit, 1);
20592067
}
2068+
this.position = position; // Only update position if we stayed within the array bounds.
20602069
}
20612070

20622071
@Override
@@ -2071,29 +2080,33 @@ public void writeFixed32NoTag(int value) throws IOException {
20712080

20722081
@Override
20732082
public void writeUInt64NoTag(long value) throws IOException {
2083+
long position = this.position; // Perf: hoist field to register to avoid load/stores.
20742084
if (position <= oneVarintLimit) {
20752085
// Optimization to avoid bounds checks on each iteration.
20762086
while (true) {
20772087
if ((value & ~0x7FL) == 0) {
20782088
UnsafeUtil.putByte(position++, (byte) value);
2079-
return;
2089+
break;
20802090
} else {
20812091
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
20822092
value >>>= 7;
20832093
}
20842094
}
20852095
} else {
2086-
while (position < limit) {
2096+
while (true) {
2097+
if (position >= limit) {
2098+
throw new OutOfSpaceException(position, limit, 1);
2099+
}
20872100
if ((value & ~0x7FL) == 0) {
20882101
UnsafeUtil.putByte(position++, (byte) value);
2089-
return;
2102+
break;
20902103
} else {
20912104
UnsafeUtil.putByte(position++, (byte) (((int) value | 0x80) & 0xFF));
20922105
value >>>= 7;
20932106
}
20942107
}
2095-
throw new OutOfSpaceException(position, limit, 1);
20962108
}
2109+
this.position = position; // Only update position if we stayed within the array bounds.
20972110
}
20982111

20992112
@Override

java/core/src/test/java/com/google/protobuf/CodedOutputStreamTest.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package com.google.protobuf;
99

1010
import static com.google.common.truth.Truth.assertThat;
11+
import static com.google.common.truth.Truth.assertWithMessage;
1112
import static com.google.common.truth.TruthJUnit.assume;
1213
import static org.junit.Assert.assertThrows;
1314

@@ -313,6 +314,36 @@ public void testWriteFixed64NoTag_outOfBounds_throws() throws Exception {
313314
}
314315
}
315316

317+
@Test
318+
public void testWriteUInt32NoTag_outOfBounds_throws() throws Exception {
319+
// Streaming's buffering masks out of bounds writes.
320+
assume().that(outputType).isNotEqualTo(OutputType.STREAM);
321+
322+
for (int i = 0; i < 5; i++) {
323+
Coder coder = outputType.newCoder(i);
324+
assertThrows(
325+
OutOfSpaceException.class, () -> coder.stream().writeUInt32NoTag(Integer.MAX_VALUE));
326+
327+
// Space left should not go negative.
328+
assertWithMessage("i=%s", i).that(coder.stream().spaceLeft()).isAtLeast(0);
329+
}
330+
}
331+
332+
@Test
333+
public void testWriteUInt64NoTag_outOfBounds_throws() throws Exception {
334+
// Streaming's buffering masks out of bounds writes.
335+
assume().that(outputType).isNotEqualTo(OutputType.STREAM);
336+
337+
for (int i = 0; i < 9; i++) {
338+
Coder coder = outputType.newCoder(i);
339+
assertThrows(
340+
OutOfSpaceException.class, () -> coder.stream().writeUInt64NoTag(Long.MAX_VALUE));
341+
342+
// Space left should not go negative.
343+
assertWithMessage("i=%s", i).that(coder.stream().spaceLeft()).isAtLeast(0);
344+
}
345+
}
346+
316347
/** Test encodeZigZag32() and encodeZigZag64(). */
317348
@Test
318349
public void testEncodeZigZag() throws Exception {

0 commit comments

Comments
 (0)