Skip to content

Use the max size of serialized examples to find a safe number of shards #11005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def _compute_shard_specs(
# HF split size is good enough for estimating the number of shards.
num_shards = shard_utils.ShardConfig.calculate_number_shards(
total_size=hf_split_info.num_bytes,
max_example_size=None,
num_examples=hf_split_info.num_examples,
uses_precise_sharding=False,
)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/core/reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _write_tfrecord(self, split_name, shards_number, records):
shard_specs = writer_lib._get_shard_specs(
num_examples=num_examples,
total_size=0,
max_example_size=None,
bucket_lengths=[num_examples],
filename_template=filename_template,
shard_config=shard_utils.ShardConfig(num_shards=shards_number),
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_datasets/core/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def __init__(
self._total_bytes = 0
# To keep data in memory until enough data has been gathered.
self._in_memory = True
self._mem_buffer = []
self._mem_buffer: list[type_utils.KeySerializedExample] = []
self._seen_keys: set[int] = set()
self._num_examples = 0

Expand Down Expand Up @@ -272,10 +272,10 @@ def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None:
if self._total_bytes > MAX_MEM_BUFFER_SIZE:
for hkey, data in self._mem_buffer:
self._add_to_bucket(hkey, data)
self._mem_buffer = None
self._mem_buffer = []
self._in_memory = False

def add(self, key: type_utils.Key, data: bytes) -> bool:
def add(self, key: type_utils.Key, data: bytes) -> None:
"""Add (key, data) to shuffler."""
if self._read_only:
raise AssertionError('add() cannot be called after __iter__.')
Expand Down
37 changes: 32 additions & 5 deletions tensorflow_datasets/core/utils/shard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,47 @@ class ShardConfig:
def calculate_number_shards(
cls,
total_size: int,
max_example_size: int | Sequence[int] | None,
num_examples: int,
uses_precise_sharding: bool = True,
) -> int:
"""Returns number of shards for num_examples of total_size in bytes.

Args:
total_size: the size of the data (serialized, not couting any overhead).
total_size: the size of the data (serialized, not counting any overhead).
max_example_size: the maximum size of a single example (serialized, not
counting any overhead).
num_examples: the number of records in the data.
uses_precise_sharding: whether a mechanism is used to exactly control how
many examples go in each shard.
"""
total_size += num_examples * cls.overhead
max_shards_number = total_size // cls.min_shard_size
total_overhead = num_examples * cls.overhead
total_size_with_overhead = total_size + total_overhead
if uses_precise_sharding:
max_shard_size = cls.max_shard_size
else:
# When the pipeline does not control exactly how many rows go into each
# shard (called 'precise sharding' here), we use a smaller max shard size
# so that the pipeline doesn't fail if a shard gets some more examples.
max_shard_size = 0.9 * cls.max_shard_size
min_shards_number = total_size // max_shard_size
max_shard_size = max(1, max_shard_size)

if max_example_size is None:
min_shards_number = max(1, total_size_with_overhead // max_shard_size)
max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size)
else:
if isinstance(max_example_size, Sequence):
if len(max_example_size) == 1:
max_example_size = max_example_size[0]
else:
raise ValueError(
'max_example_size must be a single value or None, got'
f' {max_example_size}'
)
pessimistic_total_size = num_examples * (max_example_size + cls.overhead)
min_shards_number = max(1, pessimistic_total_size // max_shard_size)
max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size)

if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024:
return 1024
elif min_shards_number > 1024:
Expand All @@ -96,15 +116,22 @@ def calculate_number_shards(
def get_number_shards(
self,
total_size: int,
max_example_size: int | None,
num_examples: int,
uses_precise_sharding: bool = True,
) -> int:
if self.num_shards:
return self.num_shards
return self.calculate_number_shards(
total_size, num_examples, uses_precise_sharding
total_size=total_size,
max_example_size=max_example_size,
num_examples=num_examples,
uses_precise_sharding=uses_precise_sharding,
)

def replace(self, **kwargs: Any) -> ShardConfig:
return dataclasses.replace(self, **kwargs)


def get_shard_boundaries(
num_examples: int,
Expand Down
100 changes: 91 additions & 9 deletions tensorflow_datasets/core/utils/shard_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,102 @@
class ShardConfigTest(parameterized.TestCase):

@parameterized.named_parameters(
('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024),
('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64),
('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512),
('xxl, 10 TiB', 10 << 40, 10**9, True, 11264),
('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808),
('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1),
('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4),
dict(
testcase_name='imagenet train, 137 GiB',
total_size=137 << 30,
num_examples=1281167,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=1024,
),
dict(
testcase_name='imagenet evaluation, 6.3 GiB',
total_size=6300 * (1 << 20),
num_examples=50000,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=64,
),
dict(
testcase_name='very large, but few examples, 52 GiB',
total_size=52 << 30,
num_examples=512,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=512,
),
dict(
testcase_name='xxl, 10 TiB',
total_size=10 << 40,
num_examples=10**9,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=11264,
),
dict(
testcase_name='xxl, 10 PiB, 100B examples',
total_size=10 << 50,
num_examples=10**11,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=10487808,
),
dict(
testcase_name='xs, 100 MiB, 100K records',
total_size=10 << 20,
num_examples=100 * 10**3,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=1,
),
dict(
testcase_name='m, 499 MiB, 200K examples',
total_size=400 << 20,
num_examples=200 * 10**3,
uses_precise_sharding=True,
max_size=None,
expected_num_shards=4,
),
dict(
testcase_name='100GiB, even example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=1000, # Max example size is 4000 bytes
uses_precise_sharding=True,
expected_num_shards=1024,
),
dict(
testcase_name='100GiB, uneven example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=4 * 1000, # Max example size is 4000 bytes
uses_precise_sharding=True,
expected_num_shards=4096,
),
dict(
testcase_name='100GiB, very uneven example sizes',
num_examples=1e9, # 1B examples
total_size=1e9 * 1000, # On average 1000 bytes per example
max_size=16 * 1000, # Max example size is 16x the average bytes
uses_precise_sharding=True,
expected_num_shards=15360,
),
)
def test_get_number_shards_default_config(
self, total_size, num_examples, uses_precise_sharding, expected_num_shards
self,
total_size: int,
num_examples: int,
uses_precise_sharding: bool,
max_size: int,
expected_num_shards: int,
):
shard_config = shard_utils.ShardConfig()
self.assertEqual(
expected_num_shards,
shard_config.get_number_shards(
total_size=total_size,
num_examples=num_examples,
max_example_size=max_size, # max(1, total_size // num_examples),
uses_precise_sharding=uses_precise_sharding,
),
)
Expand All @@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self):
self.assertEqual(
42,
shard_config.get_number_shards(
total_size=100, num_examples=1, uses_precise_sharding=True
total_size=100,
max_example_size=100,
num_examples=1,
uses_precise_sharding=True,
),
)

Expand Down
49 changes: 40 additions & 9 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,26 @@ def _get_index_path(path: str) -> epath.PathLike:
def _get_shard_specs(
num_examples: int,
total_size: int,
max_example_size: int | None,
bucket_lengths: Sequence[int],
filename_template: naming.ShardedFileTemplate,
shard_config: shard_utils.ShardConfig,
) -> Sequence[_ShardSpec]:
"""Returns list of _ShardSpec instances, corresponding to shards to write.

Args:
num_examples: int, number of examples in split.
total_size: int (bytes), sum of example sizes.
num_examples: number of examples in split.
total_size: total size in bytes, i.e., the sum of example sizes.
max_example_size: maximum size in bytes of a single example.
bucket_lengths: list of ints, number of examples in each bucket.
filename_template: template to format sharded filenames.
shard_config: the configuration for creating shards.
"""
num_shards = shard_config.get_number_shards(total_size, num_examples)
num_shards = shard_config.get_number_shards(
total_size=total_size,
max_example_size=max_example_size,
num_examples=num_examples,
)
shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards)
shard_specs = []
bucket_indexes = [str(i) for i in range(len(bucket_lengths))]
Expand Down Expand Up @@ -350,6 +356,7 @@ def __init__(
self._filename_template = filename_template
self._shard_config = shard_config or shard_utils.ShardConfig()
self._example_writer = example_writer
self._max_example_size = 0

def write(self, key: int | bytes, example: Example):
"""Writes given example.
Expand All @@ -363,6 +370,9 @@ def write(self, key: int | bytes, example: Example):
"""
serialized_example = self._serializer.serialize_example(example=example)
self._shuffler.add(key, serialized_example)
self._max_example_size = max(
self._max_example_size, len(serialized_example)
)

def finalize(self) -> tuple[list[int], int]:
"""Effectively writes examples to the shards."""
Expand All @@ -372,6 +382,7 @@ def finalize(self) -> tuple[list[int], int]:
shard_specs = _get_shard_specs(
num_examples=self._shuffler.num_examples,
total_size=self._shuffler.size,
max_example_size=self._max_example_size,
bucket_lengths=self._shuffler.bucket_lengths,
filename_template=self._filename_template,
shard_config=self._shard_config,
Expand Down Expand Up @@ -589,10 +600,13 @@ def _write_final_shard(
id=shard_id, num_examples=len(example_by_key), size=shard_size
)

def _number_of_shards(self, num_examples: int, total_size: int) -> int:
def _number_of_shards(
self, num_examples: int, total_size: int, max_example_size: int
) -> int:
"""Returns the number of shards."""
num_shards = self._shard_config.get_number_shards(
total_size=total_size,
max_example_size=max_example_size,
num_examples=num_examples,
uses_precise_sharding=False,
)
Expand Down Expand Up @@ -658,16 +672,26 @@ def write_from_pcollection(self, examples_pcollection):
| "CountExamples" >> beam.combiners.Count.Globally()
| "CheckValidNumExamples" >> beam.Map(self._check_num_examples)
)
serialized_example_sizes = (
serialized_examples | beam.Values() | beam.Map(len)
)
total_size = beam.pvalue.AsSingleton(
serialized_examples
| beam.Values()
| beam.Map(len)
| "TotalSize" >> beam.CombineGlobally(sum)
serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum)
)

max_example_size = beam.pvalue.AsSingleton(
serialized_example_sizes
| "TopExampleSize" >> beam.combiners.Top.Largest(1)
| "MaxExampleSize" >> beam.CombineGlobally(_get_max_size)
)
ideal_num_shards = beam.pvalue.AsSingleton(
num_examples
| "NumberOfShards"
>> beam.Map(self._number_of_shards, total_size=total_size)
>> beam.Map(
self._number_of_shards,
total_size=total_size,
max_example_size=max_example_size,
)
)

examples_per_shard = (
Expand Down Expand Up @@ -826,3 +850,10 @@ def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
)

return shard_lengths, total_size_bytes


def _get_max_size(sizes: Iterable[int]) -> int | None:
sizes = list(sizes)
if not sizes:
return None
return max(sizes)
2 changes: 2 additions & 0 deletions tensorflow_datasets/core/writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_1bucket_6shards(self):
filetype_suffix='tfrecord',
),
shard_config=shard_utils.ShardConfig(num_shards=6),
max_example_size=2,
)
self.assertEqual(
specs,
Expand Down Expand Up @@ -134,6 +135,7 @@ def test_4buckets_2shards(self):
filetype_suffix='tfrecord',
),
shard_config=shard_utils.ShardConfig(num_shards=2),
max_example_size=2,
)
self.assertEqual(
specs,
Expand Down