diff --git a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py index 6be1af3d0ed..9fad62f5aa2 100644 --- a/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builders/huggingface_dataset_builder.py @@ -406,6 +406,7 @@ def _compute_shard_specs( # HF split size is good enough for estimating the number of shards. num_shards = shard_utils.ShardConfig.calculate_number_shards( total_size=hf_split_info.num_bytes, + max_example_size=None, num_examples=hf_split_info.num_examples, uses_precise_sharding=False, ) diff --git a/tensorflow_datasets/core/reader_test.py b/tensorflow_datasets/core/reader_test.py index e478e26bfd7..fdc88ad42f7 100644 --- a/tensorflow_datasets/core/reader_test.py +++ b/tensorflow_datasets/core/reader_test.py @@ -97,6 +97,7 @@ def _write_tfrecord(self, split_name, shards_number, records): shard_specs = writer_lib._get_shard_specs( num_examples=num_examples, total_size=0, + max_example_size=None, bucket_lengths=[num_examples], filename_template=filename_template, shard_config=shard_utils.ShardConfig(num_shards=shards_number), diff --git a/tensorflow_datasets/core/shuffle.py b/tensorflow_datasets/core/shuffle.py index 2df14df0474..a54adbfa586 100644 --- a/tensorflow_datasets/core/shuffle.py +++ b/tensorflow_datasets/core/shuffle.py @@ -244,7 +244,7 @@ def __init__( self._total_bytes = 0 # To keep data in memory until enough data has been gathered. self._in_memory = True - self._mem_buffer = [] + self._mem_buffer: list[type_utils.KeySerializedExample] = [] self._seen_keys: set[int] = set() self._num_examples = 0 @@ -272,10 +272,10 @@ def _add_to_mem_buffer(self, hkey: int, data: bytes) -> None: if self._total_bytes > MAX_MEM_BUFFER_SIZE: for hkey, data in self._mem_buffer: self._add_to_bucket(hkey, data) - self._mem_buffer = None + self._mem_buffer = [] self._in_memory = False - def add(self, key: type_utils.Key, data: bytes) -> bool: + def add(self, key: type_utils.Key, data: bytes) -> None: """Add (key, data) to shuffler.""" if self._read_only: raise AssertionError('add() cannot be called after __iter__.') diff --git a/tensorflow_datasets/core/utils/shard_utils.py b/tensorflow_datasets/core/utils/shard_utils.py index a5a700608e2..db018e8a2e5 100644 --- a/tensorflow_datasets/core/utils/shard_utils.py +++ b/tensorflow_datasets/core/utils/shard_utils.py @@ -57,19 +57,22 @@ class ShardConfig: def calculate_number_shards( cls, total_size: int, + max_example_size: int | Sequence[int] | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: """Returns number of shards for num_examples of total_size in bytes. Args: - total_size: the size of the data (serialized, not couting any overhead). + total_size: the size of the data (serialized, not counting any overhead). + max_example_size: the maximum size of a single example (serialized, not + counting any overhead). num_examples: the number of records in the data. uses_precise_sharding: whether a mechanism is used to exactly control how many examples go in each shard. """ - total_size += num_examples * cls.overhead - max_shards_number = total_size // cls.min_shard_size + total_overhead = num_examples * cls.overhead + total_size_with_overhead = total_size + total_overhead if uses_precise_sharding: max_shard_size = cls.max_shard_size else: @@ -77,7 +80,24 @@ def calculate_number_shards( # shard (called 'precise sharding' here), we use a smaller max shard size # so that the pipeline doesn't fail if a shard gets some more examples. max_shard_size = 0.9 * cls.max_shard_size - min_shards_number = total_size // max_shard_size + max_shard_size = max(1, max_shard_size) + + if max_example_size is None: + min_shards_number = max(1, total_size_with_overhead // max_shard_size) + max_shards_number = max(1, total_size_with_overhead // cls.min_shard_size) + else: + if isinstance(max_example_size, Sequence): + if len(max_example_size) == 1: + max_example_size = max_example_size[0] + else: + raise ValueError( + 'max_example_size must be a single value or None, got' + f' {max_example_size}' + ) + pessimistic_total_size = num_examples * (max_example_size + cls.overhead) + min_shards_number = max(1, pessimistic_total_size // max_shard_size) + max_shards_number = max(1, pessimistic_total_size // cls.min_shard_size) + if min_shards_number <= 1024 <= max_shards_number and num_examples >= 1024: return 1024 elif min_shards_number > 1024: @@ -96,15 +116,22 @@ def calculate_number_shards( def get_number_shards( self, total_size: int, + max_example_size: int | None, num_examples: int, uses_precise_sharding: bool = True, ) -> int: if self.num_shards: return self.num_shards return self.calculate_number_shards( - total_size, num_examples, uses_precise_sharding + total_size=total_size, + max_example_size=max_example_size, + num_examples=num_examples, + uses_precise_sharding=uses_precise_sharding, ) + def replace(self, **kwargs: Any) -> ShardConfig: + return dataclasses.replace(self, **kwargs) + def get_shard_boundaries( num_examples: int, diff --git a/tensorflow_datasets/core/utils/shard_utils_test.py b/tensorflow_datasets/core/utils/shard_utils_test.py index 1882b178e37..325d1d15cf1 100644 --- a/tensorflow_datasets/core/utils/shard_utils_test.py +++ b/tensorflow_datasets/core/utils/shard_utils_test.py @@ -22,16 +22,94 @@ class ShardConfigTest(parameterized.TestCase): @parameterized.named_parameters( - ('imagenet train, 137 GiB', 137 << 30, 1281167, True, 1024), - ('imagenet evaluation, 6.3 GiB', 6300 * (1 << 20), 50000, True, 64), - ('very large, but few examples, 52 GiB', 52 << 30, 512, True, 512), - ('xxl, 10 TiB', 10 << 40, 10**9, True, 11264), - ('xxl, 10 PiB, 100B examples', 10 << 50, 10**11, True, 10487808), - ('xs, 100 MiB, 100K records', 10 << 20, 100 * 10**3, True, 1), - ('m, 499 MiB, 200K examples', 400 << 20, 200 * 10**3, True, 4), + dict( + testcase_name='imagenet train, 137 GiB', + total_size=137 << 30, + num_examples=1281167, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1024, + ), + dict( + testcase_name='imagenet evaluation, 6.3 GiB', + total_size=6300 * (1 << 20), + num_examples=50000, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=64, + ), + dict( + testcase_name='very large, but few examples, 52 GiB', + total_size=52 << 30, + num_examples=512, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=512, + ), + dict( + testcase_name='xxl, 10 TiB', + total_size=10 << 40, + num_examples=10**9, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=11264, + ), + dict( + testcase_name='xxl, 10 PiB, 100B examples', + total_size=10 << 50, + num_examples=10**11, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=10487808, + ), + dict( + testcase_name='xs, 100 MiB, 100K records', + total_size=10 << 20, + num_examples=100 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=1, + ), + dict( + testcase_name='m, 499 MiB, 200K examples', + total_size=400 << 20, + num_examples=200 * 10**3, + uses_precise_sharding=True, + max_size=None, + expected_num_shards=4, + ), + dict( + testcase_name='100GiB, even example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=1024, + ), + dict( + testcase_name='100GiB, uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=4 * 1000, # Max example size is 4000 bytes + uses_precise_sharding=True, + expected_num_shards=4096, + ), + dict( + testcase_name='100GiB, very uneven example sizes', + num_examples=1e9, # 1B examples + total_size=1e9 * 1000, # On average 1000 bytes per example + max_size=16 * 1000, # Max example size is 16x the average bytes + uses_precise_sharding=True, + expected_num_shards=15360, + ), ) def test_get_number_shards_default_config( - self, total_size, num_examples, uses_precise_sharding, expected_num_shards + self, + total_size: int, + num_examples: int, + uses_precise_sharding: bool, + max_size: int, + expected_num_shards: int, ): shard_config = shard_utils.ShardConfig() self.assertEqual( @@ -39,6 +117,7 @@ def test_get_number_shards_default_config( shard_config.get_number_shards( total_size=total_size, num_examples=num_examples, + max_example_size=max_size, # max(1, total_size // num_examples), uses_precise_sharding=uses_precise_sharding, ), ) @@ -48,7 +127,10 @@ def test_get_number_shards_if_specified(self): self.assertEqual( 42, shard_config.get_number_shards( - total_size=100, num_examples=1, uses_precise_sharding=True + total_size=100, + max_example_size=100, + num_examples=1, + uses_precise_sharding=True, ), ) diff --git a/tensorflow_datasets/core/writer.py b/tensorflow_datasets/core/writer.py index aa756bbcfb5..abd461717ac 100644 --- a/tensorflow_datasets/core/writer.py +++ b/tensorflow_datasets/core/writer.py @@ -116,6 +116,7 @@ def _get_index_path(path: str) -> epath.PathLike: def _get_shard_specs( num_examples: int, total_size: int, + max_example_size: int | None, bucket_lengths: Sequence[int], filename_template: naming.ShardedFileTemplate, shard_config: shard_utils.ShardConfig, @@ -123,13 +124,18 @@ def _get_shard_specs( """Returns list of _ShardSpec instances, corresponding to shards to write. Args: - num_examples: int, number of examples in split. - total_size: int (bytes), sum of example sizes. + num_examples: number of examples in split. + total_size: total size in bytes, i.e., the sum of example sizes. + max_example_size: maximum size in bytes of a single example. bucket_lengths: list of ints, number of examples in each bucket. filename_template: template to format sharded filenames. shard_config: the configuration for creating shards. """ - num_shards = shard_config.get_number_shards(total_size, num_examples) + num_shards = shard_config.get_number_shards( + total_size=total_size, + max_example_size=max_example_size, + num_examples=num_examples, + ) shard_boundaries = shard_utils.get_shard_boundaries(num_examples, num_shards) shard_specs = [] bucket_indexes = [str(i) for i in range(len(bucket_lengths))] @@ -350,6 +356,7 @@ def __init__( self._filename_template = filename_template self._shard_config = shard_config or shard_utils.ShardConfig() self._example_writer = example_writer + self._max_example_size = 0 def write(self, key: int | bytes, example: Example): """Writes given example. @@ -363,6 +370,9 @@ def write(self, key: int | bytes, example: Example): """ serialized_example = self._serializer.serialize_example(example=example) self._shuffler.add(key, serialized_example) + self._max_example_size = max( + self._max_example_size, len(serialized_example) + ) def finalize(self) -> tuple[list[int], int]: """Effectively writes examples to the shards.""" @@ -372,6 +382,7 @@ def finalize(self) -> tuple[list[int], int]: shard_specs = _get_shard_specs( num_examples=self._shuffler.num_examples, total_size=self._shuffler.size, + max_example_size=self._max_example_size, bucket_lengths=self._shuffler.bucket_lengths, filename_template=self._filename_template, shard_config=self._shard_config, @@ -589,10 +600,13 @@ def _write_final_shard( id=shard_id, num_examples=len(example_by_key), size=shard_size ) - def _number_of_shards(self, num_examples: int, total_size: int) -> int: + def _number_of_shards( + self, num_examples: int, total_size: int, max_example_size: int + ) -> int: """Returns the number of shards.""" num_shards = self._shard_config.get_number_shards( total_size=total_size, + max_example_size=max_example_size, num_examples=num_examples, uses_precise_sharding=False, ) @@ -658,16 +672,26 @@ def write_from_pcollection(self, examples_pcollection): | "CountExamples" >> beam.combiners.Count.Globally() | "CheckValidNumExamples" >> beam.Map(self._check_num_examples) ) + serialized_example_sizes = ( + serialized_examples | beam.Values() | beam.Map(len) + ) total_size = beam.pvalue.AsSingleton( - serialized_examples - | beam.Values() - | beam.Map(len) - | "TotalSize" >> beam.CombineGlobally(sum) + serialized_example_sizes | "TotalSize" >> beam.CombineGlobally(sum) + ) + + max_example_size = beam.pvalue.AsSingleton( + serialized_example_sizes + | "TopExampleSize" >> beam.combiners.Top.Largest(1) + | "MaxExampleSize" >> beam.CombineGlobally(_get_max_size) ) ideal_num_shards = beam.pvalue.AsSingleton( num_examples | "NumberOfShards" - >> beam.Map(self._number_of_shards, total_size=total_size) + >> beam.Map( + self._number_of_shards, + total_size=total_size, + max_example_size=max_example_size, + ) ) examples_per_shard = ( @@ -826,3 +850,10 @@ def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]: ) return shard_lengths, total_size_bytes + + +def _get_max_size(sizes: Iterable[int]) -> int | None: + sizes = list(sizes) + if not sizes: + return None + return max(sizes) diff --git a/tensorflow_datasets/core/writer_test.py b/tensorflow_datasets/core/writer_test.py index c0b5fd30068..79b36124636 100644 --- a/tensorflow_datasets/core/writer_test.py +++ b/tensorflow_datasets/core/writer_test.py @@ -48,6 +48,7 @@ def test_1bucket_6shards(self): filetype_suffix='tfrecord', ), shard_config=shard_utils.ShardConfig(num_shards=6), + max_example_size=2, ) self.assertEqual( specs, @@ -134,6 +135,7 @@ def test_4buckets_2shards(self): filetype_suffix='tfrecord', ), shard_config=shard_utils.ShardConfig(num_shards=2), + max_example_size=2, ) self.assertEqual( specs,