Skip to content

Commit e57d1c5

Browse files
Reduce memory footprint of P2P shuffling (#8157)
1 parent 9129dae commit e57d1c5

File tree

9 files changed

+129
-65
lines changed

9 files changed

+129
-65
lines changed

continuous_integration/environment-3.9.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies:
2828
- pre-commit
2929
- prometheus_client
3030
- psutil
31-
- pyarrow=7
31+
- pyarrow=12
3232
- pynvml # Only tested here
3333
- pytest
3434
- pytest-cov

distributed/shuffle/_arrow.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3-
from io import BytesIO
4-
from typing import TYPE_CHECKING
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Any
55

66
from packaging.version import parse
77

8+
from dask.utils import parse_bytes
9+
810
if TYPE_CHECKING:
911
import pandas as pd
1012
import pyarrow as pa
@@ -29,34 +31,27 @@ def check_minimal_arrow_version() -> None:
2931
"""Verify that the the correct version of pyarrow is installed to support
3032
the P2P extension.
3133
32-
Raises a RuntimeError in case pyarrow is not installed or installed version
33-
is not recent enough.
34+
Raises a ModuleNotFoundError if pyarrow is not installed or an
35+
ImportError if the installed version is not recent enough.
3436
"""
35-
# First version to introduce Table.sort_by
36-
minversion = "7.0.0"
37+
# First version that supports concatenating extension arrays (apache/arrow#14463)
38+
minversion = "12.0.0"
3739
try:
3840
import pyarrow as pa
39-
except ImportError:
40-
raise RuntimeError(f"P2P shuffling requires pyarrow>={minversion}")
41-
41+
except ModuleNotFoundError:
42+
raise ModuleNotFoundError(f"P2P shuffling requires pyarrow>={minversion}")
4243
if parse(pa.__version__) < parse(minversion):
43-
raise RuntimeError(
44+
raise ImportError(
4445
f"P2P shuffling requires pyarrow>={minversion} but only found {pa.__version__}"
4546
)
4647

4748

48-
def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
49+
def convert_shards(shards: list[pa.Table], meta: pd.DataFrame) -> pd.DataFrame:
4950
import pyarrow as pa
5051

5152
from dask.dataframe.dispatch import from_pyarrow_table_dispatch
5253

53-
file = BytesIO(data)
54-
end = len(data)
55-
shards = []
56-
while file.tell() < end:
57-
sr = pa.RecordBatchStreamReader(file)
58-
shards.append(sr.read_all())
59-
table = pa.concat_tables(shards, promote=True)
54+
table = pa.concat_tables(shards)
6055

6156
df = from_pyarrow_table_dispatch(meta, table, self_destruct=True)
6257
return df.astype(meta.dtypes, copy=False)
@@ -66,9 +61,7 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
6661
"""Convert a list of arrow buffers and a schema to an Arrow Table"""
6762
import pyarrow as pa
6863

69-
return pa.concat_tables(
70-
(deserialize_table(buffer) for buffer in data), promote=True
71-
)
64+
return pa.concat_tables(deserialize_table(buffer) for buffer in data)
7265

7366

7467
def serialize_table(table: pa.Table) -> bytes:
@@ -85,3 +78,42 @@ def deserialize_table(buffer: bytes) -> pa.Table:
8578

8679
with pa.ipc.open_stream(pa.py_buffer(buffer)) as reader:
8780
return reader.read_all()
81+
82+
83+
def read_from_disk(path: Path, meta: pd.DataFrame) -> tuple[Any, int]:
84+
import pyarrow as pa
85+
86+
from dask.dataframe.dispatch import pyarrow_schema_dispatch
87+
88+
batch_size = parse_bytes("1 MiB")
89+
batch = []
90+
shards = []
91+
schema = pyarrow_schema_dispatch(meta, preserve_index=True)
92+
93+
with pa.OSFile(str(path), mode="rb") as f:
94+
size = f.seek(0, whence=2)
95+
f.seek(0)
96+
prev = 0
97+
offset = f.tell()
98+
while offset < size:
99+
sr = pa.RecordBatchStreamReader(f)
100+
shard = sr.read_all()
101+
offset = f.tell()
102+
batch.append(shard)
103+
104+
if offset - prev >= batch_size:
105+
table = pa.concat_tables(batch)
106+
shards.append(_copy_table(table, schema))
107+
batch = []
108+
prev = offset
109+
if batch:
110+
table = pa.concat_tables(batch)
111+
shards.append(_copy_table(table, schema))
112+
return shards, size
113+
114+
115+
def _copy_table(table: pa.Table, schema: pa.Schema) -> pa.Table:
116+
import pyarrow as pa
117+
118+
arrs = [pa.concat_arrays(column.chunks) for column in table.columns]
119+
return pa.table(data=arrs, schema=schema)

distributed/shuffle/_core.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dataclasses import dataclass, field
1212
from enum import Enum
1313
from functools import partial
14+
from pathlib import Path
1415
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar
1516

1617
from distributed.core import PooledRPCCall
@@ -62,6 +63,7 @@ def __init__(
6263

6364
self._disk_buffer = DiskShardsBuffer(
6465
directory=directory,
66+
read=self.read,
6567
memory_limiter=memory_limiter_disk,
6668
)
6769

@@ -180,10 +182,9 @@ def fail(self, exception: Exception) -> None:
180182
if not self.closed:
181183
self._exception = exception
182184

183-
def _read_from_disk(self, id: NDIndex) -> bytes:
185+
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
184186
self.raise_if_closed()
185-
data: bytes = self._disk_buffer.read("_".join(str(i) for i in id))
186-
return data
187+
return self._disk_buffer.read("_".join(str(i) for i in id))
187188

188189
async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
189190
await self._receive(data)
@@ -238,6 +239,10 @@ async def _get_output_partition(
238239
) -> _T_partition_type:
239240
"""Get an output partition to the shuffle run"""
240241

242+
@abc.abstractmethod
243+
def read(self, path: Path) -> tuple[Any, int]:
244+
"""Read shards from disk"""
245+
241246

242247
def get_worker_plugin() -> ShuffleWorkerPlugin:
243248
from distributed import get_worker

distributed/shuffle/_disk.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import pathlib
55
import shutil
6+
from typing import Any, Callable
67

78
from distributed.shuffle._buffer import ShardsBuffer
89
from distributed.shuffle._limiter import ResourceLimiter
@@ -41,6 +42,7 @@ class DiskShardsBuffer(ShardsBuffer):
4142
def __init__(
4243
self,
4344
directory: str | pathlib.Path,
45+
read: Callable[[pathlib.Path], tuple[Any, int]],
4446
memory_limiter: ResourceLimiter | None = None,
4547
):
4648
super().__init__(
@@ -50,6 +52,7 @@ def __init__(
5052
)
5153
self.directory = pathlib.Path(directory)
5254
self.directory.mkdir(exist_ok=True)
55+
self._read = read
5356

5457
async def _process(self, id: str, shards: list[bytes]) -> None:
5558
"""Write one buffer to file
@@ -74,19 +77,15 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
7477
for shard in shards:
7578
f.write(shard)
7679

77-
def read(self, id: int | str) -> bytes:
80+
def read(self, id: int | str) -> Any:
7881
"""Read a complete file back into memory"""
7982
self.raise_on_exception()
8083
if not self._inputs_done:
8184
raise RuntimeError("Tried to read from file before done.")
8285

8386
try:
8487
with self.time("read"):
85-
with open(
86-
self.directory / str(id), mode="rb", buffering=100_000_000
87-
) as f:
88-
data = f.read()
89-
size = f.tell()
88+
data, size = self._read((self.directory / str(id)).resolve())
9089
except FileNotFoundError:
9190
raise KeyError(id)
9291

distributed/shuffle/_rechunk.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@
102102
from collections.abc import Callable, Sequence
103103
from concurrent.futures import ThreadPoolExecutor
104104
from dataclasses import dataclass
105-
from io import BytesIO
106105
from itertools import product
106+
from pathlib import Path
107107
from typing import TYPE_CHECKING, Any, NamedTuple
108108

109109
import dask
@@ -258,26 +258,22 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes:
258258
return axes
259259

260260

261-
def convert_chunk(data: bytes) -> np.ndarray:
261+
def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray:
262262
import numpy as np
263263

264264
from dask.array.core import concatenate3
265265

266-
file = BytesIO(data)
267-
shards: dict[NDIndex, np.ndarray] = {}
266+
indexed: dict[NDIndex, np.ndarray] = {}
267+
for index, shard in shards:
268+
indexed[index] = shard
269+
del shards
268270

269-
while file.tell() < len(data):
270-
for index, shard in pickle.load(file):
271-
shards[index] = shard
272-
273-
subshape = [max(dim) + 1 for dim in zip(*shards.keys())]
274-
assert len(shards) == np.prod(subshape)
271+
subshape = [max(dim) + 1 for dim in zip(*indexed.keys())]
272+
assert len(indexed) == np.prod(subshape)
275273

276274
rec_cat_arg = np.empty(subshape, dtype="O")
277-
for index, shard in shards.items():
275+
for index, shard in indexed.items():
278276
rec_cat_arg[tuple(index)] = shard
279-
del data
280-
del file
281277
arrs = rec_cat_arg.tolist()
282278
return concatenate3(arrs)
283279

@@ -427,8 +423,20 @@ def _() -> dict[str, tuple[NDIndex, bytes]]:
427423
async def _get_output_partition(
428424
self, partition_id: NDIndex, key: str, **kwargs: Any
429425
) -> np.ndarray:
430-
data = self._read_from_disk(partition_id)
431-
return await self.offload(convert_chunk, data)
426+
def _(partition_id: NDIndex) -> np.ndarray:
427+
data = self._read_from_disk(partition_id)
428+
return convert_chunk(data)
429+
430+
return await self.offload(_, partition_id)
431+
432+
def read(self, path: Path) -> tuple[Any, int]:
433+
shards: list[tuple[NDIndex, np.ndarray]] = []
434+
with path.open(mode="rb") as f:
435+
size = f.seek(0, os.SEEK_END)
436+
f.seek(0)
437+
while f.tell() < size:
438+
shards.extend(pickle.load(f))
439+
return shards, size
432440

433441
def _get_assigned_worker(self, id: NDIndex) -> str:
434442
return self.worker_for[id]

distributed/shuffle/_shuffle.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from concurrent.futures import ThreadPoolExecutor
88
from dataclasses import dataclass
99
from functools import partial
10+
from pathlib import Path
1011
from typing import TYPE_CHECKING, Any
1112

1213
import toolz
@@ -20,8 +21,9 @@
2021
from distributed.shuffle._arrow import (
2122
check_dtype_support,
2223
check_minimal_arrow_version,
23-
convert_partition,
24+
convert_shards,
2425
list_of_buffers_to_table,
26+
read_from_disk,
2527
serialize_table,
2628
)
2729
from distributed.shuffle._core import (
@@ -321,7 +323,7 @@ def split_by_worker(
321323
return out
322324

323325

324-
def split_by_partition(t: pa.Table, column: str) -> dict[Any, pa.Table]:
326+
def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]:
325327
"""
326328
Split data into many arrow batches, partitioned by final partition
327329
"""
@@ -383,6 +385,11 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
383385
buffer.
384386
"""
385387

388+
column: str
389+
meta: pd.DataFrame
390+
partitions_of: dict[str, list[int]]
391+
worker_for: pd.Series
392+
386393
def __init__(
387394
self,
388395
worker_for: dict[int, str],
@@ -476,16 +483,22 @@ async def _get_output_partition(
476483
**kwargs: Any,
477484
) -> pd.DataFrame:
478485
try:
479-
data = self._read_from_disk((partition_id,))
480486

481-
out = await self.offload(convert_partition, data, self.meta)
487+
def _(partition_id: int, meta: pd.DataFrame) -> pd.DataFrame:
488+
data = self._read_from_disk((partition_id,))
489+
return convert_shards(data, meta)
490+
491+
out = await self.offload(_, partition_id, self.meta)
482492
except KeyError:
483493
out = self.meta.copy()
484494
return out
485495

486496
def _get_assigned_worker(self, id: int) -> str:
487497
return self.worker_for[id]
488498

499+
def read(self, path: Path) -> tuple[Any, int]:
500+
return read_from_disk(path, self.meta)
501+
489502

490503
@dataclass(frozen=True)
491504
class DataFrameShuffleSpec(ShuffleSpec[int]):

0 commit comments

Comments
 (0)