diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 0e299632902..d09b5e435ad 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -100,6 +100,14 @@ def test_dumps_serialize_numpy(x): np.testing.assert_equal(x, y) +def test_dumps_numpy_writable(): + a1 = np.arange(1000) + a1.flags.writeable = False + (a2,) = loads(dumps([to_serialize(a1)])) + assert (a1 == a2).all() + assert a2.flags.writeable + + @pytest.mark.parametrize( "x", [ diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index d4250fb3c05..5f18bc1e6e6 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -1,19 +1,33 @@ +import pytest + from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames from distributed.utils import ensure_bytes -def test_merge_frames(): - result = merge_frames({"lengths": [3, 4]}, [b"12", b"34", b"567"]) - expected = [b"123", b"4567"] - +@pytest.mark.parametrize( + "lengths,frames", + [ + ([3], [b"123"]), + ([3, 3], [b"123", b"456"]), + ([2, 3, 2], [b"12345", b"67"]), + ([5, 2], [b"123", b"45", b"67"]), + ([3, 4], [b"12", b"34", b"567"]), + ], +) +def test_merge_frames(lengths, frames): + header = {"lengths": lengths} + result = merge_frames(header, frames) + + data = b"".join(frames) + expected = [] + for i in lengths: + expected.append(data[:i]) + data = data[i:] + + assert all(isinstance(f, memoryview) for f in result) + assert all(not f.readonly for f in result) assert list(map(ensure_bytes, result)) == expected - b = b"123" - assert merge_frames({"lengths": [3]}, [b])[0] is b - - L = [b"123", b"456"] - assert merge_frames({"lengths": [3, 3]}, L) is L - def test_pack_frames(): frames = [b"123", b"asdf"] diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index defbda2ba4f..e94d21584d2 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -58,34 +58,36 @@ def merge_frames(header, frames): [b'123456'] """ lengths = list(header["lengths"]) + frames = list(map(memoryview, frames)) assert sum(lengths) == sum(map(nbytes, frames)) - if all(len(f) == l for f, l in zip(frames, lengths)): - return frames - - frames = frames[::-1] - lengths = lengths[::-1] - - out = [] - while lengths: - l = lengths.pop() - L = [] - while l: - frame = frames.pop() - if nbytes(frame) <= l: - L.append(frame) - l -= nbytes(frame) + if not all(len(f) == l for f, l in zip(frames, lengths)): + frames = frames[::-1] + lengths = lengths[::-1] + + out = [] + while lengths: + l = lengths.pop() + L = [] + while l: + frame = frames.pop() + if nbytes(frame) <= l: + L.append(frame) + l -= nbytes(frame) + else: + L.append(frame[:l]) + frames.append(frame[l:]) + l = 0 + if len(L) == 1: # no work necessary + out.append(L[0]) else: - mv = memoryview(frame) - L.append(mv[:l]) - frames.append(mv[l:]) - l = 0 - if len(L) == 1: # no work necessary - out.extend(L) - else: - out.append(b"".join(L)) - return out + out.append(memoryview(bytearray().join(L))) + frames = out + + frames = [memoryview(bytearray(f)) if f.readonly else f for f in frames] + + return frames def pack_frames_prelude(frames):